Files
tyapi-server/internal/infrastructure/external/shujubao/shujubao_service.go
2026-03-05 11:05:01 +08:00

309 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package shujubao
import (
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"tyapi-server/internal/shared/external_logger"
)
const (
// 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志
maxLogParamValueLen = 300
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
ErrQueryEmpty = errors.New("查询为空")
)
// truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容
func truncateForLog(s string, maxLen int) string {
if maxLen <= 0 {
return s
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]"
}
// paramsForLog 返回适合写入错误日志的入参副本(长字符串会被截断)
func paramsForLog(params map[string]interface{}) map[string]interface{} {
if params == nil {
return nil
}
out := make(map[string]interface{}, len(params))
for k, v := range params {
if v == nil {
out[k] = nil
continue
}
switch val := v.(type) {
case string:
out[k] = truncateForLog(val, maxLogParamValueLen)
default:
s := fmt.Sprint(v)
out[k] = truncateForLog(s, maxLogParamValueLen)
}
}
return out
}
// ShujubaoResp 数据宝 API 通用响应(按实际文档调整)
type ShujubaoResp struct {
Code string `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
Success bool `json:"success"`
}
// ShujubaoConfig 数据宝服务配置
type ShujubaoConfig struct {
URL string
AppSecret string
SignMethod SignMethod
Timeout time.Duration
}
// ShujubaoService 数据宝服务
type ShujubaoService struct {
config ShujubaoConfig
logger *external_logger.ExternalServiceLogger
}
// NewShujubaoService 创建数据宝服务实例
func NewShujubaoService(url, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *ShujubaoService {
if signMethod == "" {
signMethod = SignMethodHMACMD5
}
if timeout == 0 {
timeout = 60 * time.Second
}
return &ShujubaoService{
config: ShujubaoConfig{
URL: url,
AppSecret: appSecret,
SignMethod: signMethod,
Timeout: timeout,
},
logger: logger,
}
}
// generateRequestID 生成请求 ID
func (s *ShujubaoService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, s.config.AppSecret)))
return fmt.Sprintf("shujubao_%x", hash[:8])
}
// buildSortedParamStr 将入参按 key 的 ASCII 排序组合为 key1=value1&key2=value2&...
func buildSortedParamStr(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for i, k := range keys {
if i > 0 {
b.WriteByte('&')
}
v := params[k]
var vs string
switch val := v.(type) {
case string:
vs = val
case nil:
vs = ""
default:
vs = fmt.Sprint(val)
}
b.WriteString(k)
b.WriteByte('=')
b.WriteString(vs)
}
return b.String()
}
// buildFormUrlEncodedBody 按 key 的 ASCII 排序构建 application/x-www-form-urlencoded 请求体(键与值均已 URL 编码)
func buildFormUrlEncodedBody(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for i, k := range keys {
if i > 0 {
b.WriteByte('&')
}
v := params[k]
var vs string
switch val := v.(type) {
case string:
vs = val
case nil:
vs = ""
default:
vs = fmt.Sprint(val)
}
b.WriteString(url.QueryEscape(k))
b.WriteByte('=')
b.WriteString(url.QueryEscape(vs))
}
return b.String()
}
// generateSign 根据入参与时间戳生成签名。入参按 ASCII 排序组合后与 app_secret 做 MD5/HMAC。
// 对于开启了加密的接口需传 sign 与 timestamp明文传输的接口则无需传这两个参数。
func (s *ShujubaoService) generateSign(timestamp string, params map[string]interface{}) string {
// 合并 timestamp 到入参后参与排序
merged := make(map[string]interface{}, len(params)+1)
for k, v := range params {
merged[k] = v
}
merged["timestamp"] = timestamp
sortedStr := buildSortedParamStr(merged)
switch s.config.SignMethod {
case SignMethodMD5:
return GenerateSignFromParamsMD5(s.config.AppSecret, sortedStr)
default:
return GenerateSignFromParamsHMAC(s.config.AppSecret, sortedStr)
}
}
// buildRequestURL 拼接接口地址得到最终请求 URL如 https://api.chinadatapay.com/communication/personal/197
func (s *ShujubaoService) buildRequestURL(apiPath string) string {
base := strings.TrimSuffix(s.config.URL, "/")
if apiPath == "" {
return base
}
return base + "/" + strings.TrimPrefix(apiPath, "/")
}
// CallAPI 调用数据宝 APIPOST。最终请求地址 = url + 拼接接口地址值body 为业务参数sign、timestamp 按原样传 header。
func (s *ShujubaoService) CallAPI(ctx context.Context, apiPath string, params map[string]interface{}) (data interface{}, err error) {
startTime := time.Now()
requestID := s.generateRequestID()
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
// 最终请求 URL = https://api.chinadatapay.com/communication + 拼接接口地址值,如 /personal/197
requestURL := s.buildRequestURL(apiPath)
var transactionID string
if id, ok := ctx.Value("transaction_id").(string); ok {
transactionID = id
}
if s.logger != nil {
s.logger.LogRequest(requestID, transactionID, apiPath, requestURL)
}
// 使用 application/x-www-form-urlencoded贵司接口暂不支持 JSON 入参
formBody := buildFormUrlEncodedBody(params)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, strings.NewReader(formBody))
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("timestamp", timestamp)
req.Header.Set("sign", s.generateSign(timestamp, params))
client := &http.Client{Timeout: s.config.Timeout}
response, err := client.Do(req)
if err != nil {
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
} else {
err = errors.Join(ErrSystem, err)
}
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
defer response.Body.Close()
respBody, err := io.ReadAll(response.Body)
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
if s.logger != nil {
duration := time.Since(startTime)
s.logger.LogResponse(requestID, transactionID, apiPath, response.StatusCode, duration)
}
if response.StatusCode != http.StatusOK {
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
var shujubaoResp ShujubaoResp
if err := json.Unmarshal(respBody, &shujubaoResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
code := shujubaoResp.Code
if code == "10001" || code == "10006" {
// 查空/查无:返回空数组,不视为错误
return []interface{}{}, nil
}
if code != "10000" {
shujubaoErr := NewShujubaoErrorFromCode(code, shujubaoResp.Message)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, shujubaoErr, paramsForLog(params))
}
return nil, errors.Join(ErrDatasource, shujubaoErr)
}
return shujubaoResp.Data, nil
}