417 lines
10 KiB
Go
417 lines
10 KiB
Go
package tianyuanapi
|
||
|
||
import (
|
||
"bytes"
|
||
"crypto/aes"
|
||
"crypto/cipher"
|
||
"crypto/rand"
|
||
"encoding/base64"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"time"
|
||
)
|
||
|
||
// API调用相关错误类型
|
||
var (
|
||
ErrQueryEmpty = errors.New("查询为空")
|
||
ErrSystem = errors.New("接口异常")
|
||
ErrDecryptFail = errors.New("解密失败")
|
||
ErrRequestParam = errors.New("请求参数结构不正确")
|
||
ErrInvalidParam = errors.New("参数校验不正确")
|
||
ErrInvalidIP = errors.New("未经授权的IP")
|
||
ErrMissingAccessId = errors.New("缺少Access-Id")
|
||
ErrInvalidAccessId = errors.New("未经授权的AccessId")
|
||
ErrFrozenAccount = errors.New("账户已冻结")
|
||
ErrArrears = errors.New("账户余额不足,无法请求")
|
||
ErrProductNotFound = errors.New("产品不存在")
|
||
ErrProductDisabled = errors.New("产品已停用")
|
||
ErrNotSubscribed = errors.New("未订阅此产品")
|
||
ErrBusiness = errors.New("业务失败")
|
||
)
|
||
|
||
// 错误码映射 - 严格按照用户要求
|
||
var ErrorCodeMap = map[error]int{
|
||
ErrQueryEmpty: 1000,
|
||
ErrSystem: 1001,
|
||
ErrDecryptFail: 1002,
|
||
ErrRequestParam: 1003,
|
||
ErrInvalidParam: 1003,
|
||
ErrInvalidIP: 1004,
|
||
ErrMissingAccessId: 1005,
|
||
ErrInvalidAccessId: 1006,
|
||
ErrFrozenAccount: 1007,
|
||
ErrArrears: 1007,
|
||
ErrProductNotFound: 1008,
|
||
ErrProductDisabled: 1008,
|
||
ErrNotSubscribed: 1008,
|
||
ErrBusiness: 2001,
|
||
}
|
||
|
||
// ApiCallOptions API调用选项
|
||
type ApiCallOptions struct {
|
||
Json bool `json:"json,omitempty"` // 是否返回JSON格式
|
||
}
|
||
|
||
// Client 天元API客户端
|
||
type Client struct {
|
||
accessID string
|
||
key string
|
||
baseURL string
|
||
timeout time.Duration
|
||
client *http.Client
|
||
}
|
||
|
||
// Config 客户端配置
|
||
type Config struct {
|
||
AccessID string // 访问ID
|
||
Key string // AES密钥(16进制)
|
||
BaseURL string // API基础URL
|
||
Timeout time.Duration // 超时时间
|
||
}
|
||
|
||
// Request 请求参数
|
||
type Request struct {
|
||
InterfaceName string `json:"interfaceName"` // 接口名称
|
||
Params map[string]interface{} `json:"params"` // 请求参数
|
||
Timeout int `json:"timeout"` // 超时时间(毫秒)
|
||
Options *ApiCallOptions `json:"options"` // 调用选项
|
||
}
|
||
|
||
// ApiResponse HTTP API响应
|
||
type ApiResponse struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
TransactionID string `json:"transaction_id"` // 流水号
|
||
Data string `json:"data"` // 加密的数据
|
||
}
|
||
|
||
// Response Call方法的响应
|
||
type Response struct {
|
||
Code int `json:"code"`
|
||
Message string `json:"message"`
|
||
Success bool `json:"success"`
|
||
TransactionID string `json:"transaction_id"` // 流水号
|
||
Data map[string]interface{} `json:"data"` // 解密后的数据
|
||
Timeout int64 `json:"timeout"` // 请求耗时(毫秒)
|
||
Error string `json:"error,omitempty"`
|
||
}
|
||
|
||
// NewClient 创建新的客户端实例
|
||
func NewClient(config Config) (*Client, error) {
|
||
// 参数校验
|
||
if config.AccessID == "" {
|
||
return nil, fmt.Errorf("accessID不能为空")
|
||
}
|
||
if config.Key == "" {
|
||
return nil, fmt.Errorf("key不能为空")
|
||
}
|
||
if config.BaseURL == "" {
|
||
config.BaseURL = "http://127.0.0.1:8080"
|
||
}
|
||
if config.Timeout == 0 {
|
||
config.Timeout = 60 * time.Second
|
||
}
|
||
|
||
// 验证密钥格式
|
||
if _, err := hex.DecodeString(config.Key); err != nil {
|
||
return nil, fmt.Errorf("无效的密钥格式,必须是16进制字符串: %v", err)
|
||
}
|
||
|
||
return &Client{
|
||
accessID: config.AccessID,
|
||
key: config.Key,
|
||
baseURL: config.BaseURL,
|
||
timeout: config.Timeout,
|
||
client: &http.Client{
|
||
Timeout: config.Timeout,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// Call 调用API接口
|
||
func (c *Client) Call(req Request) (*Response, error) {
|
||
startTime := time.Now()
|
||
|
||
// 参数校验
|
||
if err := c.validateRequest(req); err != nil {
|
||
return nil, fmt.Errorf("请求参数校验失败: %v", err)
|
||
}
|
||
|
||
// 加密参数
|
||
jsonData, err := json.Marshal(req.Params)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("参数序列化失败: %v", err)
|
||
}
|
||
|
||
encryptedData, err := c.encrypt(string(jsonData))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("数据加密失败: %v", err)
|
||
}
|
||
|
||
// 构建请求体
|
||
requestBody := map[string]interface{}{
|
||
"data": encryptedData,
|
||
}
|
||
|
||
// 添加选项
|
||
if req.Options != nil {
|
||
requestBody["options"] = req.Options
|
||
} else {
|
||
// 默认选项
|
||
defaultOptions := &ApiCallOptions{
|
||
Json: true,
|
||
}
|
||
requestBody["options"] = defaultOptions
|
||
}
|
||
|
||
requestBodyBytes, err := json.Marshal(requestBody)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("请求体序列化失败: %v", err)
|
||
}
|
||
|
||
// 创建HTTP请求
|
||
url := fmt.Sprintf("%s/api/v1/%s", c.baseURL, req.InterfaceName)
|
||
|
||
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(requestBodyBytes))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建HTTP请求失败: %v", err)
|
||
}
|
||
|
||
// 设置请求头
|
||
httpReq.Header.Set("Content-Type", "application/json")
|
||
httpReq.Header.Set("Access-Id", c.accessID)
|
||
httpReq.Header.Set("User-Agent", "TianyuanAPI-Go-SDK/1.0.0")
|
||
|
||
// 发送请求
|
||
resp, err := c.client.Do(httpReq)
|
||
if err != nil {
|
||
endTime := time.Now()
|
||
requestTime := endTime.Sub(startTime).Milliseconds()
|
||
return &Response{
|
||
Success: false,
|
||
Message: "请求失败",
|
||
Error: err.Error(),
|
||
Timeout: requestTime,
|
||
}, nil
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
// 读取响应
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
endTime := time.Now()
|
||
requestTime := endTime.Sub(startTime).Milliseconds()
|
||
return &Response{
|
||
Success: false,
|
||
Message: "读取响应失败",
|
||
Error: err.Error(),
|
||
Timeout: requestTime,
|
||
}, nil
|
||
}
|
||
|
||
// 解析HTTP API响应
|
||
var apiResp ApiResponse
|
||
if err := json.Unmarshal(body, &apiResp); err != nil {
|
||
endTime := time.Now()
|
||
requestTime := endTime.Sub(startTime).Milliseconds()
|
||
return &Response{
|
||
Success: false,
|
||
Message: "响应解析失败",
|
||
Error: err.Error(),
|
||
Timeout: requestTime,
|
||
}, nil
|
||
}
|
||
|
||
// 计算请求耗时
|
||
endTime := time.Now()
|
||
requestTime := endTime.Sub(startTime).Milliseconds()
|
||
|
||
// 构建Call方法的响应
|
||
response := &Response{
|
||
Code: apiResp.Code,
|
||
Message: apiResp.Message,
|
||
Success: apiResp.Code == 0,
|
||
TransactionID: apiResp.TransactionID,
|
||
Timeout: requestTime,
|
||
}
|
||
|
||
// 如果有加密数据,尝试解密
|
||
if apiResp.Data != "" {
|
||
decryptedData, err := c.decrypt(apiResp.Data)
|
||
if err == nil {
|
||
var decryptedMap map[string]interface{}
|
||
if json.Unmarshal([]byte(decryptedData), &decryptedMap) == nil {
|
||
response.Data = decryptedMap
|
||
}
|
||
}
|
||
}
|
||
|
||
// 根据响应码返回对应的错误
|
||
if apiResp.Code != 0 {
|
||
err := GetErrorByCode(apiResp.Code)
|
||
return nil, err
|
||
}
|
||
|
||
return response, nil
|
||
}
|
||
|
||
// CallInterface 简化接口调用方法
|
||
func (c *Client) CallInterface(interfaceName string, params map[string]interface{}, options ...*ApiCallOptions) (*Response, error) {
|
||
var opts *ApiCallOptions
|
||
if len(options) > 0 {
|
||
opts = options[0]
|
||
}
|
||
|
||
req := Request{
|
||
InterfaceName: interfaceName,
|
||
Params: params,
|
||
Timeout: 60000,
|
||
Options: opts,
|
||
}
|
||
|
||
return c.Call(req)
|
||
}
|
||
|
||
// validateRequest 校验请求参数
|
||
func (c *Client) validateRequest(req Request) error {
|
||
if req.InterfaceName == "" {
|
||
return fmt.Errorf("interfaceName不能为空")
|
||
}
|
||
if req.Params == nil {
|
||
return fmt.Errorf("params不能为空")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// encrypt AES CBC加密
|
||
func (c *Client) encrypt(plainText string) (string, error) {
|
||
keyBytes, err := hex.DecodeString(c.key)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
block, err := aes.NewCipher(keyBytes)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 生成随机IV
|
||
iv := make([]byte, aes.BlockSize)
|
||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||
return "", err
|
||
}
|
||
|
||
// 填充数据
|
||
paddedData := c.pkcs7Pad([]byte(plainText), aes.BlockSize)
|
||
|
||
// 加密
|
||
ciphertext := make([]byte, len(iv)+len(paddedData))
|
||
copy(ciphertext, iv)
|
||
|
||
mode := cipher.NewCBCEncrypter(block, iv)
|
||
mode.CryptBlocks(ciphertext[len(iv):], paddedData)
|
||
|
||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||
}
|
||
|
||
// decrypt AES CBC解密
|
||
func (c *Client) decrypt(encryptedText string) (string, error) {
|
||
keyBytes, err := hex.DecodeString(c.key)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
ciphertext, err := base64.StdEncoding.DecodeString(encryptedText)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
block, err := aes.NewCipher(keyBytes)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
if len(ciphertext) < aes.BlockSize {
|
||
return "", fmt.Errorf("密文太短")
|
||
}
|
||
|
||
iv := ciphertext[:aes.BlockSize]
|
||
ciphertext = ciphertext[aes.BlockSize:]
|
||
|
||
if len(ciphertext)%aes.BlockSize != 0 {
|
||
return "", fmt.Errorf("密文长度不是块大小的倍数")
|
||
}
|
||
|
||
mode := cipher.NewCBCDecrypter(block, iv)
|
||
mode.CryptBlocks(ciphertext, ciphertext)
|
||
|
||
// 去除填充
|
||
unpaddedData, err := c.pkcs7Unpad(ciphertext)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
|
||
return string(unpaddedData), nil
|
||
}
|
||
|
||
// pkcs7Pad PKCS7填充
|
||
func (c *Client) pkcs7Pad(data []byte, blockSize int) []byte {
|
||
padding := blockSize - len(data)%blockSize
|
||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||
return append(data, padtext...)
|
||
}
|
||
|
||
// pkcs7Unpad PKCS7去除填充
|
||
func (c *Client) pkcs7Unpad(data []byte) ([]byte, error) {
|
||
length := len(data)
|
||
if length == 0 {
|
||
return nil, fmt.Errorf("数据为空")
|
||
}
|
||
unpadding := int(data[length-1])
|
||
if unpadding > length {
|
||
return nil, fmt.Errorf("无效的填充")
|
||
}
|
||
return data[:length-unpadding], nil
|
||
}
|
||
|
||
// GetErrorByCode 根据错误码获取错误
|
||
func GetErrorByCode(code int) error {
|
||
// 对于有多个错误对应同一错误码的情况,返回第一个
|
||
switch code {
|
||
case 1000:
|
||
return ErrQueryEmpty
|
||
case 1001:
|
||
return ErrSystem
|
||
case 1002:
|
||
return ErrDecryptFail
|
||
case 1003:
|
||
return ErrRequestParam
|
||
case 1004:
|
||
return ErrInvalidIP
|
||
case 1005:
|
||
return ErrMissingAccessId
|
||
case 1006:
|
||
return ErrInvalidAccessId
|
||
case 1007:
|
||
return ErrFrozenAccount
|
||
case 1008:
|
||
return ErrProductNotFound
|
||
case 2001:
|
||
return ErrBusiness
|
||
default:
|
||
return fmt.Errorf("未知错误码: %d", code)
|
||
}
|
||
}
|
||
|
||
// GetCodeByError 根据错误获取错误码
|
||
func GetCodeByError(err error) int {
|
||
if code, exists := ErrorCodeMap[err]; exists {
|
||
return code
|
||
}
|
||
return -1
|
||
}
|