package shujubao import ( "context" "crypto/md5" "encoding/json" "errors" "fmt" "io" "net/http" "net/url" "sort" "strconv" "strings" "time" "tyapi-server/internal/shared/external_logger" ) const ( // 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志 maxLogParamValueLen = 300 ) var ( ErrDatasource = errors.New("数据源异常") ErrSystem = errors.New("系统异常") ErrQueryEmpty = errors.New("查询为空") ) // truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容 func truncateForLog(s string, maxLen int) string { if maxLen <= 0 { return s } if len(s) <= maxLen { return s } return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]" } // paramsForLog 返回适合写入错误日志的入参副本(长字符串会被截断) func paramsForLog(params map[string]interface{}) map[string]interface{} { if params == nil { return nil } out := make(map[string]interface{}, len(params)) for k, v := range params { if v == nil { out[k] = nil continue } switch val := v.(type) { case string: out[k] = truncateForLog(val, maxLogParamValueLen) default: s := fmt.Sprint(v) out[k] = truncateForLog(s, maxLogParamValueLen) } } return out } // ShujubaoResp 数据宝 API 通用响应(按实际文档调整) type ShujubaoResp struct { Code string `json:"code"` Message string `json:"message"` Data interface{} `json:"data"` Success bool `json:"success"` } // ShujubaoConfig 数据宝服务配置 type ShujubaoConfig struct { URL string AppSecret string SignMethod SignMethod Timeout time.Duration } // ShujubaoService 数据宝服务 type ShujubaoService struct { config ShujubaoConfig logger *external_logger.ExternalServiceLogger } // NewShujubaoService 创建数据宝服务实例 func NewShujubaoService(url, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *ShujubaoService { if signMethod == "" { signMethod = SignMethodHMACMD5 } if timeout == 0 { timeout = 60 * time.Second } return &ShujubaoService{ config: ShujubaoConfig{ URL: url, AppSecret: appSecret, SignMethod: signMethod, Timeout: timeout, }, logger: logger, } } // generateRequestID 生成请求 ID func (s *ShujubaoService) generateRequestID() string { timestamp := time.Now().UnixNano() hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, s.config.AppSecret))) return fmt.Sprintf("shujubao_%x", hash[:8]) } // buildSortedParamStr 将入参按 key 的 ASCII 排序组合为 key1=value1&key2=value2&... func buildSortedParamStr(params map[string]interface{}) string { if len(params) == 0 { return "" } keys := make([]string, 0, len(params)) for k := range params { keys = append(keys, k) } sort.Strings(keys) var b strings.Builder for i, k := range keys { if i > 0 { b.WriteByte('&') } v := params[k] var vs string switch val := v.(type) { case string: vs = val case nil: vs = "" default: vs = fmt.Sprint(val) } b.WriteString(k) b.WriteByte('=') b.WriteString(vs) } return b.String() } // buildFormUrlEncodedBody 按 key 的 ASCII 排序构建 application/x-www-form-urlencoded 请求体(键与值均已 URL 编码) func buildFormUrlEncodedBody(params map[string]interface{}) string { if len(params) == 0 { return "" } keys := make([]string, 0, len(params)) for k := range params { keys = append(keys, k) } sort.Strings(keys) var b strings.Builder for i, k := range keys { if i > 0 { b.WriteByte('&') } v := params[k] var vs string switch val := v.(type) { case string: vs = val case nil: vs = "" default: vs = fmt.Sprint(val) } b.WriteString(url.QueryEscape(k)) b.WriteByte('=') b.WriteString(url.QueryEscape(vs)) } return b.String() } // generateSign 根据入参与时间戳生成签名。入参按 ASCII 排序组合后与 app_secret 做 MD5/HMAC。 // 对于开启了加密的接口需传 sign 与 timestamp;明文传输的接口则无需传这两个参数。 func (s *ShujubaoService) generateSign(timestamp string, params map[string]interface{}) string { // 合并 timestamp 到入参后参与排序 merged := make(map[string]interface{}, len(params)+1) for k, v := range params { merged[k] = v } merged["timestamp"] = timestamp sortedStr := buildSortedParamStr(merged) switch s.config.SignMethod { case SignMethodMD5: return GenerateSignFromParamsMD5(s.config.AppSecret, sortedStr) default: return GenerateSignFromParamsHMAC(s.config.AppSecret, sortedStr) } } // buildRequestURL 拼接接口地址得到最终请求 URL,如 https://api.chinadatapay.com/communication/personal/197 func (s *ShujubaoService) buildRequestURL(apiPath string) string { base := strings.TrimSuffix(s.config.URL, "/") if apiPath == "" { return base } return base + "/" + strings.TrimPrefix(apiPath, "/") } // CallAPI 调用数据宝 API(POST)。最终请求地址 = url + 拼接接口地址值;body 为业务参数;sign、timestamp 按原样传 header。 func (s *ShujubaoService) CallAPI(ctx context.Context, apiPath string, params map[string]interface{}) (data interface{}, err error) { startTime := time.Now() requestID := s.generateRequestID() timestamp := strconv.FormatInt(time.Now().Unix(), 10) // 最终请求 URL = https://api.chinadatapay.com/communication + 拼接接口地址值,如 /personal/197 requestURL := s.buildRequestURL(apiPath) var transactionID string if id, ok := ctx.Value("transaction_id").(string); ok { transactionID = id } if s.logger != nil { s.logger.LogRequest(requestID, transactionID, apiPath, requestURL) } // 使用 application/x-www-form-urlencoded,贵司接口暂不支持 JSON 入参 formBody := buildFormUrlEncodedBody(params) req, err := http.NewRequestWithContext(ctx, "POST", requestURL, strings.NewReader(formBody)) if err != nil { err = errors.Join(ErrSystem, err) if s.logger != nil { s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params)) } return nil, err } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("timestamp", timestamp) req.Header.Set("sign", s.generateSign(timestamp, params)) client := &http.Client{Timeout: s.config.Timeout} response, err := client.Do(req) if err != nil { isTimeout := false if ctx.Err() == context.DeadlineExceeded { isTimeout = true } else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { isTimeout = true } else if errStr := err.Error(); errStr == "context deadline exceeded" || errStr == "timeout" || errStr == "Client.Timeout exceeded" || errStr == "net/http: request canceled" { isTimeout = true } if isTimeout { err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err)) } else { err = errors.Join(ErrSystem, err) } if s.logger != nil { s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params)) } return nil, err } defer response.Body.Close() respBody, err := io.ReadAll(response.Body) if err != nil { err = errors.Join(ErrSystem, err) if s.logger != nil { s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params)) } return nil, err } if s.logger != nil { duration := time.Since(startTime) s.logger.LogResponse(requestID, transactionID, apiPath, response.StatusCode, duration) } if response.StatusCode != http.StatusOK { err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode)) if s.logger != nil { s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params)) } return nil, err } var shujubaoResp ShujubaoResp if err := json.Unmarshal(respBody, &shujubaoResp); err != nil { err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err)) if s.logger != nil { s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params)) } return nil, err } code := shujubaoResp.Code if code == "10001" || code == "10006" { // 查空/查无:返回空数组,不视为错误 return []interface{}{}, nil } if code != "10000" { shujubaoErr := NewShujubaoErrorFromCode(code, shujubaoResp.Message) if s.logger != nil { s.logger.LogError(requestID, transactionID, apiPath, shujubaoErr, paramsForLog(params)) } return nil, errors.Join(ErrDatasource, shujubaoErr) } return shujubaoResp.Data, nil }