package muzi import ( "bytes" "context" "crypto/aes" "crypto/md5" "encoding/base64" "encoding/hex" "encoding/json" "errors" "fmt" "io" "net/http" "reflect" "sort" "strconv" "time" "tyapi-server/internal/shared/external_logger" ) const defaultRequestTimeout = 60 * time.Second var ( ErrDatasource = errors.New("数据源异常") ErrSystem = errors.New("系统异常") ) // Muzi状态码常量 const ( CodeSuccess = 0 // 成功查询 CodeSystemError = 500 // 系统异常 CodeParamMissing = 601 // 参数不全 CodeInterfaceExpired = 602 // 接口已过期 CodeVerifyFailed = 603 // 接口校验失败 CodeIPNotInWhitelist = 604 // IP不在白名单中 CodeProductNotFound = 701 // 产品编号不存在 CodeUserNotFound = 702 // 用户名不存在 CodeUnauthorizedAPI = 703 // 接口未授权 CodeInsufficientFund = 704 // 商户余额不足 ) // MuziResponse 木子数据接口通用响应 type MuziResponse struct { Code int `json:"code"` Msg string `json:"msg"` Data json.RawMessage `json:"data"` Timestamp int64 `json:"timestamp"` ExecuteTime int64 `json:"executeTime"` } // MuziConfig 木子数据接口配置 type MuziConfig struct { URL string AppID string AppSecret string Timeout time.Duration } // MuziService 木子数据接口服务封装 type MuziService struct { config MuziConfig logger *external_logger.ExternalServiceLogger } // NewMuziService 创建木子数据服务实例 func NewMuziService(url, appID, appSecret string, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *MuziService { if timeout <= 0 { timeout = defaultRequestTimeout } return &MuziService{ config: MuziConfig{ URL: url, AppID: appID, AppSecret: appSecret, Timeout: timeout, }, logger: logger, } } // generateRequestID 生成请求ID func (m *MuziService) generateRequestID() string { timestamp := time.Now().UnixNano() raw := fmt.Sprintf("%d_%s", timestamp, m.config.AppID) sum := md5.Sum([]byte(raw)) return fmt.Sprintf("muzi_%x", sum[:8]) } // CallAPI 调用木子数据接口 func (m *MuziService) CallAPI(ctx context.Context, prodCode string, params map[string]interface{}) (json.RawMessage, error) { requestID := m.generateRequestID() now := time.Now() timestamp := strconv.FormatInt(now.UnixMilli(), 10) flatParams := flattenParams(params) signParts := collectSignatureValues(params) signature := m.GenerateSignature(prodCode, timestamp, signParts...) // 从上下文获取链路ID var transactionID string if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok { transactionID = ctxTransactionID } requestBody := map[string]interface{}{ "appId": m.config.AppID, "prodCode": prodCode, "timestamp": timestamp, "signature": signature, } for key, value := range flatParams { requestBody[key] = value } if m.logger != nil { m.logger.LogRequest(requestID, transactionID, prodCode, m.config.URL, requestBody) } bodyBytes, marshalErr := json.Marshal(requestBody) if marshalErr != nil { err := errors.Join(ErrSystem, marshalErr) if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, err, requestBody) } return nil, err } req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, m.config.URL, bytes.NewBuffer(bodyBytes)) if reqErr != nil { err := errors.Join(ErrSystem, reqErr) if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, err, requestBody) } return nil, err } req.Header.Set("Content-Type", "application/json") client := &http.Client{ Timeout: m.config.Timeout, } resp, httpErr := client.Do(req) if httpErr != nil { err := wrapHTTPError(httpErr) if errors.Is(err, ErrDatasource) { err = errors.Join(err, fmt.Errorf("API请求超时: %v", httpErr)) } if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, err, requestBody) } return nil, err } defer func(body io.ReadCloser) { closeErr := body.Close() if closeErr != nil && m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), requestBody) } }(resp.Body) respBody, readErr := io.ReadAll(resp.Body) if readErr != nil { err := errors.Join(ErrSystem, readErr) if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, err, requestBody) } return nil, err } if m.logger != nil { m.logger.LogResponse(requestID, transactionID, prodCode, resp.StatusCode, respBody, time.Since(now)) } if resp.StatusCode != http.StatusOK { err := errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", resp.StatusCode)) if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, err, requestBody) } return nil, err } var muziResp MuziResponse if err := json.Unmarshal(respBody, &muziResp); err != nil { err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %v", err)) if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, err, requestBody) } return nil, err } if muziResp.Code != CodeSuccess { muziErr := NewMuziError(muziResp.Code, muziResp.Msg) var resultErr error switch muziResp.Code { case CodeSystemError: resultErr = errors.Join(ErrDatasource, muziErr) default: resultErr = errors.Join(ErrSystem, muziErr) } if m.logger != nil { m.logger.LogError(requestID, transactionID, prodCode, muziErr, requestBody) } return nil, resultErr } return muziResp.Data, nil } func wrapHTTPError(err error) error { var timeout bool if err == context.DeadlineExceeded { timeout = true } else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() { timeout = true } else if errStr := err.Error(); errStr == "context deadline exceeded" || errStr == "timeout" || errStr == "Client.Timeout exceeded" || errStr == "net/http: request canceled" { timeout = true } if timeout { return errors.Join(ErrDatasource, err) } return errors.Join(ErrSystem, err) } func pkcs5Padding(src []byte, blockSize int) []byte { padding := blockSize - len(src)%blockSize padtext := bytes.Repeat([]byte{byte(padding)}, padding) return append(src, padtext...) } func flattenParams(params map[string]interface{}) map[string]interface{} { result := make(map[string]interface{}) if params == nil { return result } for key, value := range params { flattenValue(key, value, result) } return result } func flattenValue(prefix string, value interface{}, out map[string]interface{}) { switch val := value.(type) { case map[string]interface{}: for k, v := range val { flattenValue(combinePrefix(prefix, k), v, out) } case map[interface{}]interface{}: for k, v := range val { keyStr := fmt.Sprint(k) flattenValue(combinePrefix(prefix, keyStr), v, out) } case []interface{}: for i, item := range val { nextPrefix := fmt.Sprintf("%s[%d]", prefix, i) flattenValue(nextPrefix, item, out) } case []string: for i, item := range val { nextPrefix := fmt.Sprintf("%s[%d]", prefix, i) flattenValue(nextPrefix, item, out) } case []int: for i, item := range val { nextPrefix := fmt.Sprintf("%s[%d]", prefix, i) flattenValue(nextPrefix, item, out) } case []float64: for i, item := range val { nextPrefix := fmt.Sprintf("%s[%d]", prefix, i) flattenValue(nextPrefix, item, out) } case []bool: for i, item := range val { nextPrefix := fmt.Sprintf("%s[%d]", prefix, i) flattenValue(nextPrefix, item, out) } default: out[prefix] = val } } func combinePrefix(prefix, key string) string { if prefix == "" { return key } return prefix + "." + key } // Encrypt 使用 AES/ECB/PKCS5Padding 对单个字符串进行加密并返回 Base64 结果 func (m *MuziService) Encrypt(value string) (string, error) { if len(m.config.AppSecret) != 32 { return "", fmt.Errorf("AppSecret长度必须为32位") } block, err := aes.NewCipher([]byte(m.config.AppSecret)) if err != nil { return "", fmt.Errorf("初始化加密器失败: %w", err) } padded := pkcs5Padding([]byte(value), block.BlockSize()) encrypted := make([]byte, len(padded)) for bs, be := 0, block.BlockSize(); bs < len(padded); bs, be = bs+block.BlockSize(), be+block.BlockSize() { block.Encrypt(encrypted[bs:be], padded[bs:be]) } return base64.StdEncoding.EncodeToString(encrypted), nil } // GenerateSignature 根据协议生成签名,extraValues 会按顺序追加在待签名字符串之后 func (m *MuziService) GenerateSignature(prodCode, timestamp string, extraValues ...string) string { signStr := m.config.AppID + prodCode + timestamp for _, extra := range extraValues { signStr += extra } hash := md5.Sum([]byte(signStr)) return hex.EncodeToString(hash[:]) } // GenerateTimestamp 生成当前毫秒级时间戳字符串 func (m *MuziService) GenerateTimestamp() string { return strconv.FormatInt(time.Now().UnixMilli(), 10) } // FlattenParams 将嵌套参数展平为一维键值对 func (m *MuziService) FlattenParams(params map[string]interface{}) map[string]interface{} { return flattenParams(params) } func collectSignatureValues(data interface{}) []string { var result []string collectSignatureValuesRecursive(reflect.ValueOf(data), &result) return result } func collectSignatureValuesRecursive(value reflect.Value, result *[]string) { if !value.IsValid() { *result = append(*result, "") return } switch value.Kind() { case reflect.Pointer, reflect.Interface: if value.IsNil() { *result = append(*result, "") return } collectSignatureValuesRecursive(value.Elem(), result) case reflect.Map: keys := value.MapKeys() sort.Slice(keys, func(i, j int) bool { return fmt.Sprint(keys[i].Interface()) < fmt.Sprint(keys[j].Interface()) }) for _, key := range keys { collectSignatureValuesRecursive(value.MapIndex(key), result) } case reflect.Slice, reflect.Array: for i := 0; i < value.Len(); i++ { collectSignatureValuesRecursive(value.Index(i), result) } case reflect.Struct: typeInfo := value.Type() fieldNames := make([]string, 0, value.NumField()) fieldIndices := make(map[string]int, value.NumField()) for i := 0; i < value.NumField(); i++ { field := typeInfo.Field(i) if field.PkgPath != "" { continue } fieldNames = append(fieldNames, field.Name) fieldIndices[field.Name] = i } sort.Strings(fieldNames) for _, name := range fieldNames { collectSignatureValuesRecursive(value.Field(fieldIndices[name]), result) } default: *result = append(*result, fmt.Sprint(value.Interface())) } }