359 lines
9.2 KiB
Go
359 lines
9.2 KiB
Go
package crypto
|
||
|
||
import (
|
||
"context"
|
||
"crypto/hmac"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"errors"
|
||
"fmt"
|
||
"sort"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"tyapi-server/internal/shared/interfaces"
|
||
)
|
||
|
||
const (
|
||
// SignatureTimestampTolerance 签名时间戳容差(秒),防止重放攻击
|
||
SignatureTimestampTolerance = 300 // 5分钟
|
||
)
|
||
|
||
// GenerateSignature 生成HMAC-SHA256签名
|
||
// params: 需要签名的参数map
|
||
// secretKey: 签名密钥
|
||
// timestamp: 时间戳(秒)
|
||
// nonce: 随机字符串
|
||
func GenerateSignature(params map[string]string, secretKey string, timestamp int64, nonce string) string {
|
||
// 1. 构建待签名字符串:按key排序,拼接成 key1=value1&key2=value2 格式
|
||
var keys []string
|
||
for k := range params {
|
||
if k != "signature" { // 排除签名字段本身
|
||
keys = append(keys, k)
|
||
}
|
||
}
|
||
sort.Strings(keys)
|
||
|
||
var parts []string
|
||
for _, k := range keys {
|
||
parts = append(parts, fmt.Sprintf("%s=%s", k, params[k]))
|
||
}
|
||
|
||
// 2. 添加时间戳和随机数
|
||
parts = append(parts, fmt.Sprintf("timestamp=%d", timestamp))
|
||
parts = append(parts, fmt.Sprintf("nonce=%s", nonce))
|
||
|
||
// 3. 拼接成待签名字符串
|
||
signString := strings.Join(parts, "&")
|
||
|
||
// 4. 使用HMAC-SHA256计算签名
|
||
mac := hmac.New(sha256.New, []byte(secretKey))
|
||
mac.Write([]byte(signString))
|
||
signature := mac.Sum(nil)
|
||
|
||
// 5. 返回hex编码的签名
|
||
return hex.EncodeToString(signature)
|
||
}
|
||
|
||
// VerifySignature 验证HMAC-SHA256签名
|
||
// params: 请求参数map(包含signature字段)
|
||
// secretKey: 签名密钥
|
||
// timestamp: 时间戳(秒)
|
||
// nonce: 随机字符串
|
||
func VerifySignature(params map[string]string, secretKey string, timestamp int64, nonce string) error {
|
||
// 1. 检查签名字段是否存在
|
||
signature, exists := params["signature"]
|
||
if !exists || signature == "" {
|
||
return errors.New("签名字段缺失")
|
||
}
|
||
|
||
// 2. 验证时间戳(防止重放攻击)
|
||
now := time.Now().Unix()
|
||
if timestamp <= 0 {
|
||
return errors.New("时间戳无效")
|
||
}
|
||
if abs(now-timestamp) > SignatureTimestampTolerance {
|
||
return fmt.Errorf("请求已过期,时间戳超出容差范围(当前时间:%d,请求时间:%d)", now, timestamp)
|
||
}
|
||
|
||
// 3. 重新计算签名
|
||
expectedSignature := GenerateSignature(params, secretKey, timestamp, nonce)
|
||
|
||
// 4. 将hex字符串转换为字节数组进行比较
|
||
signatureBytes, err := hex.DecodeString(signature)
|
||
if err != nil {
|
||
return fmt.Errorf("签名格式错误: %w", err)
|
||
}
|
||
expectedBytes, err := hex.DecodeString(expectedSignature)
|
||
if err != nil {
|
||
return fmt.Errorf("签名计算错误: %w", err)
|
||
}
|
||
|
||
// 5. 使用常量时间比较防止时序攻击
|
||
if !hmac.Equal(signatureBytes, expectedBytes) {
|
||
return errors.New("签名验证失败")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// VerifySignatureWithNonceCheck 验证HMAC-SHA256签名并检查nonce唯一性(防止重放攻击)
|
||
// params: 请求参数map(包含signature字段)
|
||
// secretKey: 签名密钥
|
||
// timestamp: 时间戳(秒)
|
||
// nonce: 随机字符串
|
||
// cache: 缓存服务,用于存储已使用的nonce
|
||
// cacheKeyPrefix: 缓存键前缀
|
||
func VerifySignatureWithNonceCheck(
|
||
ctx context.Context,
|
||
params map[string]string,
|
||
secretKey string,
|
||
timestamp int64,
|
||
nonce string,
|
||
cache interfaces.CacheService,
|
||
cacheKeyPrefix string,
|
||
) error {
|
||
// 1. 先进行基础签名验证
|
||
if err := VerifySignature(params, secretKey, timestamp, nonce); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 2. 检查nonce是否已被使用(防止重放攻击)
|
||
// 使用请求指纹:phone+timestamp+nonce 作为唯一标识
|
||
phone := params["phone"]
|
||
if phone == "" {
|
||
return errors.New("手机号不能为空")
|
||
}
|
||
|
||
// 构建nonce唯一性检查的缓存键
|
||
nonceKey := fmt.Sprintf("%s:nonce:%s:%d:%s", cacheKeyPrefix, phone, timestamp, nonce)
|
||
|
||
// 检查nonce是否已被使用
|
||
exists, err := cache.Exists(ctx, nonceKey)
|
||
if err != nil {
|
||
// 缓存查询失败,记录错误但继续验证(避免缓存故障导致服务不可用)
|
||
return fmt.Errorf("检查nonce唯一性失败: %w", err)
|
||
}
|
||
if exists {
|
||
return errors.New("请求已被使用,请勿重复提交")
|
||
}
|
||
|
||
// 3. 将nonce标记为已使用,TTL设置为时间戳容差+1分钟(确保在容差范围内不会重复使用)
|
||
ttl := time.Duration(SignatureTimestampTolerance+60) * time.Second
|
||
if err := cache.Set(ctx, nonceKey, true, ttl); err != nil {
|
||
// 记录错误但不影响验证流程(避免缓存故障导致服务不可用)
|
||
return fmt.Errorf("标记nonce已使用失败: %w", err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 自定义编码字符集(不使用标准Base64字符集,增加破解难度)
|
||
// 使用自定义字符集:数字+大写字母(排除易混淆的I和O)+小写字母(排除易混淆的i和l)+特殊字符
|
||
const customEncodeCharset = "0123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||
|
||
// EncodeRequest 使用自定义编码方案编码请求参数
|
||
// 编码方式:类似Base64,但使用自定义字符集,并加入简单的混淆
|
||
func EncodeRequest(data string) string {
|
||
// 1. 将字符串转换为字节数组
|
||
bytes := []byte(data)
|
||
|
||
// 2. 使用自定义Base64变种编码
|
||
encoded := customBase64Encode(bytes)
|
||
|
||
// 3. 添加简单的字符混淆(字符偏移)
|
||
confused := applyCharShift(encoded, 7) // 偏移7个位置
|
||
|
||
return confused
|
||
}
|
||
|
||
// DecodeRequest 解码请求参数
|
||
func DecodeRequest(encodedData string) (string, error) {
|
||
// 1. 先还原字符混淆
|
||
unconfused := reverseCharShift(encodedData, 7)
|
||
|
||
// 2. 使用自定义Base64变种解码
|
||
decoded, err := customBase64Decode(unconfused)
|
||
if err != nil {
|
||
return "", fmt.Errorf("解码失败: %w", err)
|
||
}
|
||
|
||
return string(decoded), nil
|
||
}
|
||
|
||
// customBase64Encode 自定义Base64编码(使用自定义字符集)
|
||
func customBase64Encode(data []byte) string {
|
||
if len(data) == 0 {
|
||
return ""
|
||
}
|
||
|
||
var result []byte
|
||
charset := []byte(customEncodeCharset)
|
||
|
||
// 将3个字节(24位)编码为4个字符
|
||
for i := 0; i < len(data); i += 3 {
|
||
// 读取3个字节
|
||
var b1, b2, b3 byte
|
||
b1 = data[i]
|
||
if i+1 < len(data) {
|
||
b2 = data[i+1]
|
||
}
|
||
if i+2 < len(data) {
|
||
b3 = data[i+2]
|
||
}
|
||
|
||
// 组合成24位
|
||
combined := uint32(b1)<<16 | uint32(b2)<<8 | uint32(b3)
|
||
|
||
// 分成4个6位段
|
||
result = append(result, charset[(combined>>18)&0x3F])
|
||
result = append(result, charset[(combined>>12)&0x3F])
|
||
|
||
if i+1 < len(data) {
|
||
result = append(result, charset[(combined>>6)&0x3F])
|
||
} else {
|
||
result = append(result, '=') // 填充字符
|
||
}
|
||
|
||
if i+2 < len(data) {
|
||
result = append(result, charset[combined&0x3F])
|
||
} else {
|
||
result = append(result, '=') // 填充字符
|
||
}
|
||
}
|
||
|
||
return string(result)
|
||
}
|
||
|
||
// customBase64Decode 自定义Base64解码
|
||
func customBase64Decode(encoded string) ([]byte, error) {
|
||
if len(encoded) == 0 {
|
||
return []byte{}, nil
|
||
}
|
||
|
||
charset := []byte(customEncodeCharset)
|
||
charsetMap := make(map[byte]int)
|
||
for i, c := range charset {
|
||
charsetMap[c] = i
|
||
}
|
||
|
||
var result []byte
|
||
data := []byte(encoded)
|
||
|
||
// 将4个字符解码为3个字节
|
||
for i := 0; i < len(data); i += 4 {
|
||
if i+3 >= len(data) {
|
||
return nil, fmt.Errorf("编码数据长度不正确")
|
||
}
|
||
|
||
// 获取4个字符的索引
|
||
var idx [4]int
|
||
for j := 0; j < 4; j++ {
|
||
if data[i+j] == '=' {
|
||
idx[j] = 0 // 填充字符
|
||
} else {
|
||
val, ok := charsetMap[data[i+j]]
|
||
if !ok {
|
||
return nil, fmt.Errorf("无效的编码字符: %c", data[i+j])
|
||
}
|
||
idx[j] = val
|
||
}
|
||
}
|
||
|
||
// 组合成24位
|
||
combined := uint32(idx[0])<<18 | uint32(idx[1])<<12 | uint32(idx[2])<<6 | uint32(idx[3])
|
||
|
||
// 提取3个字节
|
||
result = append(result, byte((combined>>16)&0xFF))
|
||
if data[i+2] != '=' {
|
||
result = append(result, byte((combined>>8)&0xFF))
|
||
}
|
||
if data[i+3] != '=' {
|
||
result = append(result, byte(combined&0xFF))
|
||
}
|
||
}
|
||
|
||
return result, nil
|
||
}
|
||
|
||
// applyCharShift 应用字符偏移混淆
|
||
func applyCharShift(data string, shift int) string {
|
||
charset := customEncodeCharset
|
||
charsetLen := len(charset)
|
||
result := make([]byte, len(data))
|
||
|
||
for i, c := range []byte(data) {
|
||
if c == '=' {
|
||
result[i] = c // 填充字符不变
|
||
continue
|
||
}
|
||
|
||
// 查找字符在字符集中的位置
|
||
idx := -1
|
||
for j, ch := range []byte(charset) {
|
||
if ch == c {
|
||
idx = j
|
||
break
|
||
}
|
||
}
|
||
|
||
if idx == -1 {
|
||
result[i] = c // 不在字符集中,保持不变
|
||
} else {
|
||
// 应用偏移
|
||
newIdx := (idx + shift) % charsetLen
|
||
result[i] = charset[newIdx]
|
||
}
|
||
}
|
||
|
||
return string(result)
|
||
}
|
||
|
||
// reverseCharShift 还原字符偏移混淆
|
||
func reverseCharShift(data string, shift int) string {
|
||
charset := customEncodeCharset
|
||
charsetLen := len(charset)
|
||
result := make([]byte, len(data))
|
||
|
||
for i, c := range []byte(data) {
|
||
if c == '=' {
|
||
result[i] = c // 填充字符不变
|
||
continue
|
||
}
|
||
|
||
// 查找字符在字符集中的位置
|
||
idx := -1
|
||
for j, ch := range []byte(charset) {
|
||
if ch == c {
|
||
idx = j
|
||
break
|
||
}
|
||
}
|
||
|
||
if idx == -1 {
|
||
result[i] = c // 不在字符集中,保持不变
|
||
} else {
|
||
// 还原偏移
|
||
newIdx := (idx - shift + charsetLen) % charsetLen
|
||
result[i] = charset[newIdx]
|
||
}
|
||
}
|
||
|
||
return string(result)
|
||
}
|
||
|
||
// abs 计算绝对值
|
||
func abs(x int64) int64 {
|
||
if x < 0 {
|
||
return -x
|
||
}
|
||
return x
|
||
}
|
||
|
||
// ParseTimestamp 从字符串解析时间戳
|
||
func ParseTimestamp(ts string) (int64, error) {
|
||
return strconv.ParseInt(ts, 10, 64)
|
||
}
|
||
|