This commit is contained in:
2025-07-28 01:46:39 +08:00
parent b03129667a
commit 357639462a
219 changed files with 21634 additions and 8138 deletions

View File

@@ -0,0 +1,194 @@
package cache
import (
"sync"
)
// CacheConfigManager 缓存配置管理器
// 提供全局缓存配置管理和表级别的缓存决策
type CacheConfigManager struct {
config CacheConfig
mutex sync.RWMutex
}
// GlobalCacheConfigManager 全局缓存配置管理器实例
var GlobalCacheConfigManager *CacheConfigManager
// InitCacheConfigManager 初始化全局缓存配置管理器
func InitCacheConfigManager(config CacheConfig) {
GlobalCacheConfigManager = &CacheConfigManager{
config: config,
}
}
// GetCacheConfig 获取当前缓存配置
func (m *CacheConfigManager) GetCacheConfig() CacheConfig {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.config
}
// UpdateCacheConfig 更新缓存配置
func (m *CacheConfigManager) UpdateCacheConfig(config CacheConfig) {
m.mutex.Lock()
defer m.mutex.Unlock()
m.config = config
}
// IsTableCacheEnabled 检查表是否启用缓存
func (m *CacheConfigManager) IsTableCacheEnabled(tableName string) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
// 检查表是否在禁用列表中
for _, disabledTable := range m.config.DisabledTables {
if disabledTable == tableName {
return false
}
}
// 如果配置了启用列表,只对启用列表中的表启用缓存
if len(m.config.EnabledTables) > 0 {
for _, enabledTable := range m.config.EnabledTables {
if enabledTable == tableName {
return true
}
}
return false
}
// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表)
return true
}
// IsTableCacheDisabled 检查表是否禁用缓存
func (m *CacheConfigManager) IsTableCacheDisabled(tableName string) bool {
return !m.IsTableCacheEnabled(tableName)
}
// GetEnabledTables 获取启用缓存的表列表
func (m *CacheConfigManager) GetEnabledTables() []string {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.config.EnabledTables
}
// GetDisabledTables 获取禁用缓存的表列表
func (m *CacheConfigManager) GetDisabledTables() []string {
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.config.DisabledTables
}
// AddEnabledTable 添加启用缓存的表
func (m *CacheConfigManager) AddEnabledTable(tableName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
// 检查是否已存在
for _, table := range m.config.EnabledTables {
if table == tableName {
return
}
}
m.config.EnabledTables = append(m.config.EnabledTables, tableName)
}
// AddDisabledTable 添加禁用缓存的表
func (m *CacheConfigManager) AddDisabledTable(tableName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
// 检查是否已存在
for _, table := range m.config.DisabledTables {
if table == tableName {
return
}
}
m.config.DisabledTables = append(m.config.DisabledTables, tableName)
}
// RemoveEnabledTable 移除启用缓存的表
func (m *CacheConfigManager) RemoveEnabledTable(tableName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
var newEnabledTables []string
for _, table := range m.config.EnabledTables {
if table != tableName {
newEnabledTables = append(newEnabledTables, table)
}
}
m.config.EnabledTables = newEnabledTables
}
// RemoveDisabledTable 移除禁用缓存的表
func (m *CacheConfigManager) RemoveDisabledTable(tableName string) {
m.mutex.Lock()
defer m.mutex.Unlock()
var newDisabledTables []string
for _, table := range m.config.DisabledTables {
if table != tableName {
newDisabledTables = append(newDisabledTables, table)
}
}
m.config.DisabledTables = newDisabledTables
}
// GetTableCacheStatus 获取表的缓存状态信息
func (m *CacheConfigManager) GetTableCacheStatus(tableName string) map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
status := map[string]interface{}{
"table_name": tableName,
"enabled": m.IsTableCacheEnabled(tableName),
"disabled": m.IsTableCacheDisabled(tableName),
}
// 检查是否在启用列表中
for _, table := range m.config.EnabledTables {
if table == tableName {
status["in_enabled_list"] = true
break
}
}
// 检查是否在禁用列表中
for _, table := range m.config.DisabledTables {
if table == tableName {
status["in_disabled_list"] = true
break
}
}
return status
}
// GetAllTableStatus 获取所有表的缓存状态
func (m *CacheConfigManager) GetAllTableStatus() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
result := map[string]interface{}{
"enabled_tables": m.config.EnabledTables,
"disabled_tables": m.config.DisabledTables,
"config": map[string]interface{}{
"default_ttl": m.config.DefaultTTL,
"table_prefix": m.config.TablePrefix,
"max_cache_size": m.config.MaxCacheSize,
"cache_complex_sql": m.config.CacheComplexSQL,
"enable_stats": m.config.EnableStats,
"enable_warmup": m.config.EnableWarmup,
"penetration_guard": m.config.PenetrationGuard,
"bloom_filter": m.config.BloomFilter,
"auto_invalidate": m.config.AutoInvalidate,
"invalidate_delay": m.config.InvalidateDelay,
},
}
return result
}

View File

@@ -8,6 +8,7 @@ import (
"fmt"
"reflect"
"strings"
"sync"
"time"
"go.uber.org/zap"
@@ -21,6 +22,10 @@ type GormCachePlugin struct {
cache interfaces.CacheService
logger *zap.Logger
config CacheConfig
// 缓存失效去重机制
invalidationQueue map[string]*time.Timer
queueMutex sync.RWMutex
}
// CacheConfig 缓存配置
@@ -58,7 +63,8 @@ func DefaultCacheConfig() CacheConfig {
PenetrationGuard: true,
BloomFilter: false,
AutoInvalidate: true,
InvalidateDelay: 100 * time.Millisecond,
// 增加延迟失效时间,减少频繁的缓存失效操作
InvalidateDelay: 500 * time.Millisecond,
}
}
@@ -73,6 +79,7 @@ func NewGormCachePlugin(cache interfaces.CacheService, logger *zap.Logger, confi
cache: cache,
logger: logger,
config: cfg,
invalidationQueue: make(map[string]*time.Timer),
}
}
@@ -87,12 +94,30 @@ func (p *GormCachePlugin) Initialize(db *gorm.DB) error {
zap.Duration("default_ttl", p.config.DefaultTTL),
zap.Bool("auto_invalidate", p.config.AutoInvalidate),
zap.Bool("penetration_guard", p.config.PenetrationGuard),
zap.Duration("invalidate_delay", p.config.InvalidateDelay),
)
// 注册回调函数
return p.registerCallbacks(db)
}
// Shutdown 关闭插件,清理资源
func (p *GormCachePlugin) Shutdown() {
p.queueMutex.Lock()
defer p.queueMutex.Unlock()
// 停止所有定时器
for table, timer := range p.invalidationQueue {
timer.Stop()
p.logger.Debug("停止缓存失效定时器", zap.String("table", table))
}
// 清空队列
p.invalidationQueue = make(map[string]*time.Timer)
p.logger.Info("GORM缓存插件已关闭")
}
// registerCallbacks 注册GORM回调
func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error {
// Query回调 - 查询时检查缓存
@@ -117,6 +142,7 @@ func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error {
func (p *GormCachePlugin) beforeQuery(db *gorm.DB) {
// 检查是否启用缓存
if !p.shouldCache(db) {
p.logger.Debug("跳过缓存", zap.String("table", db.Statement.Table))
return
}
@@ -147,7 +173,15 @@ func (p *GormCachePlugin) beforeQuery(db *gorm.DB) {
p.updateStats("hit", db.Statement.Table)
}
return
} else {
p.logger.Warn("缓存数据恢复失败,将执行数据库查询",
zap.String("cache_key", cacheKey),
zap.Error(err))
}
} else {
p.logger.Debug("缓存未命中",
zap.String("cache_key", cacheKey),
zap.Error(err))
}
// 缓存未命中,设置标记
@@ -195,7 +229,10 @@ func (p *GormCachePlugin) afterCreate(db *gorm.DB) {
return
}
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
// 只对启用缓存的表执行失效操作
if p.shouldInvalidateTable(db.Statement.Table) {
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
}
// afterUpdate 更新后回调
@@ -204,7 +241,10 @@ func (p *GormCachePlugin) afterUpdate(db *gorm.DB) {
return
}
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
// 只对启用缓存的表执行失效操作
if p.shouldInvalidateTable(db.Statement.Table) {
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
}
// afterDelete 删除后回调
@@ -213,26 +253,64 @@ func (p *GormCachePlugin) afterDelete(db *gorm.DB) {
return
}
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
// 只对启用缓存的表执行失效操作
if p.shouldInvalidateTable(db.Statement.Table) {
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
}
// ================ 缓存管理方法 ================
// shouldInvalidateTable 判断是否应该对表执行缓存失效操作
func (p *GormCachePlugin) shouldInvalidateTable(table string) bool {
// 使用全局缓存配置管理器进行智能决策
if GlobalCacheConfigManager != nil {
return GlobalCacheConfigManager.IsTableCacheEnabled(table)
}
// 如果全局管理器未初始化,使用本地配置
// 检查表是否在禁用列表中
for _, disabledTable := range p.config.DisabledTables {
if disabledTable == table {
return false
}
}
// 如果配置了启用列表,只对启用列表中的表执行失效操作
if len(p.config.EnabledTables) > 0 {
for _, enabledTable := range p.config.EnabledTables {
if enabledTable == table {
return true
}
}
return false
}
// 如果没有配置启用列表,默认对所有表执行失效操作(除了禁用列表中的表)
return true
}
// shouldCache 判断是否应该缓存
func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool {
// 检查是否明确禁用缓存
// 检查是否手动禁用缓存
if value, ok := db.Statement.Get("cache:disabled"); ok && value.(bool) {
return false
}
// 检查是否明确启用缓存
// 检查是否手动启用缓存
if value, ok := db.Statement.Get("cache:enabled"); ok && value.(bool) {
return true
}
// 使用全局缓存配置管理器进行智能决策
if GlobalCacheConfigManager != nil {
return GlobalCacheConfigManager.IsTableCacheEnabled(db.Statement.Table)
}
// 如果全局管理器未初始化,使用本地配置
// 检查表是否在禁用列表中
for _, table := range p.config.DisabledTables {
if table == db.Statement.Table {
for _, disabledTable := range p.config.DisabledTables {
if disabledTable == db.Statement.Table {
return false
}
}
@@ -247,11 +325,7 @@ func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool {
return false
}
// 检查是否为复杂查询
if !p.config.CacheComplexSQL && p.isComplexQuery(db) {
return false
}
// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表)
return true
}
@@ -292,10 +366,27 @@ func (p *GormCachePlugin) generateCacheKey(db *gorm.DB) string {
// hashSQL 对SQL语句和参数进行hash
func (p *GormCachePlugin) hashSQL(sql string, vars []interface{}) string {
// 将SQL和参数组合
// 修复改进SQL hash生成确保唯一性
combined := sql
for _, v := range vars {
combined += fmt.Sprintf(":%v", v)
// 使用更精确的格式化,避免类型信息丢失
switch val := v.(type) {
case string:
combined += ":" + val
case int, int32, int64:
combined += fmt.Sprintf(":%d", val)
case float32, float64:
combined += fmt.Sprintf(":%f", val)
case bool:
combined += fmt.Sprintf(":%t", val)
default:
// 对于复杂类型使用JSON序列化
if jsonData, err := json.Marshal(v); err == nil {
combined += ":" + string(jsonData)
} else {
combined += fmt.Sprintf(":%v", v)
}
}
}
// 计算MD5 hash
@@ -328,9 +419,22 @@ func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *
return fmt.Errorf("查询结果为空")
}
// 修复:改进缓存数据保存逻辑
var dataToCache interface{}
// 如果dest是切片需要特殊处理
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr && destValue.Elem().Kind() == reflect.Slice {
// 对于切片,直接使用原始数据
dataToCache = destValue.Elem().Interface()
} else {
// 对于单个对象,也直接使用原始数据
dataToCache = dest
}
// 构建缓存结果
result := CachedResult{
Data: dest,
Data: dataToCache,
RowCount: db.Statement.RowsAffected,
Timestamp: time.Now(),
}
@@ -340,6 +444,10 @@ func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *
// 保存到缓存
if err := p.cache.Set(ctx, cacheKey, result, ttl); err != nil {
p.logger.Error("保存查询结果到缓存失败",
zap.String("cache_key", cacheKey),
zap.Error(err),
)
return fmt.Errorf("保存到缓存失败: %w", err)
}
@@ -364,25 +472,35 @@ func (p *GormCachePlugin) restoreFromCache(db *gorm.DB, cachedResult *CachedResu
return fmt.Errorf("目标对象必须是指针")
}
// 将缓存数据复制到目标
// 修复:改进缓存数据恢复逻辑
cachedValue := reflect.ValueOf(cachedResult.Data)
if !cachedValue.Type().AssignableTo(destValue.Elem().Type()) {
// 如果类型完全匹配,直接赋值
if cachedValue.Type().AssignableTo(destValue.Elem().Type()) {
destValue.Elem().Set(cachedValue)
} else {
// 尝试JSON转换
jsonData, err := json.Marshal(cachedResult.Data)
if err != nil {
return fmt.Errorf("缓存数据类型不匹配")
p.logger.Error("序列化缓存数据失败", zap.Error(err))
return fmt.Errorf("缓存数据类型转换失败: %w", err)
}
if err := json.Unmarshal(jsonData, db.Statement.Dest); err != nil {
p.logger.Error("反序列化缓存数据失败",
zap.String("json_data", string(jsonData)),
zap.Error(err))
return fmt.Errorf("JSON反序列化失败: %w", err)
}
} else {
destValue.Elem().Set(cachedValue)
}
// 设置影响行数
db.Statement.RowsAffected = cachedResult.RowCount
p.logger.Debug("从缓存恢复数据成功",
zap.Int64("rows", cachedResult.RowCount),
zap.Time("timestamp", cachedResult.Timestamp))
return nil
}
@@ -400,37 +518,79 @@ func (p *GormCachePlugin) getTTL(db *gorm.DB) time.Duration {
// invalidateTableCache 失效表相关缓存
func (p *GormCachePlugin) invalidateTableCache(ctx context.Context, table string) {
if ctx == nil {
ctx = context.Background()
}
// 延迟失效(避免并发问题)
if p.config.InvalidateDelay > 0 {
time.AfterFunc(p.config.InvalidateDelay, func() {
p.doInvalidateTableCache(ctx, table)
})
} else {
p.doInvalidateTableCache(ctx, table)
// 使用去重机制,避免重复的缓存失效操作
p.queueMutex.Lock()
defer p.queueMutex.Unlock()
// 如果已经有相同的失效操作在队列中,取消之前的定时器
if timer, exists := p.invalidationQueue[table]; exists {
timer.Stop()
delete(p.invalidationQueue, table)
}
// 创建独立的上下文,避免受到原始请求上下文的影响
// 设置合理的超时时间,避免缓存失效操作阻塞
cacheCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// 创建新的定时器
timer := time.AfterFunc(p.config.InvalidateDelay, func() {
// 执行缓存失效
p.doInvalidateTableCache(cacheCtx, table)
// 清理定时器引用
p.queueMutex.Lock()
delete(p.invalidationQueue, table)
p.queueMutex.Unlock()
// 取消上下文
cancel()
})
// 将定时器加入队列
p.invalidationQueue[table] = timer
p.logger.Debug("缓存失效操作已加入队列",
zap.String("table", table),
zap.Duration("delay", p.config.InvalidateDelay),
)
}
// doInvalidateTableCache 执行缓存失效
func (p *GormCachePlugin) doInvalidateTableCache(ctx context.Context, table string) {
pattern := fmt.Sprintf("%s:%s:*", p.config.TablePrefix, table)
if err := p.cache.DeletePattern(ctx, pattern); err != nil {
p.logger.Warn("失效表缓存失败",
// 添加重试机制,提高缓存失效的可靠性
maxRetries := 3
for attempt := 1; attempt <= maxRetries; attempt++ {
if err := p.cache.DeletePattern(ctx, pattern); err != nil {
if attempt < maxRetries {
p.logger.Warn("缓存失效失败,准备重试",
zap.String("table", table),
zap.String("pattern", pattern),
zap.Int("attempt", attempt),
zap.Error(err),
)
// 短暂延迟后重试
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
p.logger.Warn("失效表缓存失败",
zap.String("table", table),
zap.String("pattern", pattern),
zap.Int("attempts", maxRetries),
zap.Error(err),
)
return
}
// 成功删除,记录日志并退出
p.logger.Debug("表缓存已失效",
zap.String("table", table),
zap.String("pattern", pattern),
zap.Error(err),
)
return
}
p.logger.Debug("表缓存已失效",
zap.String("table", table),
zap.String("pattern", pattern),
)
}
// updateStats 更新统计信息

View File

@@ -0,0 +1,113 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"errors"
"io"
)
// PKCS7填充
func PKCS7Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
// 去除PKCS7填充
func PKCS7UnPadding(origData []byte) ([]byte, error) {
length := len(origData)
if length == 0 {
return nil, errors.New("input data error")
}
unpadding := int(origData[length-1])
if unpadding > length {
return nil, errors.New("unpadding size is invalid")
}
// 检查填充字节是否一致
for i := 0; i < unpadding; i++ {
if origData[length-1-i] != byte(unpadding) {
return nil, errors.New("invalid padding")
}
}
return origData[:(length - unpadding)], nil
}
// AES CBC模式加密Base64传入传出
func AesEncrypt(plainText []byte, key string) (string, error) {
keyBytes, err := hex.DecodeString(key)
if err != nil {
return "", err
}
block, err := aes.NewCipher(keyBytes)
if err != nil {
return "", err
}
blockSize := block.BlockSize()
plainText = PKCS7Padding(plainText, blockSize)
cipherText := make([]byte, blockSize+len(plainText))
iv := cipherText[:blockSize] // 使用前blockSize字节作为IV
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return "", err
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(cipherText[blockSize:], plainText)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// AES CBC模式解密Base64传入传出
func AesDecrypt(cipherTextBase64 string, key string) ([]byte, error) {
keyBytes, err := hex.DecodeString(key)
if err != nil {
return nil, err
}
cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(keyBytes)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
if len(cipherText) < blockSize {
return nil, errors.New("ciphertext too short")
}
iv := cipherText[:blockSize]
cipherText = cipherText[blockSize:]
if len(cipherText)%blockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of the block size")
}
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(cipherText, cipherText)
plainText, err := PKCS7UnPadding(cipherText)
if err != nil {
return nil, err
}
return plainText, nil
}
// Md5Encrypt 用于对传入的message进行MD5加密
func Md5Encrypt(message string) string {
hash := md5.New()
hash.Write([]byte(message)) // 将字符串转换为字节切片并写入
return hex.EncodeToString(hash.Sum(nil)) // 将哈希值转换为16进制字符串并返回
}

View File

@@ -0,0 +1,63 @@
package crypto
import (
"crypto/rand"
"encoding/hex"
"io"
mathrand "math/rand"
"strconv"
"time"
)
// 生成AES-128密钥的函数符合市面规范
func GenerateSecretKey() (string, error) {
key := make([]byte, 16) // 16字节密钥
_, err := io.ReadFull(rand.Reader, key)
if err != nil {
return "", err
}
return hex.EncodeToString(key), nil
}
func GenerateSecretId() (string, error) {
// 创建一个字节数组,用于存储随机数据
bytes := make([]byte, 8) // 因为每个字节表示两个16进制字符
// 读取随机字节到数组中
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
// 将字节数组转换为16进制字符串
return hex.EncodeToString(bytes), nil
}
// GenerateTransactionID 生成16位数的交易单号
func GenerateTransactionID() string {
length := 16
// 获取当前时间戳
timestamp := time.Now().UnixNano()
// 转换为字符串
timeStr := strconv.FormatInt(timestamp, 10)
// 生成随机数
mathrand.Seed(time.Now().UnixNano())
randomPart := strconv.Itoa(mathrand.Intn(1000000))
// 组合时间戳和随机数
combined := timeStr + randomPart
// 如果长度超出指定值,则截断;如果不够,则填充随机字符
if len(combined) >= length {
return combined[:length]
}
// 如果长度不够填充0
for len(combined) < length {
combined += strconv.Itoa(mathrand.Intn(10)) // 填充随机数
}
return combined
}

View File

@@ -0,0 +1,150 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"encoding/base64"
)
const (
KEY_SIZE = 16 // AES-128, 16 bytes
)
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
func WestDexEncrypt(data, secretKey string) (string, error) {
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
ciphertext, err := aesEncrypt([]byte(data), key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
func WestDexDecrypt(encodedData, secretKey string) ([]byte, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
if err != nil {
return nil, err
}
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
plaintext, err := aesDecrypt(ciphertext, key)
if err != nil {
return nil, err
}
return plaintext, nil
}
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
func generateAESKey(length int, password []byte) []byte {
h := sha1.New()
h.Write(password)
state := h.Sum(nil)
keyBytes := make([]byte, 0, length/8)
for len(keyBytes) < length/8 {
h := sha1.New()
h.Write(state)
state = h.Sum(nil)
keyBytes = append(keyBytes, state...)
}
return keyBytes[:length/8]
}
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
ciphertext := make([]byte, len(paddedPlaintext))
mode := newECBEncrypter(block)
mode.CryptBlocks(ciphertext, paddedPlaintext)
return ciphertext, nil
}
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
plaintext := make([]byte, len(ciphertext))
mode := newECBDecrypter(block)
mode.CryptBlocks(plaintext, ciphertext)
return pkcs5Unpadding(plaintext), nil
}
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
func pkcs5Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
// pkcs5Unpadding removes PKCS5 padding from the input
func pkcs5Unpadding(src []byte) []byte {
length := len(src)
unpadding := int(src[length-1])
return src[:(length - unpadding)]
}
// ECB mode encryption/decryption
type ecb struct {
b cipher.Block
blockSize int
}
func newECB(b cipher.Block) *ecb {
return &ecb{
b: b,
blockSize: b.BlockSize(),
}
}
type ecbEncrypter ecb
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Encrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
type ecbDecrypter ecb
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Decrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}

View File

@@ -8,6 +8,7 @@ import (
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/shared/cache"
"tyapi-server/internal/shared/interfaces"
)
@@ -26,41 +27,112 @@ func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName stri
}
}
// ================ 智能缓存决策方法 ================
// isTableCacheEnabled 检查表是否启用缓存
func (r *CachedBaseRepositoryImpl) isTableCacheEnabled() bool {
// 使用全局缓存配置管理器
if cache.GlobalCacheConfigManager != nil {
return cache.GlobalCacheConfigManager.IsTableCacheEnabled(r.tableName)
}
// 如果全局管理器未初始化,默认启用缓存
r.logger.Warn("全局缓存配置管理器未初始化,默认启用缓存",
zap.String("table", r.tableName))
return true
}
// shouldUseCacheForTable 智能判断是否应该对当前表使用缓存
func (r *CachedBaseRepositoryImpl) shouldUseCacheForTable() bool {
// 检查表是否启用缓存
if !r.isTableCacheEnabled() {
r.logger.Debug("表未启用缓存,跳过缓存操作",
zap.String("table", r.tableName))
return false
}
return true
}
// ================ 智能缓存方法 ================
// GetWithCache 带缓存的单条查询
// GetWithCache 带缓存的单条查询(智能决策)
func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl)
db := r.GetDB(ctx)
// 智能决策:根据表配置决定是否使用缓存
if r.shouldUseCacheForTable() {
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
r.logger.Debug("执行带缓存查询",
zap.String("table", r.tableName),
zap.Duration("ttl", ttl),
zap.String("where", where))
} else {
db = db.Set("cache:disabled", true)
r.logger.Debug("执行无缓存查询",
zap.String("table", r.tableName),
zap.String("where", where))
}
return db.Where(where, args...).First(dest).Error
}
// FindWithCache 带缓存的多条查询
// FindWithCache 带缓存的多条查询(智能决策)
func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl)
db := r.GetDB(ctx)
// 智能决策:根据表配置决定是否使用缓存
if r.shouldUseCacheForTable() {
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
r.logger.Debug("执行带缓存批量查询",
zap.String("table", r.tableName),
zap.Duration("ttl", ttl),
zap.String("where", where))
} else {
db = db.Set("cache:disabled", true)
r.logger.Debug("执行无缓存批量查询",
zap.String("table", r.tableName),
zap.String("where", where))
}
return db.Where(where, args...).Find(dest).Error
}
// CountWithCache 带缓存的计数查询
// CountWithCache 带缓存的计数查询(智能决策)
func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl).
Model(entity)
db := r.GetDB(ctx).Model(entity)
// 智能决策:根据表配置决定是否使用缓存
if r.shouldUseCacheForTable() {
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
r.logger.Debug("执行带缓存计数查询",
zap.String("table", r.tableName),
zap.Duration("ttl", ttl),
zap.String("where", where))
} else {
db = db.Set("cache:disabled", true)
r.logger.Debug("执行无缓存计数查询",
zap.String("table", r.tableName),
zap.String("where", where))
}
return db.Where(where, args...).Count(count).Error
}
// ListWithCache 带缓存的列表查询
// ListWithCache 带缓存的列表查询(智能决策)
func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl)
db := r.GetDB(ctx)
// 智能决策:根据表配置决定是否使用缓存
if r.shouldUseCacheForTable() {
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
r.logger.Debug("执行带缓存列表查询",
zap.String("table", r.tableName),
zap.Duration("ttl", ttl))
} else {
db = db.Set("cache:disabled", true)
r.logger.Debug("执行无缓存列表查询",
zap.String("table", r.tableName))
}
// 应用where条件
if options.Where != "" {
@@ -90,12 +162,12 @@ func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest inter
// CacheListOptions 缓存列表查询选项
type CacheListOptions struct {
Where string `json:"where"`
Args []interface{} `json:"args"`
Order string `json:"order"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Preloads []string `json:"preloads"`
Where string `json:"where"`
Args []interface{} `json:"args"`
Order string `json:"order"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Preloads []string `json:"preloads"`
}
// ================ 缓存控制方法 ================
@@ -142,6 +214,10 @@ func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl {
// SmartGetByID 智能ID查询自动缓存
func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error {
r.logger.Debug("执行智能ID查询",
zap.String("table", r.tableName),
zap.String("id", id))
return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id)
}
@@ -151,7 +227,7 @@ func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest int
if len(ttl) > 0 {
cacheTTL = ttl[0]
}
return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value)
}
@@ -161,6 +237,12 @@ func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface
cacheTTL := r.calculateCacheTTL(options)
useCache := r.shouldUseCache(options)
r.logger.Debug("执行智能列表查询",
zap.String("table", r.tableName),
zap.Bool("use_cache", useCache),
zap.Duration("cache_ttl", cacheTTL))
// 修复:确保缓存标记在查询前正确设置
db := r.GetDB(ctx)
if useCache {
db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL)
@@ -254,14 +336,14 @@ func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions
// WarmupCommonQueries 预热常用查询
func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error {
r.logger.Info("开始预热缓存",
r.logger.Info("开始预热缓存",
zap.String("table", r.tableName),
zap.Int("queries", len(queries)),
)
for _, query := range queries {
if err := r.executeWarmupQuery(ctx, query); err != nil {
r.logger.Warn("缓存预热失败",
r.logger.Warn("缓存预热失败",
zap.String("query", query.Name),
zap.Error(err),
)
@@ -273,11 +355,11 @@ func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, quer
// WarmupQuery 预热查询定义
type WarmupQuery struct {
Name string `json:"name"`
SQL string `json:"sql"`
Args []interface{} `json:"args"`
TTL time.Duration `json:"ttl"`
Dest interface{} `json:"dest"`
Name string `json:"name"`
SQL string `json:"sql"`
Args []interface{} `json:"args"`
TTL time.Duration `json:"ttl"`
Dest interface{} `json:"dest"`
}
// executeWarmupQuery 执行预热查询
@@ -313,7 +395,7 @@ func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interfa
if err := r.CreateEntity(ctx, newEntity); err != nil {
return err
}
// 将新创建的实体复制到dest
// 这里需要反射或其他方式复制
return nil
@@ -333,7 +415,7 @@ func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []
// RefreshCache 刷新缓存
func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error {
r.logger.Info("刷新缓存",
r.logger.Info("刷新缓存",
zap.String("table", r.tableName),
zap.String("pattern", pattern),
)
@@ -348,9 +430,9 @@ func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern str
// GetCacheInfo 获取缓存信息
func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
return map[string]interface{}{
"table_name": r.tableName,
"cache_enabled": true,
"default_ttl": "30m",
"table_name": r.tableName,
"cache_enabled": true,
"default_ttl": "30m",
"cache_patterns": []string{
fmt.Sprintf("gorm_cache:%s:*", r.tableName),
},
@@ -359,9 +441,9 @@ func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
// LogCacheOperation 记录缓存操作
func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) {
r.logger.Debug("缓存操作",
r.logger.Debug("缓存操作",
zap.String("table", r.tableName),
zap.String("operation", operation),
zap.String("details", details),
)
}
}

View File

@@ -2,18 +2,43 @@ package esign
import "fmt"
type EsignContractConfig struct {
Name string `json:"name" yaml:"name"`
ExpireDays int `json:"expireDays" yaml:"expire_days"`
RetryCount int `json:"retryCount" yaml:"retry_count"`
}
type EsignAuthConfig struct {
OrgAuthModes []string `json:"orgAuthModes" yaml:"org_auth_modes"`
DefaultAuthMode string `json:"defaultAuthMode" yaml:"default_auth_mode"`
PsnAuthModes []string `json:"psnAuthModes" yaml:"psn_auth_modes"`
WillingnessAuthModes []string `json:"willingnessAuthModes" yaml:"willingness_auth_modes"`
RedirectUrl string `json:"redirectUrl" yaml:"redirect_url"`
}
type EsignSignConfig struct {
AutoFinish bool `json:"autoFinish" yaml:"auto_finish"`
SignFieldStyle int `json:"signFieldStyle" yaml:"sign_field_style"`
ClientType string `json:"clientType" yaml:"client_type"`
RedirectUrl string `json:"redirectUrl" yaml:"redirect_url"`
}
// Config e签宝服务配置结构体
// 包含应用ID、密钥、服务器URL和模板ID等基础配置信息
// 新增Contract、Auth、Sign配置
type Config struct {
AppID string `json:"appId"` // 应用ID
AppSecret string `json:"appSecret"` // 应用密钥
ServerURL string `json:"serverUrl"` // 服务器URL
TemplateID string `json:"templateId"` // 模板ID
AppID string `json:"appId" yaml:"app_id"`
AppSecret string `json:"appSecret" yaml:"app_secret"`
ServerURL string `json:"serverUrl" yaml:"server_url"`
TemplateID string `json:"templateId" yaml:"template_id"`
Contract *EsignContractConfig `json:"contract" yaml:"contract"`
Auth *EsignAuthConfig `json:"auth" yaml:"auth"`
Sign *EsignSignConfig `json:"sign" yaml:"sign"`
}
// NewConfig 创建新的配置实例
// 提供配置验证和默认值设置
func NewConfig(appID, appSecret, serverURL, templateID string) (*Config, error) {
func NewConfig(appID, appSecret, serverURL, templateID string, contract *EsignContractConfig, auth *EsignAuthConfig, sign *EsignSignConfig) (*Config, error) {
if appID == "" {
return nil, fmt.Errorf("应用ID不能为空")
}
@@ -26,12 +51,15 @@ func NewConfig(appID, appSecret, serverURL, templateID string) (*Config, error)
if templateID == "" {
return nil, fmt.Errorf("模板ID不能为空")
}
return &Config{
AppID: appID,
AppSecret: appSecret,
ServerURL: serverURL,
TemplateID: templateID,
Contract: contract,
Auth: auth,
Sign: sign,
}, nil
}
@@ -60,7 +88,7 @@ const (
AuthModeBank = "PSN_BANK" // 银行卡认证
// 意愿认证模式
WillingnessAuthSMS = "CODE_SMS" // 短信验证码
WillingnessAuthSMS = "CODE_SMS" // 短信验证码
WillingnessAuthEmail = "CODE_EMAIL" // 邮箱验证码
// 证件类型常量
@@ -69,7 +97,7 @@ const (
// 签署区样式常量
SignFieldStyleNormal = 1 // 普通签章
SignFieldStyleSeam = 2 // 骑缝签章
SignFieldStyleSeam = 2 // 骑缝签章
// 签署人类型常量
SignerTypePerson = 0 // 个人
@@ -80,4 +108,4 @@ const (
// 客户端类型常量
ClientTypeAll = "ALL" // 所有客户端
)
)

View File

@@ -13,6 +13,24 @@ func Example() {
"your_app_secret",
"https://smlopenapi.esign.cn",
"your_template_id",
&EsignContractConfig{
Name: "测试合同",
ExpireDays: 30,
RetryCount: 3,
},
&EsignAuthConfig{
OrgAuthModes: []string{"ORG"},
DefaultAuthMode: "ORG",
PsnAuthModes: []string{"PSN"},
WillingnessAuthModes: []string{"WILLINGNESS"},
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
},
&EsignSignConfig{
AutoFinish: true,
SignFieldStyle: 1,
ClientType: "ALL",
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
},
)
if err != nil {
log.Fatal("配置创建失败:", err)
@@ -122,7 +140,22 @@ func Example() {
// ExampleBasicUsage 基础用法示例
func ExampleBasicUsage() {
// 最简单的用法 - 一行代码完成合同签署
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id")
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{
Name: "测试合同",
ExpireDays: 30,
RetryCount: 3,
}, &EsignAuthConfig{
OrgAuthModes: []string{"ORG"},
DefaultAuthMode: "ORG",
PsnAuthModes: []string{"PSN"},
WillingnessAuthModes: []string{"WILLINGNESS"},
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
}, &EsignSignConfig{
AutoFinish: true,
SignFieldStyle: 1,
ClientType: "ALL",
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
})
client := NewClient(config)
// 快速合同签署
@@ -143,7 +176,22 @@ func ExampleBasicUsage() {
// ExampleWithCustomData 自定义数据示例
func ExampleWithCustomData() {
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id")
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{
Name: "测试合同",
ExpireDays: 30,
RetryCount: 3,
}, &EsignAuthConfig{
OrgAuthModes: []string{"ORG"},
DefaultAuthMode: "ORG",
PsnAuthModes: []string{"PSN"},
WillingnessAuthModes: []string{"WILLINGNESS"},
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
}, &EsignSignConfig{
AutoFinish: true,
SignFieldStyle: 1,
ClientType: "ALL",
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
})
client := NewClient(config)
// 使用自定义模板数据
@@ -171,7 +219,22 @@ func ExampleWithCustomData() {
// ExampleEnterpriseAuth 企业认证示例
func ExampleEnterpriseAuth() {
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id")
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{
Name: "测试合同",
ExpireDays: 30,
RetryCount: 3,
}, &EsignAuthConfig{
OrgAuthModes: []string{"ORG"},
DefaultAuthMode: "ORG",
PsnAuthModes: []string{"PSN"},
WillingnessAuthModes: []string{"WILLINGNESS"},
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
}, &EsignSignConfig{
AutoFinish: true,
SignFieldStyle: 1,
ClientType: "ALL",
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
})
client := NewClient(config)
// 企业认证

View File

@@ -34,8 +34,8 @@ func (s *FileOpsService) UpdateConfig(config *Config) {
func (s *FileOpsService) DownloadSignedFile(signFlowId string) (*DownloadSignedFileResponse, error) {
fmt.Println("开始下载已签署文件及附属材料...")
// 发送API请求
urlPath := fmt.Sprintf("/v3/sign-flow/%s/attachments", signFlowId)
// 按照最新e签宝文档接口路径应为 /v3/sign-flow/{signFlowId}/file-download-url
urlPath := fmt.Sprintf("/v3/sign-flow/%s/file-download-url", signFlowId)
responseBody, err := s.httpClient.Request("GET", urlPath, nil)
if err != nil {
return nil, fmt.Errorf("下载已签署文件失败: %v", err)

View File

@@ -44,8 +44,8 @@ func (h *HTTPClient) UpdateConfig(config *Config) {
func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error) {
// 生成签名所需参数
timestamp := getCurrentTimestamp()
nonce := generateNonce()
date := getCurrentDate()
// date := getCurrentDate()
date := ""
// 计算Content-MD5
contentMD5 := ""
@@ -53,14 +53,14 @@ func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error
contentMD5 = getContentMD5(body)
}
// 根据Java示例Headers为空字符串
headers := ""
// 生成签名
// 生成签名用原始urlPath
signature := generateSignature(h.config.AppSecret, method, "*/*", contentMD5, "application/json", date, headers, urlPath)
// 创建HTTP请求
url := h.config.ServerURL + urlPath
// 实际请求url用encode后的urlPath
encodedURLPath := encodeURLQueryParams(urlPath)
url := h.config.ServerURL + encodedURLPath
req, err := http.NewRequest(method, url, bytes.NewBuffer(body))
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %v", err)
@@ -69,12 +69,11 @@ func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Content-MD5", contentMD5)
req.Header.Set("Date", date)
// req.Header.Set("Date", date)
req.Header.Set("Accept", "*/*")
req.Header.Set("X-Tsign-Open-App-Id", h.config.AppID)
req.Header.Set("X-Tsign-Open-Auth-Mode", "Signature")
req.Header.Set("X-Tsign-Open-Ca-Timestamp", timestamp)
req.Header.Set("X-Tsign-Open-Nonce", nonce)
req.Header.Set("X-Tsign-Open-Ca-Signature", signature)
// 发送请求
@@ -197,3 +196,24 @@ func sortURLQueryParams(urlPath string) string {
}
return basePath
}
// encodeURLQueryParams 对urlPath中的query参数值进行encode
func encodeURLQueryParams(urlPath string) string {
if !strings.Contains(urlPath, "?") {
return urlPath
}
parts := strings.SplitN(urlPath, "?", 2)
basePath := parts[0]
queryString := parts[1]
values, err := url.ParseQuery(queryString)
if err != nil {
return urlPath
}
var encodedPairs []string
for key, vals := range values {
for _, val := range vals {
encodedPairs = append(encodedPairs, key+"="+url.QueryEscape(val))
}
}
return basePath + "?" + strings.Join(encodedPairs, "&")
}

View File

@@ -66,6 +66,9 @@ func (s *OrgAuthService) GetAuthURL(req *OrgAuthRequest) (string, string, string
},
},
ClientType: ClientTypeAll,
RedirectConfig: &RedirectConfig{
RedirectUrl: s.config.Auth.RedirectUrl,
},
}
// 序列化请求数据

View File

@@ -43,7 +43,7 @@ func (s *SignFlowService) Create(req *CreateSignFlowRequest) (string, error) {
Docs: []DocInfo{
{
FileId: req.FileID,
FileName: "天远数据API合作协议.pdf",
FileName: s.config.Contract.Name,
},
},
SignFlowConfig: s.buildSignFlowConfig(),
@@ -141,9 +141,9 @@ func (s *SignFlowService) buildPartyASigner(fileID string) SignerInfo {
AutoSign: true,
SignFieldStyle: SignFieldStyleNormal,
SignFieldPosition: &SignFieldPosition{
PositionPage: "1",
PositionPage: "8",
PositionX: 200,
PositionY: 200,
PositionY: 430,
},
},
},
@@ -188,9 +188,9 @@ func (s *SignFlowService) buildPartyBSigner(fileID, signerAccount, signerName, t
AutoSign: false,
SignFieldStyle: SignFieldStyleNormal,
SignFieldPosition: &SignFieldPosition{
PositionPage: "1",
PositionX: 458,
PositionY: 200,
PositionPage: "8",
PositionX: 450,
PositionY: 430,
},
},
},
@@ -201,9 +201,9 @@ func (s *SignFlowService) buildPartyBSigner(fileID, signerAccount, signerName, t
// buildSignFlowConfig 构建签署流程配置
func (s *SignFlowService) buildSignFlowConfig() SignFlowConfig {
return SignFlowConfig{
SignFlowTitle: "天远数据API合作协议签署",
SignFlowExpireTime: calculateExpireTime(7), // 7天后过期
AutoFinish: true, // 所有签署方完成后自动完结
SignFlowTitle: s.config.Contract.Name,
SignFlowExpireTime: calculateExpireTime(s.config.Contract.ExpireDays),
AutoFinish: s.config.Sign.AutoFinish,
AuthConfig: &AuthConfig{
PsnAvailableAuthModes: []string{AuthModeMobile3},
WillingnessAuthModes: []string{WillingnessAuthSMS},
@@ -211,5 +211,8 @@ func (s *SignFlowService) buildSignFlowConfig() SignFlowConfig {
ContractConfig: &ContractConfig{
AllowToRescind: false,
},
RedirectConfig: &RedirectConfig{
RedirectUrl: s.config.Sign.RedirectUrl,
},
}
}

View File

@@ -36,7 +36,7 @@ func (s *TemplateService) Fill(components []Component) (*FillTemplate, error) {
fmt.Println("开始填写模板生成文件...")
// 生成带时间戳的文件名
fileName := generateFileName("天远数据API合作协议", "pdf")
fileName := generateFileName(s.config.Contract.Name, "pdf")
// 构建请求数据
requestData := FillTemplateRequest{

View File

@@ -5,10 +5,40 @@ import (
"crypto/md5"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"time"
)
// appendSignDataString 拼接待签名字符串
func appendSignDataString(httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters string) string {
if accept == "" {
accept = "*/*"
}
if contentType == "" {
contentType = "application/json; charset=UTF-8"
}
// 前四项
signStr := httpMethod + "\n" + accept + "\n" + contentMD5 + "\n" + contentType + "\n"
// 处理 date
if date == "" {
signStr += "\n"
} else {
signStr += date + "\n"
}
// 处理 headers
if headers == "" {
signStr += pathAndParameters
} else {
signStr += headers + "\n" + pathAndParameters
}
return signStr
}
// generateSignature 生成e签宝API请求签名
// 使用HMAC-SHA256算法对请求参数进行签名
//
@@ -24,8 +54,8 @@ import (
//
// 返回: Base64编码的签名字符串
func generateSignature(appSecret, httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters string) string {
// 构建待签名字符串按照e签宝API规范拼接
signStr := httpMethod + "\n" + accept + "\n" + contentMD5 + "\n" + contentType + "\n" + date + "\n" + headers + pathAndParameters
// 构建待签名字符串按照e签宝API规范拼接兼容Python实现细节
signStr := appendSignDataString(httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters)
// 使用HMAC-SHA256计算签名
h := hmac.New(sha256.New, []byte(appSecret))
@@ -102,3 +132,66 @@ func generateFileName(baseName, extension string) string {
func calculateExpireTime(days int) int64 {
return time.Now().AddDate(0, 0, days).UnixMilli()
}
// verifySignature 验证e签宝回调签名
func VerifySignature(callbackData interface{}, headers map[string]string, queryParams map[string]string, appSecret string) error {
// 1. 获取签名相关参数
signature, ok := headers["X-Tsign-Open-Signature"]
if !ok {
return fmt.Errorf("缺少签名头: X-Tsign-Open-Signature")
}
timestamp, ok := headers["X-Tsign-Open-Timestamp"]
if !ok {
return fmt.Errorf("缺少时间戳头: X-Tsign-Open-Timestamp")
}
// 2. 构建查询参数字符串
var queryKeys []string
for key := range queryParams {
queryKeys = append(queryKeys, key)
}
sort.Strings(queryKeys) // 按ASCII码升序排序
var queryValues []string
for _, key := range queryKeys {
queryValues = append(queryValues, queryParams[key])
}
queryString := strings.Join(queryValues, "")
// 3. 获取请求体数据
bodyData, err := getRequestBodyString(callbackData)
if err != nil {
return fmt.Errorf("获取请求体数据失败: %w", err)
}
// 4. 构建验签数据
data := timestamp + queryString + bodyData
// 5. 计算签名
expectedSignature := calculateSignature(data, appSecret)
// 6. 比较签名
if strings.ToLower(expectedSignature) != strings.ToLower(signature) {
return fmt.Errorf("签名验证失败")
}
return nil
}
// calculateSignature 计算HMAC-SHA256签名
func calculateSignature(data, secret string) string {
h := hmac.New(sha256.New, []byte(secret))
h.Write([]byte(data))
return strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
}
// getRequestBodyString 获取请求体字符串
func getRequestBodyString(callbackData interface{}) (string, error) {
// 将map转换为JSON字符串
jsonBytes, err := json.Marshal(callbackData)
if err != nil {
return "", fmt.Errorf("JSON序列化失败: %w", err)
}
return string(jsonBytes), nil
}

View File

@@ -5,6 +5,7 @@ import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
// HTTPHandler HTTP处理器接口
@@ -95,6 +96,11 @@ type RequestValidator interface {
// 绑定和验证
BindAndValidate(c *gin.Context, dto interface{}) error
// 业务逻辑验证方法
GetValidator() *validator.Validate
ValidateValue(field interface{}, tag string) error
ValidateStruct(s interface{}) error
}
// PaginationMeta 分页元数据

View File

@@ -0,0 +1,50 @@
package middleware
import (
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ApiAuthMiddleware API认证中间件
type ApiAuthMiddleware struct {
config *config.Config
logger *zap.Logger
responseBuilder interfaces.ResponseBuilder
}
// NewApiAuthMiddleware 创建API认证中间件
func NewApiAuthMiddleware(cfg *config.Config, logger *zap.Logger, responseBuilder interfaces.ResponseBuilder) *ApiAuthMiddleware {
return &ApiAuthMiddleware{
config: cfg,
logger: logger,
responseBuilder: responseBuilder,
}
}
// GetName 返回中间件名称
func (m *ApiAuthMiddleware) GetName() string {
return "api_auth"
}
// GetPriority 返回中间件优先级
func (m *ApiAuthMiddleware) GetPriority() int {
return 60 // 中等优先级,在日志之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *ApiAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端IP地址并存入上下文
clientIP := c.ClientIP()
c.Set("client_ip", clientIP)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *ApiAuthMiddleware) IsGlobal() bool {
return false
}

View File

@@ -0,0 +1,72 @@
package middleware
import (
"strings"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// DomainAuthMiddleware 域名认证中间件
type DomainAuthMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewDomainAuthMiddleware 创建域名认证中间件
func NewDomainAuthMiddleware(cfg *config.Config, logger *zap.Logger) *DomainAuthMiddleware {
return &DomainAuthMiddleware{
config: cfg,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *DomainAuthMiddleware) GetName() string {
return "domain_auth"
}
// GetPriority 返回中间件优先级
func (m *DomainAuthMiddleware) GetPriority() int {
return 100
}
// Handle 返回中间件处理函数
func (m *DomainAuthMiddleware) Handle(domain string) gin.HandlerFunc {
return func(c *gin.Context) {
// 开发环境下跳过外部验证
if m.config.App.IsDevelopment() {
m.logger.Info("开发环境:跳过域名验证",
zap.String("domain", domain))
c.Next()
}
if domain == "" {
domain = m.config.API.Domain
}
host := c.Request.Host
// 移除端口部分
if idx := strings.Index(host, ":"); idx != -1 {
host = host[:idx]
}
if host == domain {
// 设置域名匹配标记
c.Set("domainMatched", domain)
c.Next()
} else {
// 不匹配域名,跳过当前组处理,继续执行其他路由
c.Abort()
return
}
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *DomainAuthMiddleware) IsGlobal() bool {
return false
}

View File

@@ -0,0 +1,282 @@
package payment
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net/http"
"strconv"
"sync/atomic"
"time"
"github.com/shopspring/decimal"
"github.com/smartwalle/alipay/v3"
)
type AlipayConfig struct {
AppID string
PrivateKey string
AlipayPublicKey string
IsProduction bool
NotifyUrl string
ReturnURL string // 同步回调地址
}
type AliPayService struct {
config AlipayConfig
AlipayClient *alipay.Client
}
// NewAliPayService 是一个构造函数,用于初始化 AliPayService
func NewAliPayService(config AlipayConfig) *AliPayService {
client, err := alipay.New(config.AppID, config.PrivateKey, config.IsProduction)
if err != nil {
panic(fmt.Sprintf("创建支付宝客户端失败: %v", err))
}
// 加载支付宝公钥
err = client.LoadAliPayPublicKey(config.AlipayPublicKey)
if err != nil {
panic(fmt.Sprintf("加载支付宝公钥失败: %v", err))
}
return &AliPayService{
config: config,
AlipayClient: client,
}
}
func (a *AliPayService) CreateAlipayAppOrder(amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
client := a.AlipayClient
totalAmount := amount.StringFixed(2) // 保留2位小数
// 构造移动支付请求
p := alipay.TradeAppPay{
Trade: alipay.Trade{
Subject: subject,
OutTradeNo: outTradeNo,
TotalAmount: totalAmount,
ProductCode: "QUICK_MSECURITY_PAY", // 移动端支付专用代码
NotifyURL: a.config.NotifyUrl, // 异步回调通知地址
},
}
// 获取APP支付字符串这里会签名
payStr, err := client.TradeAppPay(p)
if err != nil {
return "", fmt.Errorf("创建支付宝订单失败: %v", err)
}
return payStr, nil
}
// CreateAlipayH5Order 创建支付宝H5支付订单
func (a *AliPayService) CreateAlipayH5Order(amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
client := a.AlipayClient
totalAmount := amount.StringFixed(2) // 保留2位小数
// 构造H5支付请求
p := alipay.TradeWapPay{
Trade: alipay.Trade{
Subject: subject,
OutTradeNo: outTradeNo,
TotalAmount: totalAmount,
ProductCode: "QUICK_WAP_PAY", // H5支付专用产品码
NotifyURL: a.config.NotifyUrl, // 异步回调通知地址
ReturnURL: a.config.ReturnURL,
},
}
// 获取H5支付请求字符串这里会签名
payUrl, err := client.TradeWapPay(p)
if err != nil {
return "", fmt.Errorf("创建支付宝H5订单失败: %v", err)
}
return payUrl.String(), nil
}
// CreateAlipayPCOrder 创建支付宝PC端支付订单
func (a *AliPayService) CreateAlipayPCOrder(amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
client := a.AlipayClient
totalAmount := amount.StringFixed(2) // 保留2位小数
// 构造PC端支付请求
p := alipay.TradePagePay{
Trade: alipay.Trade{
Subject: subject,
OutTradeNo: outTradeNo,
TotalAmount: totalAmount,
ProductCode: "FAST_INSTANT_TRADE_PAY", // PC端支付专用产品码
NotifyURL: a.config.NotifyUrl, // 异步回调通知地址
ReturnURL: a.config.ReturnURL, // 同步回调地址
},
}
// 获取PC端支付URL这里会签名
payUrl, err := client.TradePagePay(p)
if err != nil {
return "", fmt.Errorf("创建支付宝PC端订单失败: %v", err)
}
return payUrl.String(), nil
}
// CreateAlipayOrder 根据平台类型创建支付宝支付订单
func (a *AliPayService) CreateAlipayOrder(ctx context.Context, platform string, amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
switch platform {
case "app":
// 调用App支付的创建方法
return a.CreateAlipayAppOrder(amount, subject, outTradeNo)
case "h5":
// 调用H5支付的创建方法并传入 returnUrl
return a.CreateAlipayH5Order(amount, subject, outTradeNo)
case "pc":
// 调用PC端支付的创建方法
return a.CreateAlipayPCOrder(amount, subject, outTradeNo)
default:
return "", fmt.Errorf("不支持的支付平台: %s", platform)
}
}
// AliRefund 发起支付宝退款
func (a *AliPayService) AliRefund(ctx context.Context, outTradeNo string, refundAmount decimal.Decimal) (*alipay.TradeRefundRsp, error) {
refund := alipay.TradeRefund{
OutTradeNo: outTradeNo,
RefundAmount: refundAmount.StringFixed(2), // 保留2位小数
OutRequestNo: fmt.Sprintf("%s-refund", outTradeNo),
}
// 发起退款请求
refundResp, err := a.AlipayClient.TradeRefund(ctx, refund)
if err != nil {
return nil, fmt.Errorf("支付宝退款请求错误:%v", err)
}
return refundResp, nil
}
// HandleAliPaymentNotification 支付宝支付回调
func (a *AliPayService) HandleAliPaymentNotification(r *http.Request) (*alipay.Notification, error) {
// 解析表单
err := r.ParseForm()
if err != nil {
return nil, fmt.Errorf("解析请求表单失败:%v", err)
}
// 解析并验证通知DecodeNotification 会自动验证签名
notification, err := a.AlipayClient.DecodeNotification(r.Form)
if err != nil {
return nil, fmt.Errorf("验证签名失败: %v", err)
}
return notification, nil
}
func (a *AliPayService) IsAlipayPaymentSuccess(notification *alipay.Notification) bool {
return notification.TradeStatus == alipay.TradeStatusSuccess
}
func (a *AliPayService) QueryOrderStatus(ctx context.Context, outTradeNo string) (*alipay.TradeQueryRsp, error) {
queryRequest := alipay.TradeQuery{
OutTradeNo: outTradeNo,
}
// 发起查询请求
resp, err := a.AlipayClient.TradeQuery(ctx, queryRequest)
if err != nil {
return nil, fmt.Errorf("查询支付宝订单失败: %v", err)
}
// 返回交易状态
if resp.IsSuccess() {
return resp, nil
}
return nil, fmt.Errorf("查询支付宝订单失败: %v", resp.SubMsg)
}
// 添加全局原子计数器
var alipayOrderCounter uint32 = 0
// GenerateOutTradeNo 生成唯一订单号的函数 - 优化版本
func (a *AliPayService) GenerateOutTradeNo() string {
// 获取当前时间戳(毫秒级)
timestamp := time.Now().UnixMilli()
timeStr := strconv.FormatInt(timestamp, 10)
// 原子递增计数器
counter := atomic.AddUint32(&alipayOrderCounter, 1)
// 生成4字节真随机数
randomBytes := make([]byte, 4)
_, err := rand.Read(randomBytes)
if err != nil {
// 如果随机数生成失败,回退到使用时间纳秒数据
randomBytes = []byte(strconv.FormatInt(time.Now().UnixNano()%1000000, 16))
}
randomHex := hex.EncodeToString(randomBytes)
// 组合所有部分: 前缀 + 时间戳 + 计数器 + 随机数
orderNo := fmt.Sprintf("%s%06x%s", timeStr[:10], counter%0xFFFFFF, randomHex[:6])
// 确保长度不超过32字符大多数支付平台的限制
if len(orderNo) > 32 {
orderNo = orderNo[:32]
}
return orderNo
}
// AliTransfer 支付宝单笔转账到支付宝账户(提现功能)
func (a *AliPayService) AliTransfer(
ctx context.Context,
payeeAccount string, // 收款方支付宝账户
payeeName string, // 收款方姓名
amount decimal.Decimal, // 转账金额
remark string, // 转账备注
outBizNo string, // 商户转账唯一订单号可使用GenerateOutTradeNo生成
) (*alipay.FundTransUniTransferRsp, error) {
// 参数校验
if payeeAccount == "" {
return nil, fmt.Errorf("收款账户不能为空")
}
if amount.LessThanOrEqual(decimal.Zero) {
return nil, fmt.Errorf("转账金额必须大于0")
}
// 构造转账请求
req := alipay.FundTransUniTransfer{
OutBizNo: outBizNo,
TransAmount: amount.StringFixed(2), // 保留2位小数
ProductCode: "TRANS_ACCOUNT_NO_PWD", // 单笔无密转账到支付宝账户
BizScene: "DIRECT_TRANSFER", // 单笔转账
OrderTitle: "账户提现", // 转账标题
Remark: remark,
PayeeInfo: &alipay.PayeeInfo{
Identity: payeeAccount,
IdentityType: "ALIPAY_LOGON_ID", // 根据账户类型选择:
Name: payeeName,
// ALIPAY_USER_ID/ALIPAY_LOGON_ID
},
}
// 执行转账请求
transferRsp, err := a.AlipayClient.FundTransUniTransfer(ctx, req)
if err != nil {
return nil, fmt.Errorf("支付宝转账请求失败: %v", err)
}
return transferRsp, nil
}
func (a *AliPayService) QueryTransferStatus(
ctx context.Context,
outBizNo string,
) (*alipay.FundTransOrderQueryRsp, error) {
req := alipay.FundTransOrderQuery{
OutBizNo: outBizNo,
}
response, err := a.AlipayClient.FundTransOrderQuery(ctx, req)
if err != nil {
return nil, fmt.Errorf("支付宝接口调用失败: %v", err)
}
// 处理响应
if response.Code.IsFailure() {
return nil, fmt.Errorf("支付宝返回错误: %s-%s", response.Code, response.Msg)
}
return response, nil
}

View File

@@ -0,0 +1,246 @@
# AuthDate 授权日期验证器
## 概述
`authDate` 是一个自定义验证器,用于验证授权日期格式和有效性。该验证器确保日期格式正确,并且日期范围必须包括今天。
## 验证规则
### 1. 格式要求
- 必须为 `YYYYMMDD-YYYYMMDD` 格式
- 两个日期之间用连字符 `-` 分隔
- 每个日期必须是8位数字
### 2. 日期有效性
- 开始日期不能晚于结束日期
- 日期范围必须包括今天(如果两个日期都是今天也行)
- 支持闰年验证
- 验证月份和日期的有效性
### 3. 业务逻辑
- 空值由 `required` 标签处理,本验证器返回 `true`
- 日期范围必须覆盖今天,确保授权在有效期内
## 使用示例
### 在 DTO 中使用
```go
type FLXG0V4BReq struct {
Name string `json:"name" validate:"required,name"`
IDCard string `json:"id_card" validate:"required,idCard"`
AuthDate string `json:"auth_date" validate:"required,authDate"`
}
```
### 在结构体中使用
```go
type AuthorizationRequest struct {
UserID string `json:"user_id" validate:"required"`
AuthDate string `json:"auth_date" validate:"required,authDate"`
Scope string `json:"scope" validate:"required"`
}
```
## 有效示例
### ✅ 有效的日期范围
```json
{
"auth_date": "20240101-20240131" // 1月1日到1月31日如果今天是1月15日
}
```
```json
{
"auth_date": "20240115-20240115" // 今天到今天
}
```
```json
{
"auth_date": "20240110-20240120" // 昨天到明天如果今天是1月15日
}
```
```json
{
"auth_date": "20240101-20240201" // 上个月到下个月如果今天是1月15日
}
```
### ❌ 无效的日期范围
```json
{
"auth_date": "20240116-20240120" // 明天到后天(不包括今天)
}
```
```json
{
"auth_date": "20240101-20240114" // 上个月到昨天(不包括今天)
}
```
```json
{
"auth_date": "20240131-20240101" // 开始日期晚于结束日期
}
```
```json
{
"auth_date": "20240101-2024013A" // 非数字字符
}
```
```json
{
"auth_date": "202401-20240131" // 日期长度不对
}
```
```json
{
"auth_date": "2024010120240131" // 缺少连字符
}
```
```json
{
"auth_date": "20240230-20240301" // 无效日期2月30日
}
```
## 错误消息
当验证失败时,会返回中文错误消息:
```
"授权日期格式不正确必须是YYYYMMDD-YYYYMMDD格式且日期范围必须包括今天"
```
## 测试用例
验证器包含完整的测试用例,覆盖以下场景:
### 有效场景
- 今天到今天
- 昨天到今天
- 今天到明天
- 上周到今天
- 今天到下周
- 昨天到明天
### 无效场景
- 明天到后天(不包括今天)
- 上周到昨天(不包括今天)
- 格式错误(缺少连字符、多个连字符、长度不对、非数字)
- 无效日期2月30日、13月等
- 开始日期晚于结束日期
## 实现细节
### 核心验证逻辑
```go
func validateAuthDate(fl validator.FieldLevel) bool {
authDate := fl.Field().String()
if authDate == "" {
return true // 空值由required标签处理
}
// 1. 检查格式YYYYMMDD-YYYYMMDD
parts := strings.Split(authDate, "-")
if len(parts) != 2 {
return false
}
// 2. 解析日期
startDate, err := parseYYYYMMDD(parts[0])
if err != nil {
return false
}
endDate, err := parseYYYYMMDD(parts[1])
if err != nil {
return false
}
// 3. 检查日期顺序
if startDate.After(endDate) {
return false
}
// 4. 检查是否包括今天
today := time.Now().Truncate(24 * time.Hour)
return !startDate.After(today) && !endDate.Before(today)
}
```
### 日期解析
```go
func parseYYYYMMDD(dateStr string) (time.Time, error) {
if len(dateStr) != 8 {
return time.Time{}, fmt.Errorf("日期格式错误")
}
year, _ := strconv.Atoi(dateStr[:4])
month, _ := strconv.Atoi(dateStr[4:6])
day, _ := strconv.Atoi(dateStr[6:8])
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
// 验证日期有效性
expectedDateStr := date.Format("20060102")
if expectedDateStr != dateStr {
return time.Time{}, fmt.Errorf("无效日期")
}
return date, nil
}
```
## 注册方式
验证器已在 `RegisterCustomValidators` 函数中自动注册:
```go
func RegisterCustomValidators(validate *validator.Validate) {
// ... 其他验证器
validate.RegisterValidation("auth_date", validateAuthDate)
}
```
翻译也已自动注册:
```go
validate.RegisterTranslation("auth_date", trans, func(ut ut.Translator) error {
return ut.Add("auth_date", "{0}格式不正确必须是YYYYMMDD-YYYYMMDD格式且日期范围必须包括今天", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("auth_date", getFieldDisplayName(fe.Field()))
return t
})
```
## 注意事项
1. **时区处理**:验证器使用 UTC 时区进行日期比较
2. **空值处理**:空字符串由 `required` 标签处理,本验证器返回 `true`
3. **日期精度**:只比较日期部分,忽略时间部分
4. **闰年支持**:自动处理闰年验证
5. **错误消息**:提供中文错误消息,便于用户理解
## 运行测试
```bash
# 运行所有 authDate 相关测试
go test ./internal/shared/validator -v -run TestValidateAuthDate
# 运行所有验证器测试
go test ./internal/shared/validator -v
```

View File

@@ -0,0 +1,180 @@
package validator
import (
"testing"
"time"
"github.com/go-playground/validator/v10"
)
func TestValidateAuthDate(t *testing.T) {
validate := validator.New()
validate.RegisterValidation("auth_date", validateAuthDate)
today := time.Now().Format("20060102")
yesterday := time.Now().AddDate(0, 0, -1).Format("20060102")
tomorrow := time.Now().AddDate(0, 0, 1).Format("20060102")
lastWeek := time.Now().AddDate(0, 0, -7).Format("20060102")
nextWeek := time.Now().AddDate(0, 0, 7).Format("20060102")
tests := []struct {
name string
authDate string
wantErr bool
}{
{
name: "今天到今天 - 有效",
authDate: today + "-" + today,
wantErr: false,
},
{
name: "昨天到今天 - 有效",
authDate: yesterday + "-" + today,
wantErr: false,
},
{
name: "今天到明天 - 有效",
authDate: today + "-" + tomorrow,
wantErr: false,
},
{
name: "上周到今天 - 有效",
authDate: lastWeek + "-" + today,
wantErr: false,
},
{
name: "今天到下周 - 有效",
authDate: today + "-" + nextWeek,
wantErr: false,
},
{
name: "昨天到明天 - 有效",
authDate: yesterday + "-" + tomorrow,
wantErr: false,
},
{
name: "明天到后天 - 无效(不包括今天)",
authDate: tomorrow + "-" + time.Now().AddDate(0, 0, 2).Format("20060102"),
wantErr: true,
},
{
name: "上周到昨天 - 无效(不包括今天)",
authDate: lastWeek + "-" + yesterday,
wantErr: true,
},
{
name: "格式错误 - 缺少连字符",
authDate: "2024010120240131",
wantErr: true,
},
{
name: "格式错误 - 多个连字符",
authDate: "20240101-20240131-20240201",
wantErr: true,
},
{
name: "格式错误 - 日期长度不对",
authDate: "202401-20240131",
wantErr: true,
},
{
name: "格式错误 - 非数字",
authDate: "20240101-2024013A",
wantErr: true,
},
{
name: "无效日期 - 2月30日",
authDate: "20240230-20240301",
wantErr: true,
},
{
name: "无效日期 - 13月",
authDate: "20241301-20241331",
wantErr: true,
},
{
name: "开始日期晚于结束日期",
authDate: "20240131-20240101",
wantErr: true,
},
{
name: "空字符串 - 由required处理",
authDate: "",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validate.Var(tt.authDate, "auth_date")
if (err != nil) != tt.wantErr {
t.Errorf("validateAuthDate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestParseYYYYMMDD(t *testing.T) {
tests := []struct {
name string
dateStr string
want time.Time
wantErr bool
}{
{
name: "有效日期",
dateStr: "20240101",
want: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
wantErr: false,
},
{
name: "闰年2月29日",
dateStr: "20240229",
want: time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC),
wantErr: false,
},
{
name: "非闰年2月29日",
dateStr: "20230229",
want: time.Time{},
wantErr: true,
},
{
name: "长度错误",
dateStr: "202401",
want: time.Time{},
wantErr: true,
},
{
name: "非数字",
dateStr: "2024010A",
want: time.Time{},
wantErr: true,
},
{
name: "无效月份",
dateStr: "20241301",
want: time.Time{},
wantErr: true,
},
{
name: "无效日期",
dateStr: "20240230",
want: time.Time{},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseYYYYMMDD(tt.dateStr)
if (err != nil) != tt.wantErr {
t.Errorf("parseYYYYMMDD() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && !got.Equal(tt.want) {
t.Errorf("parseYYYYMMDD() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -1,8 +1,12 @@
package validator
import (
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"time"
"github.com/go-playground/validator/v10"
)
@@ -38,6 +42,18 @@ func RegisterCustomValidators(validate *validator.Validate) {
// URL验证器
validate.RegisterValidation("url", validateURL)
// 企业邮箱验证器
validate.RegisterValidation("enterprise_email", validateEnterpriseEmail)
// 企业地址验证器
validate.RegisterValidation("enterprise_address", validateEnterpriseAddress)
// IP地址验证器
validate.RegisterValidation("ip", validateIP)
// 授权日期验证器
validate.RegisterValidation("auth_date", validateAuthDate)
}
// validatePhone 手机号验证
@@ -111,4 +127,113 @@ func validateURL(fl validator.FieldLevel) bool {
urlStr := fl.Field().String()
_, err := url.ParseRequestURI(urlStr)
return err == nil
}
// validateEnterpriseEmail 企业邮箱验证
func validateEnterpriseEmail(fl validator.FieldLevel) bool {
email := fl.Field().String()
// 邮箱格式验证:用户名@域名.顶级域名
matched, _ := regexp.MatchString(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`, email)
return matched
}
// validateEnterpriseAddress 企业地址验证
func validateEnterpriseAddress(fl validator.FieldLevel) bool {
address := fl.Field().String()
// 地址长度验证2-200字符不能只包含空格
if len(strings.TrimSpace(address)) < 2 || len(address) > 200 {
return false
}
// 地址不能只包含特殊字符
matched, _ := regexp.MatchString(`^[^\s]+.*[^\s]+$`, strings.TrimSpace(address))
return matched
}
// validateIP IP地址验证支持IPv4
func validateIP(fl validator.FieldLevel) bool {
ip := fl.Field().String()
// 使用正则表达式验证IPv4格式
pattern := `^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`
matched, _ := regexp.MatchString(pattern, ip)
return matched
}
// validateAuthDate 授权日期验证器
// 格式YYYYMMDD-YYYYMMDD之前的日期范围必须包括今天
func validateAuthDate(fl validator.FieldLevel) bool {
authDate := fl.Field().String()
if authDate == "" {
return true // 空值由required标签处理
}
// 检查格式YYYYMMDD-YYYYMMDD
parts := strings.Split(authDate, "-")
if len(parts) != 2 {
return false
}
startDateStr := parts[0]
endDateStr := parts[1]
// 检查日期格式是否为8位数字
if len(startDateStr) != 8 || len(endDateStr) != 8 {
return false
}
// 解析开始日期
startDate, err := parseYYYYMMDD(startDateStr)
if err != nil {
return false
}
// 解析结束日期
endDate, err := parseYYYYMMDD(endDateStr)
if err != nil {
return false
}
// 检查开始日期不能晚于结束日期
if startDate.After(endDate) {
return false
}
// 获取今天的日期(去掉时间部分)
today := time.Now().Truncate(24 * time.Hour)
// 检查日期范围是否包括今天
// 如果两个日期都是今天也行
return !startDate.After(today) && !endDate.Before(today)
}
// parseYYYYMMDD 解析YYYYMMDD格式的日期字符串
func parseYYYYMMDD(dateStr string) (time.Time, error) {
if len(dateStr) != 8 {
return time.Time{}, fmt.Errorf("日期格式错误")
}
year, err := strconv.Atoi(dateStr[:4])
if err != nil {
return time.Time{}, err
}
month, err := strconv.Atoi(dateStr[4:6])
if err != nil {
return time.Time{}, err
}
day, err := strconv.Atoi(dateStr[6:8])
if err != nil {
return time.Time{}, err
}
// 验证日期有效性
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
// 检查解析后的日期是否与输入一致防止无效日期如20230230
expectedDateStr := date.Format("20060102")
if expectedDateStr != dateStr {
return time.Time{}, fmt.Errorf("无效日期")
}
return date, nil
}

View File

@@ -162,6 +162,38 @@ func registerCustomFieldTranslations(validate *validator.Validate, trans ut.Tran
t, _ := ut.T("url", getFieldDisplayName(fe.Field()))
return t
})
// 企业邮箱翻译
validate.RegisterTranslation("enterprise_email", trans, func(ut ut.Translator) error {
return ut.Add("enterprise_email", "{0}必须是有效的企业邮箱地址", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("enterprise_email", getFieldDisplayName(fe.Field()))
return t
})
// 企业地址翻译
validate.RegisterTranslation("enterprise_address", trans, func(ut ut.Translator) error {
return ut.Add("enterprise_address", "{0}长度必须在2-200字符之间且不能只包含空格", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("enterprise_address", getFieldDisplayName(fe.Field()))
return t
})
// IP地址翻译
validate.RegisterTranslation("ip", trans, func(ut ut.Translator) error {
return ut.Add("ip", "{0}必须是有效的IPv4地址格式", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("ip", getFieldDisplayName(fe.Field()))
return t
})
// 授权日期翻译
validate.RegisterTranslation("auth_date", trans, func(ut ut.Translator) error {
return ut.Add("auth_date", "{0}格式不正确必须是YYYYMMDD-YYYYMMDD格式且日期范围必须包括今天", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("auth_date", getFieldDisplayName(fe.Field()))
return t
})
}
// getFieldDisplayName 获取字段显示名称(中文)
@@ -176,6 +208,9 @@ func getFieldDisplayName(field string) string {
"code": "验证码",
"username": "用户名",
"email": "邮箱",
"enterprise_email": "企业邮箱",
"enterprise_address": "企业地址",
"ip_address": "IP地址",
"display_name": "显示名称",
"scene": "使用场景",
"Password": "密码",
@@ -244,6 +279,10 @@ func getFieldDisplayName(field string) string {
"ID": "ID",
"ids": "ID列表",
"IDs": "ID列表",
"auth_date": "授权日期",
"AuthDate": "授权日期",
"id_card": "身份证号",
"IDCard": "身份证号",
}
if displayName, exists := fieldNames[field]; exists {

View File

@@ -65,7 +65,7 @@ func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error
}
return err
}
return nil
}
@@ -80,7 +80,7 @@ func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error
}
return err
}
return nil
}
@@ -213,4 +213,4 @@ func (v *RequestValidator) ValidateValue(field interface{}, tag string) error {
// ValidateStruct 验证结构体(用于业务逻辑)
func (v *RequestValidator) ValidateStruct(s interface{}) error {
return v.validator.Struct(s)
}
}