Files
tyapi-server/internal/infrastructure/external/shujubao/shujubao_service.go
2026-01-30 18:25:30 +08:00

266 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package shujubao
import (
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"tyapi-server/internal/shared/external_logger"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
)
// ShujubaoResp 数据宝 API 通用响应(按实际文档调整)
type ShujubaoResp struct {
Code string `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
Success bool `json:"success"`
}
// ShujubaoConfig 数据宝服务配置
type ShujubaoConfig struct {
URL string
AppSecret string
SignMethod SignMethod
Timeout time.Duration
}
// ShujubaoService 数据宝服务
type ShujubaoService struct {
config ShujubaoConfig
logger *external_logger.ExternalServiceLogger
}
// NewShujubaoService 创建数据宝服务实例
func NewShujubaoService(url, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *ShujubaoService {
if signMethod == "" {
signMethod = SignMethodHMACMD5
}
if timeout == 0 {
timeout = 60 * time.Second
}
return &ShujubaoService{
config: ShujubaoConfig{
URL: url,
AppSecret: appSecret,
SignMethod: signMethod,
Timeout: timeout,
},
logger: logger,
}
}
// generateRequestID 生成请求 ID
func (s *ShujubaoService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, s.config.AppSecret)))
return fmt.Sprintf("shujubao_%x", hash[:8])
}
// buildSortedParamStr 将入参按 key 的 ASCII 排序组合为 key1=value1&key2=value2&...
func buildSortedParamStr(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for i, k := range keys {
if i > 0 {
b.WriteByte('&')
}
v := params[k]
var vs string
switch val := v.(type) {
case string:
vs = val
case nil:
vs = ""
default:
vs = fmt.Sprint(val)
}
b.WriteString(k)
b.WriteByte('=')
b.WriteString(vs)
}
return b.String()
}
// buildFormUrlEncodedBody 按 key 的 ASCII 排序构建 application/x-www-form-urlencoded 请求体(键与值均已 URL 编码)
func buildFormUrlEncodedBody(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for i, k := range keys {
if i > 0 {
b.WriteByte('&')
}
v := params[k]
var vs string
switch val := v.(type) {
case string:
vs = val
case nil:
vs = ""
default:
vs = fmt.Sprint(val)
}
b.WriteString(url.QueryEscape(k))
b.WriteByte('=')
b.WriteString(url.QueryEscape(vs))
}
return b.String()
}
// generateSign 根据入参与时间戳生成签名。入参按 ASCII 排序组合后与 app_secret 做 MD5/HMAC。
// 对于开启了加密的接口需传 sign 与 timestamp明文传输的接口则无需传这两个参数。
func (s *ShujubaoService) generateSign(timestamp string, params map[string]interface{}) string {
// 合并 timestamp 到入参后参与排序
merged := make(map[string]interface{}, len(params)+1)
for k, v := range params {
merged[k] = v
}
merged["timestamp"] = timestamp
sortedStr := buildSortedParamStr(merged)
switch s.config.SignMethod {
case SignMethodMD5:
return GenerateSignFromParamsMD5(s.config.AppSecret, sortedStr)
default:
return GenerateSignFromParamsHMAC(s.config.AppSecret, sortedStr)
}
}
// buildRequestURL 拼接接口地址得到最终请求 URL如 https://api.chinadatapay.com/communication/personal/197
func (s *ShujubaoService) buildRequestURL(apiPath string) string {
base := strings.TrimSuffix(s.config.URL, "/")
if apiPath == "" {
return base
}
return base + "/" + strings.TrimPrefix(apiPath, "/")
}
// CallAPI 调用数据宝 APIPOST。最终请求地址 = url + 拼接接口地址值body 为业务参数sign、timestamp 按原样传 header。
func (s *ShujubaoService) CallAPI(ctx context.Context, apiPath string, params map[string]interface{}) (data interface{}, err error) {
startTime := time.Now()
requestID := s.generateRequestID()
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
// 最终请求 URL = https://api.chinadatapay.com/communication + 拼接接口地址值,如 /personal/197
requestURL := s.buildRequestURL(apiPath)
var transactionID string
if id, ok := ctx.Value("transaction_id").(string); ok {
transactionID = id
}
if s.logger != nil {
s.logger.LogRequest(requestID, transactionID, apiPath, requestURL)
}
// 使用 application/x-www-form-urlencoded贵司接口暂不支持 JSON 入参
formBody := buildFormUrlEncodedBody(params)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, strings.NewReader(formBody))
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, params)
}
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("timestamp", timestamp)
req.Header.Set("sign", s.generateSign(timestamp, params))
client := &http.Client{Timeout: s.config.Timeout}
response, err := client.Do(req)
if err != nil {
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
} else {
err = errors.Join(ErrSystem, err)
}
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, params)
}
return nil, err
}
defer response.Body.Close()
respBody, err := io.ReadAll(response.Body)
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, params)
}
return nil, err
}
if s.logger != nil {
duration := time.Since(startTime)
s.logger.LogResponse(requestID, transactionID, apiPath, response.StatusCode, duration)
}
if response.StatusCode != http.StatusOK {
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, params)
}
return nil, err
}
var shujubaoResp ShujubaoResp
if err := json.Unmarshal(respBody, &shujubaoResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, params)
}
return nil, err
}
code := shujubaoResp.Code
if code != "10000" && code != "10006" {
shujubaoErr := NewShujubaoErrorFromCode(code, shujubaoResp.Message)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, shujubaoErr, params)
}
return nil, errors.Join(ErrDatasource, shujubaoErr)
}
return shujubaoResp.Data, nil
}