309 lines
8.4 KiB
Go
309 lines
8.4 KiB
Go
package shujubao
|
||
|
||
import (
|
||
"context"
|
||
"crypto/md5"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"net/url"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"tyapi-server/internal/shared/external_logger"
|
||
)
|
||
|
||
const (
|
||
// 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志
|
||
maxLogParamValueLen = 300
|
||
)
|
||
|
||
var (
|
||
ErrDatasource = errors.New("数据源异常")
|
||
ErrSystem = errors.New("系统异常")
|
||
ErrQueryEmpty = errors.New("查询为空")
|
||
)
|
||
|
||
// truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容
|
||
func truncateForLog(s string, maxLen int) string {
|
||
if maxLen <= 0 {
|
||
return s
|
||
}
|
||
if len(s) <= maxLen {
|
||
return s
|
||
}
|
||
return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]"
|
||
}
|
||
|
||
// paramsForLog 返回适合写入错误日志的入参副本(长字符串会被截断)
|
||
func paramsForLog(params map[string]interface{}) map[string]interface{} {
|
||
if params == nil {
|
||
return nil
|
||
}
|
||
out := make(map[string]interface{}, len(params))
|
||
for k, v := range params {
|
||
if v == nil {
|
||
out[k] = nil
|
||
continue
|
||
}
|
||
switch val := v.(type) {
|
||
case string:
|
||
out[k] = truncateForLog(val, maxLogParamValueLen)
|
||
default:
|
||
s := fmt.Sprint(v)
|
||
out[k] = truncateForLog(s, maxLogParamValueLen)
|
||
}
|
||
}
|
||
return out
|
||
}
|
||
|
||
// ShujubaoResp 数据宝 API 通用响应(按实际文档调整)
|
||
type ShujubaoResp struct {
|
||
Code string `json:"code"`
|
||
Message string `json:"message"`
|
||
Data interface{} `json:"data"`
|
||
Success bool `json:"success"`
|
||
}
|
||
|
||
// ShujubaoConfig 数据宝服务配置
|
||
type ShujubaoConfig struct {
|
||
URL string
|
||
AppSecret string
|
||
SignMethod SignMethod
|
||
Timeout time.Duration
|
||
}
|
||
|
||
// ShujubaoService 数据宝服务
|
||
type ShujubaoService struct {
|
||
config ShujubaoConfig
|
||
logger *external_logger.ExternalServiceLogger
|
||
}
|
||
|
||
// NewShujubaoService 创建数据宝服务实例
|
||
func NewShujubaoService(url, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *ShujubaoService {
|
||
if signMethod == "" {
|
||
signMethod = SignMethodHMACMD5
|
||
}
|
||
if timeout == 0 {
|
||
timeout = 60 * time.Second
|
||
}
|
||
return &ShujubaoService{
|
||
config: ShujubaoConfig{
|
||
URL: url,
|
||
AppSecret: appSecret,
|
||
SignMethod: signMethod,
|
||
Timeout: timeout,
|
||
},
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// generateRequestID 生成请求 ID
|
||
func (s *ShujubaoService) generateRequestID() string {
|
||
timestamp := time.Now().UnixNano()
|
||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, s.config.AppSecret)))
|
||
return fmt.Sprintf("shujubao_%x", hash[:8])
|
||
}
|
||
|
||
// buildSortedParamStr 将入参按 key 的 ASCII 排序组合为 key1=value1&key2=value2&...
|
||
func buildSortedParamStr(params map[string]interface{}) string {
|
||
if len(params) == 0 {
|
||
return ""
|
||
}
|
||
keys := make([]string, 0, len(params))
|
||
for k := range params {
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
var b strings.Builder
|
||
for i, k := range keys {
|
||
if i > 0 {
|
||
b.WriteByte('&')
|
||
}
|
||
v := params[k]
|
||
var vs string
|
||
switch val := v.(type) {
|
||
case string:
|
||
vs = val
|
||
case nil:
|
||
vs = ""
|
||
default:
|
||
vs = fmt.Sprint(val)
|
||
}
|
||
b.WriteString(k)
|
||
b.WriteByte('=')
|
||
b.WriteString(vs)
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
// buildFormUrlEncodedBody 按 key 的 ASCII 排序构建 application/x-www-form-urlencoded 请求体(键与值均已 URL 编码)
|
||
func buildFormUrlEncodedBody(params map[string]interface{}) string {
|
||
if len(params) == 0 {
|
||
return ""
|
||
}
|
||
keys := make([]string, 0, len(params))
|
||
for k := range params {
|
||
keys = append(keys, k)
|
||
}
|
||
sort.Strings(keys)
|
||
var b strings.Builder
|
||
for i, k := range keys {
|
||
if i > 0 {
|
||
b.WriteByte('&')
|
||
}
|
||
v := params[k]
|
||
var vs string
|
||
switch val := v.(type) {
|
||
case string:
|
||
vs = val
|
||
case nil:
|
||
vs = ""
|
||
default:
|
||
vs = fmt.Sprint(val)
|
||
}
|
||
b.WriteString(url.QueryEscape(k))
|
||
b.WriteByte('=')
|
||
b.WriteString(url.QueryEscape(vs))
|
||
}
|
||
return b.String()
|
||
}
|
||
|
||
// generateSign 根据入参与时间戳生成签名。入参按 ASCII 排序组合后与 app_secret 做 MD5/HMAC。
|
||
// 对于开启了加密的接口需传 sign 与 timestamp;明文传输的接口则无需传这两个参数。
|
||
func (s *ShujubaoService) generateSign(timestamp string, params map[string]interface{}) string {
|
||
// 合并 timestamp 到入参后参与排序
|
||
merged := make(map[string]interface{}, len(params)+1)
|
||
for k, v := range params {
|
||
merged[k] = v
|
||
}
|
||
merged["timestamp"] = timestamp
|
||
sortedStr := buildSortedParamStr(merged)
|
||
switch s.config.SignMethod {
|
||
case SignMethodMD5:
|
||
return GenerateSignFromParamsMD5(s.config.AppSecret, sortedStr)
|
||
default:
|
||
return GenerateSignFromParamsHMAC(s.config.AppSecret, sortedStr)
|
||
}
|
||
}
|
||
|
||
// buildRequestURL 拼接接口地址得到最终请求 URL,如 https://api.chinadatapay.com/communication/personal/197
|
||
func (s *ShujubaoService) buildRequestURL(apiPath string) string {
|
||
base := strings.TrimSuffix(s.config.URL, "/")
|
||
if apiPath == "" {
|
||
return base
|
||
}
|
||
return base + "/" + strings.TrimPrefix(apiPath, "/")
|
||
}
|
||
|
||
// CallAPI 调用数据宝 API(POST)。最终请求地址 = url + 拼接接口地址值;body 为业务参数;sign、timestamp 按原样传 header。
|
||
func (s *ShujubaoService) CallAPI(ctx context.Context, apiPath string, params map[string]interface{}) (data interface{}, err error) {
|
||
startTime := time.Now()
|
||
requestID := s.generateRequestID()
|
||
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||
|
||
// 最终请求 URL = https://api.chinadatapay.com/communication + 拼接接口地址值,如 /personal/197
|
||
requestURL := s.buildRequestURL(apiPath)
|
||
|
||
var transactionID string
|
||
if id, ok := ctx.Value("transaction_id").(string); ok {
|
||
transactionID = id
|
||
}
|
||
|
||
if s.logger != nil {
|
||
s.logger.LogRequest(requestID, transactionID, apiPath, requestURL)
|
||
}
|
||
|
||
// 使用 application/x-www-form-urlencoded,贵司接口暂不支持 JSON 入参
|
||
formBody := buildFormUrlEncodedBody(params)
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, strings.NewReader(formBody))
|
||
if err != nil {
|
||
err = errors.Join(ErrSystem, err)
|
||
if s.logger != nil {
|
||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
req.Header.Set("timestamp", timestamp)
|
||
req.Header.Set("sign", s.generateSign(timestamp, params))
|
||
|
||
client := &http.Client{Timeout: s.config.Timeout}
|
||
response, err := client.Do(req)
|
||
if err != nil {
|
||
isTimeout := false
|
||
if ctx.Err() == context.DeadlineExceeded {
|
||
isTimeout = true
|
||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||
isTimeout = true
|
||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||
errStr == "timeout" ||
|
||
errStr == "Client.Timeout exceeded" ||
|
||
errStr == "net/http: request canceled" {
|
||
isTimeout = true
|
||
}
|
||
if isTimeout {
|
||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
|
||
} else {
|
||
err = errors.Join(ErrSystem, err)
|
||
}
|
||
if s.logger != nil {
|
||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||
}
|
||
return nil, err
|
||
}
|
||
defer response.Body.Close()
|
||
|
||
respBody, err := io.ReadAll(response.Body)
|
||
if err != nil {
|
||
err = errors.Join(ErrSystem, err)
|
||
if s.logger != nil {
|
||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
if s.logger != nil {
|
||
duration := time.Since(startTime)
|
||
s.logger.LogResponse(requestID, transactionID, apiPath, response.StatusCode, duration)
|
||
}
|
||
|
||
if response.StatusCode != http.StatusOK {
|
||
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
|
||
if s.logger != nil {
|
||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
var shujubaoResp ShujubaoResp
|
||
if err := json.Unmarshal(respBody, &shujubaoResp); err != nil {
|
||
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
|
||
if s.logger != nil {
|
||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
code := shujubaoResp.Code
|
||
if code == "10001" || code == "10006" {
|
||
// 查空/查无:返回空数组,不视为错误
|
||
return []interface{}{}, nil
|
||
}
|
||
if code != "10000" {
|
||
shujubaoErr := NewShujubaoErrorFromCode(code, shujubaoResp.Message)
|
||
if s.logger != nil {
|
||
s.logger.LogError(requestID, transactionID, apiPath, shujubaoErr, paramsForLog(params))
|
||
}
|
||
return nil, errors.Join(ErrDatasource, shujubaoErr)
|
||
}
|
||
|
||
return shujubaoResp.Data, nil
|
||
}
|