v0.1
This commit is contained in:
194
internal/shared/cache/cache_config_manager.go
vendored
Normal file
194
internal/shared/cache/cache_config_manager.go
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CacheConfigManager 缓存配置管理器
|
||||
// 提供全局缓存配置管理和表级别的缓存决策
|
||||
type CacheConfigManager struct {
|
||||
config CacheConfig
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// GlobalCacheConfigManager 全局缓存配置管理器实例
|
||||
var GlobalCacheConfigManager *CacheConfigManager
|
||||
|
||||
// InitCacheConfigManager 初始化全局缓存配置管理器
|
||||
func InitCacheConfigManager(config CacheConfig) {
|
||||
GlobalCacheConfigManager = &CacheConfigManager{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GetCacheConfig 获取当前缓存配置
|
||||
func (m *CacheConfigManager) GetCacheConfig() CacheConfig {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.config
|
||||
}
|
||||
|
||||
// UpdateCacheConfig 更新缓存配置
|
||||
func (m *CacheConfigManager) UpdateCacheConfig(config CacheConfig) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
m.config = config
|
||||
}
|
||||
|
||||
// IsTableCacheEnabled 检查表是否启用缓存
|
||||
func (m *CacheConfigManager) IsTableCacheEnabled(tableName string) bool {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
// 检查表是否在禁用列表中
|
||||
for _, disabledTable := range m.config.DisabledTables {
|
||||
if disabledTable == tableName {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 如果配置了启用列表,只对启用列表中的表启用缓存
|
||||
if len(m.config.EnabledTables) > 0 {
|
||||
for _, enabledTable := range m.config.EnabledTables {
|
||||
if enabledTable == tableName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表)
|
||||
return true
|
||||
}
|
||||
|
||||
// IsTableCacheDisabled 检查表是否禁用缓存
|
||||
func (m *CacheConfigManager) IsTableCacheDisabled(tableName string) bool {
|
||||
return !m.IsTableCacheEnabled(tableName)
|
||||
}
|
||||
|
||||
// GetEnabledTables 获取启用缓存的表列表
|
||||
func (m *CacheConfigManager) GetEnabledTables() []string {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.config.EnabledTables
|
||||
}
|
||||
|
||||
// GetDisabledTables 获取禁用缓存的表列表
|
||||
func (m *CacheConfigManager) GetDisabledTables() []string {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.config.DisabledTables
|
||||
}
|
||||
|
||||
// AddEnabledTable 添加启用缓存的表
|
||||
func (m *CacheConfigManager) AddEnabledTable(tableName string) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
for _, table := range m.config.EnabledTables {
|
||||
if table == tableName {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.config.EnabledTables = append(m.config.EnabledTables, tableName)
|
||||
}
|
||||
|
||||
// AddDisabledTable 添加禁用缓存的表
|
||||
func (m *CacheConfigManager) AddDisabledTable(tableName string) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 检查是否已存在
|
||||
for _, table := range m.config.DisabledTables {
|
||||
if table == tableName {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
m.config.DisabledTables = append(m.config.DisabledTables, tableName)
|
||||
}
|
||||
|
||||
// RemoveEnabledTable 移除启用缓存的表
|
||||
func (m *CacheConfigManager) RemoveEnabledTable(tableName string) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var newEnabledTables []string
|
||||
for _, table := range m.config.EnabledTables {
|
||||
if table != tableName {
|
||||
newEnabledTables = append(newEnabledTables, table)
|
||||
}
|
||||
}
|
||||
m.config.EnabledTables = newEnabledTables
|
||||
}
|
||||
|
||||
// RemoveDisabledTable 移除禁用缓存的表
|
||||
func (m *CacheConfigManager) RemoveDisabledTable(tableName string) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
var newDisabledTables []string
|
||||
for _, table := range m.config.DisabledTables {
|
||||
if table != tableName {
|
||||
newDisabledTables = append(newDisabledTables, table)
|
||||
}
|
||||
}
|
||||
m.config.DisabledTables = newDisabledTables
|
||||
}
|
||||
|
||||
// GetTableCacheStatus 获取表的缓存状态信息
|
||||
func (m *CacheConfigManager) GetTableCacheStatus(tableName string) map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
status := map[string]interface{}{
|
||||
"table_name": tableName,
|
||||
"enabled": m.IsTableCacheEnabled(tableName),
|
||||
"disabled": m.IsTableCacheDisabled(tableName),
|
||||
}
|
||||
|
||||
// 检查是否在启用列表中
|
||||
for _, table := range m.config.EnabledTables {
|
||||
if table == tableName {
|
||||
status["in_enabled_list"] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否在禁用列表中
|
||||
for _, table := range m.config.DisabledTables {
|
||||
if table == tableName {
|
||||
status["in_disabled_list"] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// GetAllTableStatus 获取所有表的缓存状态
|
||||
func (m *CacheConfigManager) GetAllTableStatus() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
result := map[string]interface{}{
|
||||
"enabled_tables": m.config.EnabledTables,
|
||||
"disabled_tables": m.config.DisabledTables,
|
||||
"config": map[string]interface{}{
|
||||
"default_ttl": m.config.DefaultTTL,
|
||||
"table_prefix": m.config.TablePrefix,
|
||||
"max_cache_size": m.config.MaxCacheSize,
|
||||
"cache_complex_sql": m.config.CacheComplexSQL,
|
||||
"enable_stats": m.config.EnableStats,
|
||||
"enable_warmup": m.config.EnableWarmup,
|
||||
"penetration_guard": m.config.PenetrationGuard,
|
||||
"bloom_filter": m.config.BloomFilter,
|
||||
"auto_invalidate": m.config.AutoInvalidate,
|
||||
"invalidate_delay": m.config.InvalidateDelay,
|
||||
},
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
240
internal/shared/cache/gorm_cache_plugin.go
vendored
240
internal/shared/cache/gorm_cache_plugin.go
vendored
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
@@ -21,6 +22,10 @@ type GormCachePlugin struct {
|
||||
cache interfaces.CacheService
|
||||
logger *zap.Logger
|
||||
config CacheConfig
|
||||
|
||||
// 缓存失效去重机制
|
||||
invalidationQueue map[string]*time.Timer
|
||||
queueMutex sync.RWMutex
|
||||
}
|
||||
|
||||
// CacheConfig 缓存配置
|
||||
@@ -58,7 +63,8 @@ func DefaultCacheConfig() CacheConfig {
|
||||
PenetrationGuard: true,
|
||||
BloomFilter: false,
|
||||
AutoInvalidate: true,
|
||||
InvalidateDelay: 100 * time.Millisecond,
|
||||
// 增加延迟失效时间,减少频繁的缓存失效操作
|
||||
InvalidateDelay: 500 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,6 +79,7 @@ func NewGormCachePlugin(cache interfaces.CacheService, logger *zap.Logger, confi
|
||||
cache: cache,
|
||||
logger: logger,
|
||||
config: cfg,
|
||||
invalidationQueue: make(map[string]*time.Timer),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,12 +94,30 @@ func (p *GormCachePlugin) Initialize(db *gorm.DB) error {
|
||||
zap.Duration("default_ttl", p.config.DefaultTTL),
|
||||
zap.Bool("auto_invalidate", p.config.AutoInvalidate),
|
||||
zap.Bool("penetration_guard", p.config.PenetrationGuard),
|
||||
zap.Duration("invalidate_delay", p.config.InvalidateDelay),
|
||||
)
|
||||
|
||||
// 注册回调函数
|
||||
return p.registerCallbacks(db)
|
||||
}
|
||||
|
||||
// Shutdown 关闭插件,清理资源
|
||||
func (p *GormCachePlugin) Shutdown() {
|
||||
p.queueMutex.Lock()
|
||||
defer p.queueMutex.Unlock()
|
||||
|
||||
// 停止所有定时器
|
||||
for table, timer := range p.invalidationQueue {
|
||||
timer.Stop()
|
||||
p.logger.Debug("停止缓存失效定时器", zap.String("table", table))
|
||||
}
|
||||
|
||||
// 清空队列
|
||||
p.invalidationQueue = make(map[string]*time.Timer)
|
||||
|
||||
p.logger.Info("GORM缓存插件已关闭")
|
||||
}
|
||||
|
||||
// registerCallbacks 注册GORM回调
|
||||
func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error {
|
||||
// Query回调 - 查询时检查缓存
|
||||
@@ -117,6 +142,7 @@ func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error {
|
||||
func (p *GormCachePlugin) beforeQuery(db *gorm.DB) {
|
||||
// 检查是否启用缓存
|
||||
if !p.shouldCache(db) {
|
||||
p.logger.Debug("跳过缓存", zap.String("table", db.Statement.Table))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -147,7 +173,15 @@ func (p *GormCachePlugin) beforeQuery(db *gorm.DB) {
|
||||
p.updateStats("hit", db.Statement.Table)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
p.logger.Warn("缓存数据恢复失败,将执行数据库查询",
|
||||
zap.String("cache_key", cacheKey),
|
||||
zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
p.logger.Debug("缓存未命中",
|
||||
zap.String("cache_key", cacheKey),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// 缓存未命中,设置标记
|
||||
@@ -195,7 +229,10 @@ func (p *GormCachePlugin) afterCreate(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
|
||||
// 只对启用缓存的表执行失效操作
|
||||
if p.shouldInvalidateTable(db.Statement.Table) {
|
||||
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// afterUpdate 更新后回调
|
||||
@@ -204,7 +241,10 @@ func (p *GormCachePlugin) afterUpdate(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
|
||||
// 只对启用缓存的表执行失效操作
|
||||
if p.shouldInvalidateTable(db.Statement.Table) {
|
||||
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// afterDelete 删除后回调
|
||||
@@ -213,26 +253,64 @@ func (p *GormCachePlugin) afterDelete(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
|
||||
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
|
||||
// 只对启用缓存的表执行失效操作
|
||||
if p.shouldInvalidateTable(db.Statement.Table) {
|
||||
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 缓存管理方法 ================
|
||||
|
||||
// shouldInvalidateTable 判断是否应该对表执行缓存失效操作
|
||||
func (p *GormCachePlugin) shouldInvalidateTable(table string) bool {
|
||||
// 使用全局缓存配置管理器进行智能决策
|
||||
if GlobalCacheConfigManager != nil {
|
||||
return GlobalCacheConfigManager.IsTableCacheEnabled(table)
|
||||
}
|
||||
|
||||
// 如果全局管理器未初始化,使用本地配置
|
||||
// 检查表是否在禁用列表中
|
||||
for _, disabledTable := range p.config.DisabledTables {
|
||||
if disabledTable == table {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 如果配置了启用列表,只对启用列表中的表执行失效操作
|
||||
if len(p.config.EnabledTables) > 0 {
|
||||
for _, enabledTable := range p.config.EnabledTables {
|
||||
if enabledTable == table {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果没有配置启用列表,默认对所有表执行失效操作(除了禁用列表中的表)
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldCache 判断是否应该缓存
|
||||
func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool {
|
||||
// 检查是否明确禁用缓存
|
||||
// 检查是否手动禁用缓存
|
||||
if value, ok := db.Statement.Get("cache:disabled"); ok && value.(bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否明确启用缓存
|
||||
// 检查是否手动启用缓存
|
||||
if value, ok := db.Statement.Get("cache:enabled"); ok && value.(bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 使用全局缓存配置管理器进行智能决策
|
||||
if GlobalCacheConfigManager != nil {
|
||||
return GlobalCacheConfigManager.IsTableCacheEnabled(db.Statement.Table)
|
||||
}
|
||||
|
||||
// 如果全局管理器未初始化,使用本地配置
|
||||
// 检查表是否在禁用列表中
|
||||
for _, table := range p.config.DisabledTables {
|
||||
if table == db.Statement.Table {
|
||||
for _, disabledTable := range p.config.DisabledTables {
|
||||
if disabledTable == db.Statement.Table {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -247,11 +325,7 @@ func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否为复杂查询
|
||||
if !p.config.CacheComplexSQL && p.isComplexQuery(db) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -292,10 +366,27 @@ func (p *GormCachePlugin) generateCacheKey(db *gorm.DB) string {
|
||||
|
||||
// hashSQL 对SQL语句和参数进行hash
|
||||
func (p *GormCachePlugin) hashSQL(sql string, vars []interface{}) string {
|
||||
// 将SQL和参数组合
|
||||
// 修复:改进SQL hash生成,确保唯一性
|
||||
combined := sql
|
||||
for _, v := range vars {
|
||||
combined += fmt.Sprintf(":%v", v)
|
||||
// 使用更精确的格式化,避免类型信息丢失
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
combined += ":" + val
|
||||
case int, int32, int64:
|
||||
combined += fmt.Sprintf(":%d", val)
|
||||
case float32, float64:
|
||||
combined += fmt.Sprintf(":%f", val)
|
||||
case bool:
|
||||
combined += fmt.Sprintf(":%t", val)
|
||||
default:
|
||||
// 对于复杂类型,使用JSON序列化
|
||||
if jsonData, err := json.Marshal(v); err == nil {
|
||||
combined += ":" + string(jsonData)
|
||||
} else {
|
||||
combined += fmt.Sprintf(":%v", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算MD5 hash
|
||||
@@ -328,9 +419,22 @@ func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *
|
||||
return fmt.Errorf("查询结果为空")
|
||||
}
|
||||
|
||||
// 修复:改进缓存数据保存逻辑
|
||||
var dataToCache interface{}
|
||||
|
||||
// 如果dest是切片,需要特殊处理
|
||||
destValue := reflect.ValueOf(dest)
|
||||
if destValue.Kind() == reflect.Ptr && destValue.Elem().Kind() == reflect.Slice {
|
||||
// 对于切片,直接使用原始数据
|
||||
dataToCache = destValue.Elem().Interface()
|
||||
} else {
|
||||
// 对于单个对象,也直接使用原始数据
|
||||
dataToCache = dest
|
||||
}
|
||||
|
||||
// 构建缓存结果
|
||||
result := CachedResult{
|
||||
Data: dest,
|
||||
Data: dataToCache,
|
||||
RowCount: db.Statement.RowsAffected,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
@@ -340,6 +444,10 @@ func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *
|
||||
|
||||
// 保存到缓存
|
||||
if err := p.cache.Set(ctx, cacheKey, result, ttl); err != nil {
|
||||
p.logger.Error("保存查询结果到缓存失败",
|
||||
zap.String("cache_key", cacheKey),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("保存到缓存失败: %w", err)
|
||||
}
|
||||
|
||||
@@ -364,25 +472,35 @@ func (p *GormCachePlugin) restoreFromCache(db *gorm.DB, cachedResult *CachedResu
|
||||
return fmt.Errorf("目标对象必须是指针")
|
||||
}
|
||||
|
||||
// 将缓存数据复制到目标
|
||||
// 修复:改进缓存数据恢复逻辑
|
||||
cachedValue := reflect.ValueOf(cachedResult.Data)
|
||||
if !cachedValue.Type().AssignableTo(destValue.Elem().Type()) {
|
||||
|
||||
// 如果类型完全匹配,直接赋值
|
||||
if cachedValue.Type().AssignableTo(destValue.Elem().Type()) {
|
||||
destValue.Elem().Set(cachedValue)
|
||||
} else {
|
||||
// 尝试JSON转换
|
||||
jsonData, err := json.Marshal(cachedResult.Data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("缓存数据类型不匹配")
|
||||
p.logger.Error("序列化缓存数据失败", zap.Error(err))
|
||||
return fmt.Errorf("缓存数据类型转换失败: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, db.Statement.Dest); err != nil {
|
||||
p.logger.Error("反序列化缓存数据失败",
|
||||
zap.String("json_data", string(jsonData)),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("JSON反序列化失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
destValue.Elem().Set(cachedValue)
|
||||
}
|
||||
|
||||
// 设置影响行数
|
||||
db.Statement.RowsAffected = cachedResult.RowCount
|
||||
|
||||
p.logger.Debug("从缓存恢复数据成功",
|
||||
zap.Int64("rows", cachedResult.RowCount),
|
||||
zap.Time("timestamp", cachedResult.Timestamp))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -400,37 +518,79 @@ func (p *GormCachePlugin) getTTL(db *gorm.DB) time.Duration {
|
||||
|
||||
// invalidateTableCache 失效表相关缓存
|
||||
func (p *GormCachePlugin) invalidateTableCache(ctx context.Context, table string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// 延迟失效(避免并发问题)
|
||||
if p.config.InvalidateDelay > 0 {
|
||||
time.AfterFunc(p.config.InvalidateDelay, func() {
|
||||
p.doInvalidateTableCache(ctx, table)
|
||||
})
|
||||
} else {
|
||||
p.doInvalidateTableCache(ctx, table)
|
||||
// 使用去重机制,避免重复的缓存失效操作
|
||||
p.queueMutex.Lock()
|
||||
defer p.queueMutex.Unlock()
|
||||
|
||||
// 如果已经有相同的失效操作在队列中,取消之前的定时器
|
||||
if timer, exists := p.invalidationQueue[table]; exists {
|
||||
timer.Stop()
|
||||
delete(p.invalidationQueue, table)
|
||||
}
|
||||
|
||||
// 创建独立的上下文,避免受到原始请求上下文的影响
|
||||
// 设置合理的超时时间,避免缓存失效操作阻塞
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
// 创建新的定时器
|
||||
timer := time.AfterFunc(p.config.InvalidateDelay, func() {
|
||||
// 执行缓存失效
|
||||
p.doInvalidateTableCache(cacheCtx, table)
|
||||
|
||||
// 清理定时器引用
|
||||
p.queueMutex.Lock()
|
||||
delete(p.invalidationQueue, table)
|
||||
p.queueMutex.Unlock()
|
||||
|
||||
// 取消上下文
|
||||
cancel()
|
||||
})
|
||||
|
||||
// 将定时器加入队列
|
||||
p.invalidationQueue[table] = timer
|
||||
|
||||
p.logger.Debug("缓存失效操作已加入队列",
|
||||
zap.String("table", table),
|
||||
zap.Duration("delay", p.config.InvalidateDelay),
|
||||
)
|
||||
}
|
||||
|
||||
// doInvalidateTableCache 执行缓存失效
|
||||
func (p *GormCachePlugin) doInvalidateTableCache(ctx context.Context, table string) {
|
||||
pattern := fmt.Sprintf("%s:%s:*", p.config.TablePrefix, table)
|
||||
|
||||
if err := p.cache.DeletePattern(ctx, pattern); err != nil {
|
||||
p.logger.Warn("失效表缓存失败",
|
||||
// 添加重试机制,提高缓存失效的可靠性
|
||||
maxRetries := 3
|
||||
for attempt := 1; attempt <= maxRetries; attempt++ {
|
||||
if err := p.cache.DeletePattern(ctx, pattern); err != nil {
|
||||
if attempt < maxRetries {
|
||||
p.logger.Warn("缓存失效失败,准备重试",
|
||||
zap.String("table", table),
|
||||
zap.String("pattern", pattern),
|
||||
zap.Int("attempt", attempt),
|
||||
zap.Error(err),
|
||||
)
|
||||
// 短暂延迟后重试
|
||||
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
p.logger.Warn("失效表缓存失败",
|
||||
zap.String("table", table),
|
||||
zap.String("pattern", pattern),
|
||||
zap.Int("attempts", maxRetries),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// 成功删除,记录日志并退出
|
||||
p.logger.Debug("表缓存已失效",
|
||||
zap.String("table", table),
|
||||
zap.String("pattern", pattern),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
p.logger.Debug("表缓存已失效",
|
||||
zap.String("table", table),
|
||||
zap.String("pattern", pattern),
|
||||
)
|
||||
}
|
||||
|
||||
// updateStats 更新统计信息
|
||||
|
||||
113
internal/shared/crypto/crypto.go
Normal file
113
internal/shared/crypto/crypto.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PKCS7填充
|
||||
func PKCS7Padding(ciphertext []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(ciphertext)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padtext...)
|
||||
}
|
||||
|
||||
// 去除PKCS7填充
|
||||
func PKCS7UnPadding(origData []byte) ([]byte, error) {
|
||||
length := len(origData)
|
||||
if length == 0 {
|
||||
return nil, errors.New("input data error")
|
||||
}
|
||||
unpadding := int(origData[length-1])
|
||||
if unpadding > length {
|
||||
return nil, errors.New("unpadding size is invalid")
|
||||
}
|
||||
|
||||
// 检查填充字节是否一致
|
||||
for i := 0; i < unpadding; i++ {
|
||||
if origData[length-1-i] != byte(unpadding) {
|
||||
return nil, errors.New("invalid padding")
|
||||
}
|
||||
}
|
||||
|
||||
return origData[:(length - unpadding)], nil
|
||||
}
|
||||
|
||||
// AES CBC模式加密,Base64传入传出
|
||||
func AesEncrypt(plainText []byte, key string) (string, error) {
|
||||
keyBytes, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
plainText = PKCS7Padding(plainText, blockSize)
|
||||
|
||||
cipherText := make([]byte, blockSize+len(plainText))
|
||||
iv := cipherText[:blockSize] // 使用前blockSize字节作为IV
|
||||
_, err = io.ReadFull(rand.Reader, iv)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
mode.CryptBlocks(cipherText[blockSize:], plainText)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// AES CBC模式解密,Base64传入传出
|
||||
func AesDecrypt(cipherTextBase64 string, key string) ([]byte, error) {
|
||||
keyBytes, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blockSize := block.BlockSize()
|
||||
if len(cipherText) < blockSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := cipherText[:blockSize]
|
||||
cipherText = cipherText[blockSize:]
|
||||
|
||||
if len(cipherText)%blockSize != 0 {
|
||||
return nil, errors.New("ciphertext is not a multiple of the block size")
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
mode.CryptBlocks(cipherText, cipherText)
|
||||
|
||||
plainText, err := PKCS7UnPadding(cipherText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plainText, nil
|
||||
}
|
||||
|
||||
// Md5Encrypt 用于对传入的message进行MD5加密
|
||||
func Md5Encrypt(message string) string {
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(message)) // 将字符串转换为字节切片并写入
|
||||
return hex.EncodeToString(hash.Sum(nil)) // 将哈希值转换为16进制字符串并返回
|
||||
}
|
||||
63
internal/shared/crypto/generate.go
Normal file
63
internal/shared/crypto/generate.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
mathrand "math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 生成AES-128密钥的函数,符合市面规范
|
||||
func GenerateSecretKey() (string, error) {
|
||||
key := make([]byte, 16) // 16字节密钥
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func GenerateSecretId() (string, error) {
|
||||
// 创建一个字节数组,用于存储随机数据
|
||||
bytes := make([]byte, 8) // 因为每个字节表示两个16进制字符
|
||||
|
||||
// 读取随机字节到数组中
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将字节数组转换为16进制字符串
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// GenerateTransactionID 生成16位数的交易单号
|
||||
func GenerateTransactionID() string {
|
||||
length := 16
|
||||
// 获取当前时间戳
|
||||
timestamp := time.Now().UnixNano()
|
||||
|
||||
// 转换为字符串
|
||||
timeStr := strconv.FormatInt(timestamp, 10)
|
||||
|
||||
// 生成随机数
|
||||
mathrand.Seed(time.Now().UnixNano())
|
||||
randomPart := strconv.Itoa(mathrand.Intn(1000000))
|
||||
|
||||
// 组合时间戳和随机数
|
||||
combined := timeStr + randomPart
|
||||
|
||||
// 如果长度超出指定值,则截断;如果不够,则填充随机字符
|
||||
if len(combined) >= length {
|
||||
return combined[:length]
|
||||
}
|
||||
|
||||
// 如果长度不够,填充0
|
||||
for len(combined) < length {
|
||||
combined += strconv.Itoa(mathrand.Intn(10)) // 填充随机数
|
||||
}
|
||||
|
||||
return combined
|
||||
}
|
||||
150
internal/shared/crypto/west_crypto.go
Normal file
150
internal/shared/crypto/west_crypto.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
const (
|
||||
KEY_SIZE = 16 // AES-128, 16 bytes
|
||||
)
|
||||
|
||||
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
|
||||
func WestDexEncrypt(data, secretKey string) (string, error) {
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
ciphertext, err := aesEncrypt([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
|
||||
func WestDexDecrypt(encodedData, secretKey string) ([]byte, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
plaintext, err := aesDecrypt(ciphertext, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
|
||||
func generateAESKey(length int, password []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(password)
|
||||
state := h.Sum(nil)
|
||||
|
||||
keyBytes := make([]byte, 0, length/8)
|
||||
for len(keyBytes) < length/8 {
|
||||
h := sha1.New()
|
||||
h.Write(state)
|
||||
state = h.Sum(nil)
|
||||
keyBytes = append(keyBytes, state...)
|
||||
}
|
||||
|
||||
return keyBytes[:length/8]
|
||||
}
|
||||
|
||||
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
|
||||
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
|
||||
ciphertext := make([]byte, len(paddedPlaintext))
|
||||
mode := newECBEncrypter(block)
|
||||
mode.CryptBlocks(ciphertext, paddedPlaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
|
||||
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode := newECBDecrypter(block)
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext), nil
|
||||
}
|
||||
|
||||
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
|
||||
func pkcs5Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs5Unpadding removes PKCS5 padding from the input
|
||||
func pkcs5Unpadding(src []byte) []byte {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
return src[:(length - unpadding)]
|
||||
}
|
||||
|
||||
// ECB mode encryption/decryption
|
||||
type ecb struct {
|
||||
b cipher.Block
|
||||
blockSize int
|
||||
}
|
||||
|
||||
func newECB(b cipher.Block) *ecb {
|
||||
return &ecb{
|
||||
b: b,
|
||||
blockSize: b.BlockSize(),
|
||||
}
|
||||
}
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Encrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Decrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/shared/cache"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
@@ -26,41 +27,112 @@ func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName stri
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 智能缓存决策方法 ================
|
||||
|
||||
// isTableCacheEnabled 检查表是否启用缓存
|
||||
func (r *CachedBaseRepositoryImpl) isTableCacheEnabled() bool {
|
||||
// 使用全局缓存配置管理器
|
||||
if cache.GlobalCacheConfigManager != nil {
|
||||
return cache.GlobalCacheConfigManager.IsTableCacheEnabled(r.tableName)
|
||||
}
|
||||
|
||||
// 如果全局管理器未初始化,默认启用缓存
|
||||
r.logger.Warn("全局缓存配置管理器未初始化,默认启用缓存",
|
||||
zap.String("table", r.tableName))
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldUseCacheForTable 智能判断是否应该对当前表使用缓存
|
||||
func (r *CachedBaseRepositoryImpl) shouldUseCacheForTable() bool {
|
||||
// 检查表是否启用缓存
|
||||
if !r.isTableCacheEnabled() {
|
||||
r.logger.Debug("表未启用缓存,跳过缓存操作",
|
||||
zap.String("table", r.tableName))
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// ================ 智能缓存方法 ================
|
||||
|
||||
// GetWithCache 带缓存的单条查询
|
||||
// GetWithCache 带缓存的单条查询(智能决策)
|
||||
func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl)
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 智能决策:根据表配置决定是否使用缓存
|
||||
if r.shouldUseCacheForTable() {
|
||||
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
|
||||
r.logger.Debug("执行带缓存查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Duration("ttl", ttl),
|
||||
zap.String("where", where))
|
||||
} else {
|
||||
db = db.Set("cache:disabled", true)
|
||||
r.logger.Debug("执行无缓存查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("where", where))
|
||||
}
|
||||
|
||||
return db.Where(where, args...).First(dest).Error
|
||||
}
|
||||
|
||||
// FindWithCache 带缓存的多条查询
|
||||
// FindWithCache 带缓存的多条查询(智能决策)
|
||||
func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl)
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 智能决策:根据表配置决定是否使用缓存
|
||||
if r.shouldUseCacheForTable() {
|
||||
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
|
||||
r.logger.Debug("执行带缓存批量查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Duration("ttl", ttl),
|
||||
zap.String("where", where))
|
||||
} else {
|
||||
db = db.Set("cache:disabled", true)
|
||||
r.logger.Debug("执行无缓存批量查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("where", where))
|
||||
}
|
||||
|
||||
return db.Where(where, args...).Find(dest).Error
|
||||
}
|
||||
|
||||
// CountWithCache 带缓存的计数查询
|
||||
// CountWithCache 带缓存的计数查询(智能决策)
|
||||
func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl).
|
||||
Model(entity)
|
||||
db := r.GetDB(ctx).Model(entity)
|
||||
|
||||
// 智能决策:根据表配置决定是否使用缓存
|
||||
if r.shouldUseCacheForTable() {
|
||||
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
|
||||
r.logger.Debug("执行带缓存计数查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Duration("ttl", ttl),
|
||||
zap.String("where", where))
|
||||
} else {
|
||||
db = db.Set("cache:disabled", true)
|
||||
r.logger.Debug("执行无缓存计数查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("where", where))
|
||||
}
|
||||
|
||||
return db.Where(where, args...).Count(count).Error
|
||||
}
|
||||
|
||||
// ListWithCache 带缓存的列表查询
|
||||
// ListWithCache 带缓存的列表查询(智能决策)
|
||||
func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl)
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 智能决策:根据表配置决定是否使用缓存
|
||||
if r.shouldUseCacheForTable() {
|
||||
db = db.Set("cache:enabled", true).Set("cache:ttl", ttl)
|
||||
r.logger.Debug("执行带缓存列表查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Duration("ttl", ttl))
|
||||
} else {
|
||||
db = db.Set("cache:disabled", true)
|
||||
r.logger.Debug("执行无缓存列表查询",
|
||||
zap.String("table", r.tableName))
|
||||
}
|
||||
|
||||
// 应用where条件
|
||||
if options.Where != "" {
|
||||
@@ -90,12 +162,12 @@ func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest inter
|
||||
|
||||
// CacheListOptions 缓存列表查询选项
|
||||
type CacheListOptions struct {
|
||||
Where string `json:"where"`
|
||||
Args []interface{} `json:"args"`
|
||||
Order string `json:"order"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Preloads []string `json:"preloads"`
|
||||
Where string `json:"where"`
|
||||
Args []interface{} `json:"args"`
|
||||
Order string `json:"order"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Preloads []string `json:"preloads"`
|
||||
}
|
||||
|
||||
// ================ 缓存控制方法 ================
|
||||
@@ -142,6 +214,10 @@ func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl {
|
||||
|
||||
// SmartGetByID 智能ID查询(自动缓存)
|
||||
func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error {
|
||||
r.logger.Debug("执行智能ID查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("id", id))
|
||||
|
||||
return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id)
|
||||
}
|
||||
|
||||
@@ -151,7 +227,7 @@ func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest int
|
||||
if len(ttl) > 0 {
|
||||
cacheTTL = ttl[0]
|
||||
}
|
||||
|
||||
|
||||
return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value)
|
||||
}
|
||||
|
||||
@@ -161,6 +237,12 @@ func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface
|
||||
cacheTTL := r.calculateCacheTTL(options)
|
||||
useCache := r.shouldUseCache(options)
|
||||
|
||||
r.logger.Debug("执行智能列表查询",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Bool("use_cache", useCache),
|
||||
zap.Duration("cache_ttl", cacheTTL))
|
||||
|
||||
// 修复:确保缓存标记在查询前正确设置
|
||||
db := r.GetDB(ctx)
|
||||
if useCache {
|
||||
db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL)
|
||||
@@ -254,14 +336,14 @@ func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions
|
||||
|
||||
// WarmupCommonQueries 预热常用查询
|
||||
func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error {
|
||||
r.logger.Info("开始预热缓存",
|
||||
r.logger.Info("开始预热缓存",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Int("queries", len(queries)),
|
||||
)
|
||||
|
||||
for _, query := range queries {
|
||||
if err := r.executeWarmupQuery(ctx, query); err != nil {
|
||||
r.logger.Warn("缓存预热失败",
|
||||
r.logger.Warn("缓存预热失败",
|
||||
zap.String("query", query.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
@@ -273,11 +355,11 @@ func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, quer
|
||||
|
||||
// WarmupQuery 预热查询定义
|
||||
type WarmupQuery struct {
|
||||
Name string `json:"name"`
|
||||
SQL string `json:"sql"`
|
||||
Args []interface{} `json:"args"`
|
||||
TTL time.Duration `json:"ttl"`
|
||||
Dest interface{} `json:"dest"`
|
||||
Name string `json:"name"`
|
||||
SQL string `json:"sql"`
|
||||
Args []interface{} `json:"args"`
|
||||
TTL time.Duration `json:"ttl"`
|
||||
Dest interface{} `json:"dest"`
|
||||
}
|
||||
|
||||
// executeWarmupQuery 执行预热查询
|
||||
@@ -313,7 +395,7 @@ func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interfa
|
||||
if err := r.CreateEntity(ctx, newEntity); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
// 将新创建的实体复制到dest
|
||||
// 这里需要反射或其他方式复制
|
||||
return nil
|
||||
@@ -333,7 +415,7 @@ func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []
|
||||
|
||||
// RefreshCache 刷新缓存
|
||||
func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error {
|
||||
r.logger.Info("刷新缓存",
|
||||
r.logger.Info("刷新缓存",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("pattern", pattern),
|
||||
)
|
||||
@@ -348,9 +430,9 @@ func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern str
|
||||
// GetCacheInfo 获取缓存信息
|
||||
func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"table_name": r.tableName,
|
||||
"cache_enabled": true,
|
||||
"default_ttl": "30m",
|
||||
"table_name": r.tableName,
|
||||
"cache_enabled": true,
|
||||
"default_ttl": "30m",
|
||||
"cache_patterns": []string{
|
||||
fmt.Sprintf("gorm_cache:%s:*", r.tableName),
|
||||
},
|
||||
@@ -359,9 +441,9 @@ func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
|
||||
|
||||
// LogCacheOperation 记录缓存操作
|
||||
func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) {
|
||||
r.logger.Debug("缓存操作",
|
||||
r.logger.Debug("缓存操作",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("operation", operation),
|
||||
zap.String("details", details),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,18 +2,43 @@ package esign
|
||||
|
||||
import "fmt"
|
||||
|
||||
type EsignContractConfig struct {
|
||||
Name string `json:"name" yaml:"name"`
|
||||
ExpireDays int `json:"expireDays" yaml:"expire_days"`
|
||||
RetryCount int `json:"retryCount" yaml:"retry_count"`
|
||||
}
|
||||
|
||||
type EsignAuthConfig struct {
|
||||
OrgAuthModes []string `json:"orgAuthModes" yaml:"org_auth_modes"`
|
||||
DefaultAuthMode string `json:"defaultAuthMode" yaml:"default_auth_mode"`
|
||||
PsnAuthModes []string `json:"psnAuthModes" yaml:"psn_auth_modes"`
|
||||
WillingnessAuthModes []string `json:"willingnessAuthModes" yaml:"willingness_auth_modes"`
|
||||
RedirectUrl string `json:"redirectUrl" yaml:"redirect_url"`
|
||||
}
|
||||
|
||||
type EsignSignConfig struct {
|
||||
AutoFinish bool `json:"autoFinish" yaml:"auto_finish"`
|
||||
SignFieldStyle int `json:"signFieldStyle" yaml:"sign_field_style"`
|
||||
ClientType string `json:"clientType" yaml:"client_type"`
|
||||
RedirectUrl string `json:"redirectUrl" yaml:"redirect_url"`
|
||||
}
|
||||
|
||||
// Config e签宝服务配置结构体
|
||||
// 包含应用ID、密钥、服务器URL和模板ID等基础配置信息
|
||||
// 新增Contract、Auth、Sign配置
|
||||
type Config struct {
|
||||
AppID string `json:"appId"` // 应用ID
|
||||
AppSecret string `json:"appSecret"` // 应用密钥
|
||||
ServerURL string `json:"serverUrl"` // 服务器URL
|
||||
TemplateID string `json:"templateId"` // 模板ID
|
||||
AppID string `json:"appId" yaml:"app_id"`
|
||||
AppSecret string `json:"appSecret" yaml:"app_secret"`
|
||||
ServerURL string `json:"serverUrl" yaml:"server_url"`
|
||||
TemplateID string `json:"templateId" yaml:"template_id"`
|
||||
Contract *EsignContractConfig `json:"contract" yaml:"contract"`
|
||||
Auth *EsignAuthConfig `json:"auth" yaml:"auth"`
|
||||
Sign *EsignSignConfig `json:"sign" yaml:"sign"`
|
||||
}
|
||||
|
||||
// NewConfig 创建新的配置实例
|
||||
// 提供配置验证和默认值设置
|
||||
func NewConfig(appID, appSecret, serverURL, templateID string) (*Config, error) {
|
||||
func NewConfig(appID, appSecret, serverURL, templateID string, contract *EsignContractConfig, auth *EsignAuthConfig, sign *EsignSignConfig) (*Config, error) {
|
||||
if appID == "" {
|
||||
return nil, fmt.Errorf("应用ID不能为空")
|
||||
}
|
||||
@@ -26,12 +51,15 @@ func NewConfig(appID, appSecret, serverURL, templateID string) (*Config, error)
|
||||
if templateID == "" {
|
||||
return nil, fmt.Errorf("模板ID不能为空")
|
||||
}
|
||||
|
||||
|
||||
return &Config{
|
||||
AppID: appID,
|
||||
AppSecret: appSecret,
|
||||
ServerURL: serverURL,
|
||||
TemplateID: templateID,
|
||||
Contract: contract,
|
||||
Auth: auth,
|
||||
Sign: sign,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -60,7 +88,7 @@ const (
|
||||
AuthModeBank = "PSN_BANK" // 银行卡认证
|
||||
|
||||
// 意愿认证模式
|
||||
WillingnessAuthSMS = "CODE_SMS" // 短信验证码
|
||||
WillingnessAuthSMS = "CODE_SMS" // 短信验证码
|
||||
WillingnessAuthEmail = "CODE_EMAIL" // 邮箱验证码
|
||||
|
||||
// 证件类型常量
|
||||
@@ -69,7 +97,7 @@ const (
|
||||
|
||||
// 签署区样式常量
|
||||
SignFieldStyleNormal = 1 // 普通签章
|
||||
SignFieldStyleSeam = 2 // 骑缝签章
|
||||
SignFieldStyleSeam = 2 // 骑缝签章
|
||||
|
||||
// 签署人类型常量
|
||||
SignerTypePerson = 0 // 个人
|
||||
@@ -80,4 +108,4 @@ const (
|
||||
|
||||
// 客户端类型常量
|
||||
ClientTypeAll = "ALL" // 所有客户端
|
||||
)
|
||||
)
|
||||
|
||||
@@ -13,6 +13,24 @@ func Example() {
|
||||
"your_app_secret",
|
||||
"https://smlopenapi.esign.cn",
|
||||
"your_template_id",
|
||||
&EsignContractConfig{
|
||||
Name: "测试合同",
|
||||
ExpireDays: 30,
|
||||
RetryCount: 3,
|
||||
},
|
||||
&EsignAuthConfig{
|
||||
OrgAuthModes: []string{"ORG"},
|
||||
DefaultAuthMode: "ORG",
|
||||
PsnAuthModes: []string{"PSN"},
|
||||
WillingnessAuthModes: []string{"WILLINGNESS"},
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
},
|
||||
&EsignSignConfig{
|
||||
AutoFinish: true,
|
||||
SignFieldStyle: 1,
|
||||
ClientType: "ALL",
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal("配置创建失败:", err)
|
||||
@@ -122,7 +140,22 @@ func Example() {
|
||||
// ExampleBasicUsage 基础用法示例
|
||||
func ExampleBasicUsage() {
|
||||
// 最简单的用法 - 一行代码完成合同签署
|
||||
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id")
|
||||
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{
|
||||
Name: "测试合同",
|
||||
ExpireDays: 30,
|
||||
RetryCount: 3,
|
||||
}, &EsignAuthConfig{
|
||||
OrgAuthModes: []string{"ORG"},
|
||||
DefaultAuthMode: "ORG",
|
||||
PsnAuthModes: []string{"PSN"},
|
||||
WillingnessAuthModes: []string{"WILLINGNESS"},
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
}, &EsignSignConfig{
|
||||
AutoFinish: true,
|
||||
SignFieldStyle: 1,
|
||||
ClientType: "ALL",
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
})
|
||||
client := NewClient(config)
|
||||
|
||||
// 快速合同签署
|
||||
@@ -143,7 +176,22 @@ func ExampleBasicUsage() {
|
||||
|
||||
// ExampleWithCustomData 自定义数据示例
|
||||
func ExampleWithCustomData() {
|
||||
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id")
|
||||
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{
|
||||
Name: "测试合同",
|
||||
ExpireDays: 30,
|
||||
RetryCount: 3,
|
||||
}, &EsignAuthConfig{
|
||||
OrgAuthModes: []string{"ORG"},
|
||||
DefaultAuthMode: "ORG",
|
||||
PsnAuthModes: []string{"PSN"},
|
||||
WillingnessAuthModes: []string{"WILLINGNESS"},
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
}, &EsignSignConfig{
|
||||
AutoFinish: true,
|
||||
SignFieldStyle: 1,
|
||||
ClientType: "ALL",
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
})
|
||||
client := NewClient(config)
|
||||
|
||||
// 使用自定义模板数据
|
||||
@@ -171,7 +219,22 @@ func ExampleWithCustomData() {
|
||||
|
||||
// ExampleEnterpriseAuth 企业认证示例
|
||||
func ExampleEnterpriseAuth() {
|
||||
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id")
|
||||
config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{
|
||||
Name: "测试合同",
|
||||
ExpireDays: 30,
|
||||
RetryCount: 3,
|
||||
}, &EsignAuthConfig{
|
||||
OrgAuthModes: []string{"ORG"},
|
||||
DefaultAuthMode: "ORG",
|
||||
PsnAuthModes: []string{"PSN"},
|
||||
WillingnessAuthModes: []string{"WILLINGNESS"},
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
}, &EsignSignConfig{
|
||||
AutoFinish: true,
|
||||
SignFieldStyle: 1,
|
||||
ClientType: "ALL",
|
||||
RedirectUrl: "https://www.tianyuanapi.com/certification/callback",
|
||||
})
|
||||
client := NewClient(config)
|
||||
|
||||
// 企业认证
|
||||
|
||||
@@ -34,8 +34,8 @@ func (s *FileOpsService) UpdateConfig(config *Config) {
|
||||
func (s *FileOpsService) DownloadSignedFile(signFlowId string) (*DownloadSignedFileResponse, error) {
|
||||
fmt.Println("开始下载已签署文件及附属材料...")
|
||||
|
||||
// 发送API请求
|
||||
urlPath := fmt.Sprintf("/v3/sign-flow/%s/attachments", signFlowId)
|
||||
// 按照最新e签宝文档,接口路径应为 /v3/sign-flow/{signFlowId}/file-download-url
|
||||
urlPath := fmt.Sprintf("/v3/sign-flow/%s/file-download-url", signFlowId)
|
||||
responseBody, err := s.httpClient.Request("GET", urlPath, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载已签署文件失败: %v", err)
|
||||
|
||||
@@ -44,8 +44,8 @@ func (h *HTTPClient) UpdateConfig(config *Config) {
|
||||
func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error) {
|
||||
// 生成签名所需参数
|
||||
timestamp := getCurrentTimestamp()
|
||||
nonce := generateNonce()
|
||||
date := getCurrentDate()
|
||||
// date := getCurrentDate()
|
||||
date := ""
|
||||
|
||||
// 计算Content-MD5
|
||||
contentMD5 := ""
|
||||
@@ -53,14 +53,14 @@ func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error
|
||||
contentMD5 = getContentMD5(body)
|
||||
}
|
||||
|
||||
// 根据Java示例,Headers为空字符串
|
||||
headers := ""
|
||||
|
||||
// 生成签名
|
||||
// 生成签名(用原始urlPath)
|
||||
signature := generateSignature(h.config.AppSecret, method, "*/*", contentMD5, "application/json", date, headers, urlPath)
|
||||
|
||||
// 创建HTTP请求
|
||||
url := h.config.ServerURL + urlPath
|
||||
// 实际请求url用encode后的urlPath
|
||||
encodedURLPath := encodeURLQueryParams(urlPath)
|
||||
url := h.config.ServerURL + encodedURLPath
|
||||
req, err := http.NewRequest(method, url, bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建HTTP请求失败: %v", err)
|
||||
@@ -69,12 +69,11 @@ func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Content-MD5", contentMD5)
|
||||
req.Header.Set("Date", date)
|
||||
// req.Header.Set("Date", date)
|
||||
req.Header.Set("Accept", "*/*")
|
||||
req.Header.Set("X-Tsign-Open-App-Id", h.config.AppID)
|
||||
req.Header.Set("X-Tsign-Open-Auth-Mode", "Signature")
|
||||
req.Header.Set("X-Tsign-Open-Ca-Timestamp", timestamp)
|
||||
req.Header.Set("X-Tsign-Open-Nonce", nonce)
|
||||
req.Header.Set("X-Tsign-Open-Ca-Signature", signature)
|
||||
|
||||
// 发送请求
|
||||
@@ -197,3 +196,24 @@ func sortURLQueryParams(urlPath string) string {
|
||||
}
|
||||
return basePath
|
||||
}
|
||||
|
||||
// encodeURLQueryParams 对urlPath中的query参数值进行encode
|
||||
func encodeURLQueryParams(urlPath string) string {
|
||||
if !strings.Contains(urlPath, "?") {
|
||||
return urlPath
|
||||
}
|
||||
parts := strings.SplitN(urlPath, "?", 2)
|
||||
basePath := parts[0]
|
||||
queryString := parts[1]
|
||||
values, err := url.ParseQuery(queryString)
|
||||
if err != nil {
|
||||
return urlPath
|
||||
}
|
||||
var encodedPairs []string
|
||||
for key, vals := range values {
|
||||
for _, val := range vals {
|
||||
encodedPairs = append(encodedPairs, key+"="+url.QueryEscape(val))
|
||||
}
|
||||
}
|
||||
return basePath + "?" + strings.Join(encodedPairs, "&")
|
||||
}
|
||||
|
||||
@@ -66,6 +66,9 @@ func (s *OrgAuthService) GetAuthURL(req *OrgAuthRequest) (string, string, string
|
||||
},
|
||||
},
|
||||
ClientType: ClientTypeAll,
|
||||
RedirectConfig: &RedirectConfig{
|
||||
RedirectUrl: s.config.Auth.RedirectUrl,
|
||||
},
|
||||
}
|
||||
|
||||
// 序列化请求数据
|
||||
|
||||
@@ -43,7 +43,7 @@ func (s *SignFlowService) Create(req *CreateSignFlowRequest) (string, error) {
|
||||
Docs: []DocInfo{
|
||||
{
|
||||
FileId: req.FileID,
|
||||
FileName: "天远数据API合作协议.pdf",
|
||||
FileName: s.config.Contract.Name,
|
||||
},
|
||||
},
|
||||
SignFlowConfig: s.buildSignFlowConfig(),
|
||||
@@ -141,9 +141,9 @@ func (s *SignFlowService) buildPartyASigner(fileID string) SignerInfo {
|
||||
AutoSign: true,
|
||||
SignFieldStyle: SignFieldStyleNormal,
|
||||
SignFieldPosition: &SignFieldPosition{
|
||||
PositionPage: "1",
|
||||
PositionPage: "8",
|
||||
PositionX: 200,
|
||||
PositionY: 200,
|
||||
PositionY: 430,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -188,9 +188,9 @@ func (s *SignFlowService) buildPartyBSigner(fileID, signerAccount, signerName, t
|
||||
AutoSign: false,
|
||||
SignFieldStyle: SignFieldStyleNormal,
|
||||
SignFieldPosition: &SignFieldPosition{
|
||||
PositionPage: "1",
|
||||
PositionX: 458,
|
||||
PositionY: 200,
|
||||
PositionPage: "8",
|
||||
PositionX: 450,
|
||||
PositionY: 430,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -201,9 +201,9 @@ func (s *SignFlowService) buildPartyBSigner(fileID, signerAccount, signerName, t
|
||||
// buildSignFlowConfig 构建签署流程配置
|
||||
func (s *SignFlowService) buildSignFlowConfig() SignFlowConfig {
|
||||
return SignFlowConfig{
|
||||
SignFlowTitle: "天远数据API合作协议签署",
|
||||
SignFlowExpireTime: calculateExpireTime(7), // 7天后过期
|
||||
AutoFinish: true, // 所有签署方完成后自动完结
|
||||
SignFlowTitle: s.config.Contract.Name,
|
||||
SignFlowExpireTime: calculateExpireTime(s.config.Contract.ExpireDays),
|
||||
AutoFinish: s.config.Sign.AutoFinish,
|
||||
AuthConfig: &AuthConfig{
|
||||
PsnAvailableAuthModes: []string{AuthModeMobile3},
|
||||
WillingnessAuthModes: []string{WillingnessAuthSMS},
|
||||
@@ -211,5 +211,8 @@ func (s *SignFlowService) buildSignFlowConfig() SignFlowConfig {
|
||||
ContractConfig: &ContractConfig{
|
||||
AllowToRescind: false,
|
||||
},
|
||||
RedirectConfig: &RedirectConfig{
|
||||
RedirectUrl: s.config.Sign.RedirectUrl,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func (s *TemplateService) Fill(components []Component) (*FillTemplate, error) {
|
||||
fmt.Println("开始填写模板生成文件...")
|
||||
|
||||
// 生成带时间戳的文件名
|
||||
fileName := generateFileName("天远数据API合作协议", "pdf")
|
||||
fileName := generateFileName(s.config.Contract.Name, "pdf")
|
||||
|
||||
// 构建请求数据
|
||||
requestData := FillTemplateRequest{
|
||||
|
||||
@@ -5,10 +5,40 @@ import (
|
||||
"crypto/md5"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// appendSignDataString 拼接待签名字符串
|
||||
func appendSignDataString(httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters string) string {
|
||||
if accept == "" {
|
||||
accept = "*/*"
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/json; charset=UTF-8"
|
||||
}
|
||||
// 前四项
|
||||
signStr := httpMethod + "\n" + accept + "\n" + contentMD5 + "\n" + contentType + "\n"
|
||||
// 处理 date
|
||||
if date == "" {
|
||||
signStr += "\n"
|
||||
} else {
|
||||
signStr += date + "\n"
|
||||
}
|
||||
// 处理 headers
|
||||
if headers == "" {
|
||||
signStr += pathAndParameters
|
||||
} else {
|
||||
signStr += headers + "\n" + pathAndParameters
|
||||
}
|
||||
return signStr
|
||||
}
|
||||
|
||||
// generateSignature 生成e签宝API请求签名
|
||||
// 使用HMAC-SHA256算法对请求参数进行签名
|
||||
//
|
||||
@@ -24,8 +54,8 @@ import (
|
||||
//
|
||||
// 返回: Base64编码的签名字符串
|
||||
func generateSignature(appSecret, httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters string) string {
|
||||
// 构建待签名字符串,按照e签宝API规范拼接
|
||||
signStr := httpMethod + "\n" + accept + "\n" + contentMD5 + "\n" + contentType + "\n" + date + "\n" + headers + pathAndParameters
|
||||
// 构建待签名字符串,按照e签宝API规范拼接(兼容Python实现细节)
|
||||
signStr := appendSignDataString(httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters)
|
||||
|
||||
// 使用HMAC-SHA256计算签名
|
||||
h := hmac.New(sha256.New, []byte(appSecret))
|
||||
@@ -102,3 +132,66 @@ func generateFileName(baseName, extension string) string {
|
||||
func calculateExpireTime(days int) int64 {
|
||||
return time.Now().AddDate(0, 0, days).UnixMilli()
|
||||
}
|
||||
|
||||
// verifySignature 验证e签宝回调签名
|
||||
func VerifySignature(callbackData interface{}, headers map[string]string, queryParams map[string]string, appSecret string) error {
|
||||
// 1. 获取签名相关参数
|
||||
signature, ok := headers["X-Tsign-Open-Signature"]
|
||||
if !ok {
|
||||
return fmt.Errorf("缺少签名头: X-Tsign-Open-Signature")
|
||||
}
|
||||
|
||||
timestamp, ok := headers["X-Tsign-Open-Timestamp"]
|
||||
if !ok {
|
||||
return fmt.Errorf("缺少时间戳头: X-Tsign-Open-Timestamp")
|
||||
}
|
||||
|
||||
// 2. 构建查询参数字符串
|
||||
var queryKeys []string
|
||||
for key := range queryParams {
|
||||
queryKeys = append(queryKeys, key)
|
||||
}
|
||||
sort.Strings(queryKeys) // 按ASCII码升序排序
|
||||
|
||||
var queryValues []string
|
||||
for _, key := range queryKeys {
|
||||
queryValues = append(queryValues, queryParams[key])
|
||||
}
|
||||
queryString := strings.Join(queryValues, "")
|
||||
|
||||
// 3. 获取请求体数据
|
||||
bodyData, err := getRequestBodyString(callbackData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取请求体数据失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 构建验签数据
|
||||
data := timestamp + queryString + bodyData
|
||||
|
||||
// 5. 计算签名
|
||||
expectedSignature := calculateSignature(data, appSecret)
|
||||
|
||||
// 6. 比较签名
|
||||
if strings.ToLower(expectedSignature) != strings.ToLower(signature) {
|
||||
return fmt.Errorf("签名验证失败")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateSignature 计算HMAC-SHA256签名
|
||||
func calculateSignature(data, secret string) string {
|
||||
h := hmac.New(sha256.New, []byte(secret))
|
||||
h.Write([]byte(data))
|
||||
return strings.ToUpper(hex.EncodeToString(h.Sum(nil)))
|
||||
}
|
||||
|
||||
// getRequestBodyString 获取请求体字符串
|
||||
func getRequestBodyString(callbackData interface{}) (string, error) {
|
||||
// 将map转换为JSON字符串
|
||||
jsonBytes, err := json.Marshal(callbackData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("JSON序列化失败: %w", err)
|
||||
}
|
||||
return string(jsonBytes), nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// HTTPHandler HTTP处理器接口
|
||||
@@ -95,6 +96,11 @@ type RequestValidator interface {
|
||||
|
||||
// 绑定和验证
|
||||
BindAndValidate(c *gin.Context, dto interface{}) error
|
||||
|
||||
// 业务逻辑验证方法
|
||||
GetValidator() *validator.Validate
|
||||
ValidateValue(field interface{}, tag string) error
|
||||
ValidateStruct(s interface{}) error
|
||||
}
|
||||
|
||||
// PaginationMeta 分页元数据
|
||||
|
||||
50
internal/shared/middleware/api_auth.go
Normal file
50
internal/shared/middleware/api_auth.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ApiAuthMiddleware API认证中间件
|
||||
type ApiAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
}
|
||||
|
||||
// NewApiAuthMiddleware 创建API认证中间件
|
||||
func NewApiAuthMiddleware(cfg *config.Config, logger *zap.Logger, responseBuilder interfaces.ResponseBuilder) *ApiAuthMiddleware {
|
||||
return &ApiAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
responseBuilder: responseBuilder,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ApiAuthMiddleware) GetName() string {
|
||||
return "api_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ApiAuthMiddleware) GetPriority() int {
|
||||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ApiAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端IP地址,并存入上下文
|
||||
clientIP := c.ClientIP()
|
||||
c.Set("client_ip", clientIP)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ApiAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
72
internal/shared/middleware/domain.go
Normal file
72
internal/shared/middleware/domain.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DomainAuthMiddleware 域名认证中间件
|
||||
type DomainAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewDomainAuthMiddleware 创建域名认证中间件
|
||||
func NewDomainAuthMiddleware(cfg *config.Config, logger *zap.Logger) *DomainAuthMiddleware {
|
||||
return &DomainAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *DomainAuthMiddleware) GetName() string {
|
||||
return "domain_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *DomainAuthMiddleware) GetPriority() int {
|
||||
return 100
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *DomainAuthMiddleware) Handle(domain string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
// 开发环境下跳过外部验证
|
||||
if m.config.App.IsDevelopment() {
|
||||
m.logger.Info("开发环境:跳过域名验证",
|
||||
zap.String("domain", domain))
|
||||
c.Next()
|
||||
}
|
||||
if domain == "" {
|
||||
domain = m.config.API.Domain
|
||||
}
|
||||
host := c.Request.Host
|
||||
|
||||
// 移除端口部分
|
||||
if idx := strings.Index(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
|
||||
if host == domain {
|
||||
// 设置域名匹配标记
|
||||
c.Set("domainMatched", domain)
|
||||
c.Next()
|
||||
} else {
|
||||
// 不匹配域名,跳过当前组处理,继续执行其他路由
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *DomainAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
282
internal/shared/payment/alipay.go
Normal file
282
internal/shared/payment/alipay.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package payment
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/smartwalle/alipay/v3"
|
||||
)
|
||||
|
||||
type AlipayConfig struct {
|
||||
AppID string
|
||||
PrivateKey string
|
||||
AlipayPublicKey string
|
||||
IsProduction bool
|
||||
NotifyUrl string
|
||||
ReturnURL string // 同步回调地址
|
||||
}
|
||||
type AliPayService struct {
|
||||
config AlipayConfig
|
||||
AlipayClient *alipay.Client
|
||||
}
|
||||
|
||||
// NewAliPayService 是一个构造函数,用于初始化 AliPayService
|
||||
func NewAliPayService(config AlipayConfig) *AliPayService {
|
||||
client, err := alipay.New(config.AppID, config.PrivateKey, config.IsProduction)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("创建支付宝客户端失败: %v", err))
|
||||
}
|
||||
|
||||
// 加载支付宝公钥
|
||||
err = client.LoadAliPayPublicKey(config.AlipayPublicKey)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("加载支付宝公钥失败: %v", err))
|
||||
}
|
||||
return &AliPayService{
|
||||
config: config,
|
||||
AlipayClient: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AliPayService) CreateAlipayAppOrder(amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
|
||||
client := a.AlipayClient
|
||||
totalAmount := amount.StringFixed(2) // 保留2位小数
|
||||
// 构造移动支付请求
|
||||
p := alipay.TradeAppPay{
|
||||
Trade: alipay.Trade{
|
||||
Subject: subject,
|
||||
OutTradeNo: outTradeNo,
|
||||
TotalAmount: totalAmount,
|
||||
ProductCode: "QUICK_MSECURITY_PAY", // 移动端支付专用代码
|
||||
NotifyURL: a.config.NotifyUrl, // 异步回调通知地址
|
||||
},
|
||||
}
|
||||
|
||||
// 获取APP支付字符串,这里会签名
|
||||
payStr, err := client.TradeAppPay(p)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建支付宝订单失败: %v", err)
|
||||
}
|
||||
|
||||
return payStr, nil
|
||||
}
|
||||
|
||||
// CreateAlipayH5Order 创建支付宝H5支付订单
|
||||
func (a *AliPayService) CreateAlipayH5Order(amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
|
||||
client := a.AlipayClient
|
||||
totalAmount := amount.StringFixed(2) // 保留2位小数
|
||||
// 构造H5支付请求
|
||||
p := alipay.TradeWapPay{
|
||||
Trade: alipay.Trade{
|
||||
Subject: subject,
|
||||
OutTradeNo: outTradeNo,
|
||||
TotalAmount: totalAmount,
|
||||
ProductCode: "QUICK_WAP_PAY", // H5支付专用产品码
|
||||
NotifyURL: a.config.NotifyUrl, // 异步回调通知地址
|
||||
ReturnURL: a.config.ReturnURL,
|
||||
},
|
||||
}
|
||||
// 获取H5支付请求字符串,这里会签名
|
||||
payUrl, err := client.TradeWapPay(p)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建支付宝H5订单失败: %v", err)
|
||||
}
|
||||
|
||||
return payUrl.String(), nil
|
||||
}
|
||||
|
||||
// CreateAlipayPCOrder 创建支付宝PC端支付订单
|
||||
func (a *AliPayService) CreateAlipayPCOrder(amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
|
||||
client := a.AlipayClient
|
||||
totalAmount := amount.StringFixed(2) // 保留2位小数
|
||||
|
||||
// 构造PC端支付请求
|
||||
p := alipay.TradePagePay{
|
||||
Trade: alipay.Trade{
|
||||
Subject: subject,
|
||||
OutTradeNo: outTradeNo,
|
||||
TotalAmount: totalAmount,
|
||||
ProductCode: "FAST_INSTANT_TRADE_PAY", // PC端支付专用产品码
|
||||
NotifyURL: a.config.NotifyUrl, // 异步回调通知地址
|
||||
ReturnURL: a.config.ReturnURL, // 同步回调地址
|
||||
},
|
||||
}
|
||||
|
||||
// 获取PC端支付URL,这里会签名
|
||||
payUrl, err := client.TradePagePay(p)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建支付宝PC端订单失败: %v", err)
|
||||
}
|
||||
|
||||
return payUrl.String(), nil
|
||||
}
|
||||
|
||||
// CreateAlipayOrder 根据平台类型创建支付宝支付订单
|
||||
func (a *AliPayService) CreateAlipayOrder(ctx context.Context, platform string, amount decimal.Decimal, subject string, outTradeNo string) (string, error) {
|
||||
switch platform {
|
||||
case "app":
|
||||
// 调用App支付的创建方法
|
||||
return a.CreateAlipayAppOrder(amount, subject, outTradeNo)
|
||||
case "h5":
|
||||
// 调用H5支付的创建方法,并传入 returnUrl
|
||||
return a.CreateAlipayH5Order(amount, subject, outTradeNo)
|
||||
case "pc":
|
||||
// 调用PC端支付的创建方法
|
||||
return a.CreateAlipayPCOrder(amount, subject, outTradeNo)
|
||||
default:
|
||||
return "", fmt.Errorf("不支持的支付平台: %s", platform)
|
||||
}
|
||||
}
|
||||
|
||||
// AliRefund 发起支付宝退款
|
||||
func (a *AliPayService) AliRefund(ctx context.Context, outTradeNo string, refundAmount decimal.Decimal) (*alipay.TradeRefundRsp, error) {
|
||||
refund := alipay.TradeRefund{
|
||||
OutTradeNo: outTradeNo,
|
||||
RefundAmount: refundAmount.StringFixed(2), // 保留2位小数
|
||||
OutRequestNo: fmt.Sprintf("%s-refund", outTradeNo),
|
||||
}
|
||||
|
||||
// 发起退款请求
|
||||
refundResp, err := a.AlipayClient.TradeRefund(ctx, refund)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("支付宝退款请求错误:%v", err)
|
||||
}
|
||||
return refundResp, nil
|
||||
}
|
||||
|
||||
// HandleAliPaymentNotification 支付宝支付回调
|
||||
func (a *AliPayService) HandleAliPaymentNotification(r *http.Request) (*alipay.Notification, error) {
|
||||
// 解析表单
|
||||
err := r.ParseForm()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("解析请求表单失败:%v", err)
|
||||
}
|
||||
// 解析并验证通知,DecodeNotification 会自动验证签名
|
||||
notification, err := a.AlipayClient.DecodeNotification(r.Form)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("验证签名失败: %v", err)
|
||||
}
|
||||
return notification, nil
|
||||
}
|
||||
|
||||
func (a *AliPayService) IsAlipayPaymentSuccess(notification *alipay.Notification) bool {
|
||||
return notification.TradeStatus == alipay.TradeStatusSuccess
|
||||
}
|
||||
|
||||
func (a *AliPayService) QueryOrderStatus(ctx context.Context, outTradeNo string) (*alipay.TradeQueryRsp, error) {
|
||||
queryRequest := alipay.TradeQuery{
|
||||
OutTradeNo: outTradeNo,
|
||||
}
|
||||
|
||||
// 发起查询请求
|
||||
resp, err := a.AlipayClient.TradeQuery(ctx, queryRequest)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询支付宝订单失败: %v", err)
|
||||
}
|
||||
|
||||
// 返回交易状态
|
||||
if resp.IsSuccess() {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("查询支付宝订单失败: %v", resp.SubMsg)
|
||||
}
|
||||
|
||||
// 添加全局原子计数器
|
||||
var alipayOrderCounter uint32 = 0
|
||||
|
||||
// GenerateOutTradeNo 生成唯一订单号的函数 - 优化版本
|
||||
func (a *AliPayService) GenerateOutTradeNo() string {
|
||||
|
||||
// 获取当前时间戳(毫秒级)
|
||||
timestamp := time.Now().UnixMilli()
|
||||
timeStr := strconv.FormatInt(timestamp, 10)
|
||||
|
||||
// 原子递增计数器
|
||||
counter := atomic.AddUint32(&alipayOrderCounter, 1)
|
||||
|
||||
// 生成4字节真随机数
|
||||
randomBytes := make([]byte, 4)
|
||||
_, err := rand.Read(randomBytes)
|
||||
if err != nil {
|
||||
// 如果随机数生成失败,回退到使用时间纳秒数据
|
||||
randomBytes = []byte(strconv.FormatInt(time.Now().UnixNano()%1000000, 16))
|
||||
}
|
||||
randomHex := hex.EncodeToString(randomBytes)
|
||||
|
||||
// 组合所有部分: 前缀 + 时间戳 + 计数器 + 随机数
|
||||
orderNo := fmt.Sprintf("%s%06x%s", timeStr[:10], counter%0xFFFFFF, randomHex[:6])
|
||||
|
||||
// 确保长度不超过32字符(大多数支付平台的限制)
|
||||
if len(orderNo) > 32 {
|
||||
orderNo = orderNo[:32]
|
||||
}
|
||||
|
||||
return orderNo
|
||||
}
|
||||
|
||||
// AliTransfer 支付宝单笔转账到支付宝账户(提现功能)
|
||||
func (a *AliPayService) AliTransfer(
|
||||
ctx context.Context,
|
||||
payeeAccount string, // 收款方支付宝账户
|
||||
payeeName string, // 收款方姓名
|
||||
amount decimal.Decimal, // 转账金额
|
||||
remark string, // 转账备注
|
||||
outBizNo string, // 商户转账唯一订单号(可使用GenerateOutTradeNo生成)
|
||||
) (*alipay.FundTransUniTransferRsp, error) {
|
||||
// 参数校验
|
||||
if payeeAccount == "" {
|
||||
return nil, fmt.Errorf("收款账户不能为空")
|
||||
}
|
||||
if amount.LessThanOrEqual(decimal.Zero) {
|
||||
return nil, fmt.Errorf("转账金额必须大于0")
|
||||
}
|
||||
|
||||
// 构造转账请求
|
||||
req := alipay.FundTransUniTransfer{
|
||||
OutBizNo: outBizNo,
|
||||
TransAmount: amount.StringFixed(2), // 保留2位小数
|
||||
ProductCode: "TRANS_ACCOUNT_NO_PWD", // 单笔无密转账到支付宝账户
|
||||
BizScene: "DIRECT_TRANSFER", // 单笔转账
|
||||
OrderTitle: "账户提现", // 转账标题
|
||||
Remark: remark,
|
||||
PayeeInfo: &alipay.PayeeInfo{
|
||||
Identity: payeeAccount,
|
||||
IdentityType: "ALIPAY_LOGON_ID", // 根据账户类型选择:
|
||||
Name: payeeName,
|
||||
// ALIPAY_USER_ID/ALIPAY_LOGON_ID
|
||||
},
|
||||
}
|
||||
|
||||
// 执行转账请求
|
||||
transferRsp, err := a.AlipayClient.FundTransUniTransfer(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("支付宝转账请求失败: %v", err)
|
||||
}
|
||||
|
||||
return transferRsp, nil
|
||||
}
|
||||
func (a *AliPayService) QueryTransferStatus(
|
||||
ctx context.Context,
|
||||
outBizNo string,
|
||||
) (*alipay.FundTransOrderQueryRsp, error) {
|
||||
req := alipay.FundTransOrderQuery{
|
||||
OutBizNo: outBizNo,
|
||||
}
|
||||
response, err := a.AlipayClient.FundTransOrderQuery(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("支付宝接口调用失败: %v", err)
|
||||
}
|
||||
// 处理响应
|
||||
if response.Code.IsFailure() {
|
||||
return nil, fmt.Errorf("支付宝返回错误: %s-%s", response.Code, response.Msg)
|
||||
}
|
||||
return response, nil
|
||||
}
|
||||
246
internal/shared/validator/AUTH_DATE_VALIDATOR.md
Normal file
246
internal/shared/validator/AUTH_DATE_VALIDATOR.md
Normal file
@@ -0,0 +1,246 @@
|
||||
# AuthDate 授权日期验证器
|
||||
|
||||
## 概述
|
||||
|
||||
`authDate` 是一个自定义验证器,用于验证授权日期格式和有效性。该验证器确保日期格式正确,并且日期范围必须包括今天。
|
||||
|
||||
## 验证规则
|
||||
|
||||
### 1. 格式要求
|
||||
- 必须为 `YYYYMMDD-YYYYMMDD` 格式
|
||||
- 两个日期之间用连字符 `-` 分隔
|
||||
- 每个日期必须是8位数字
|
||||
|
||||
### 2. 日期有效性
|
||||
- 开始日期不能晚于结束日期
|
||||
- 日期范围必须包括今天(如果两个日期都是今天也行)
|
||||
- 支持闰年验证
|
||||
- 验证月份和日期的有效性
|
||||
|
||||
### 3. 业务逻辑
|
||||
- 空值由 `required` 标签处理,本验证器返回 `true`
|
||||
- 日期范围必须覆盖今天,确保授权在有效期内
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 在 DTO 中使用
|
||||
|
||||
```go
|
||||
type FLXG0V4BReq struct {
|
||||
Name string `json:"name" validate:"required,name"`
|
||||
IDCard string `json:"id_card" validate:"required,idCard"`
|
||||
AuthDate string `json:"auth_date" validate:"required,authDate"`
|
||||
}
|
||||
```
|
||||
|
||||
### 在结构体中使用
|
||||
|
||||
```go
|
||||
type AuthorizationRequest struct {
|
||||
UserID string `json:"user_id" validate:"required"`
|
||||
AuthDate string `json:"auth_date" validate:"required,authDate"`
|
||||
Scope string `json:"scope" validate:"required"`
|
||||
}
|
||||
```
|
||||
|
||||
## 有效示例
|
||||
|
||||
### ✅ 有效的日期范围
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240101-20240131" // 1月1日到1月31日(如果今天是1月15日)
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240115-20240115" // 今天到今天
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240110-20240120" // 昨天到明天(如果今天是1月15日)
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240101-20240201" // 上个月到下个月(如果今天是1月15日)
|
||||
}
|
||||
```
|
||||
|
||||
### ❌ 无效的日期范围
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240116-20240120" // 明天到后天(不包括今天)
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240101-20240114" // 上个月到昨天(不包括今天)
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240131-20240101" // 开始日期晚于结束日期
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240101-2024013A" // 非数字字符
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "202401-20240131" // 日期长度不对
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "2024010120240131" // 缺少连字符
|
||||
}
|
||||
```
|
||||
|
||||
```json
|
||||
{
|
||||
"auth_date": "20240230-20240301" // 无效日期(2月30日)
|
||||
}
|
||||
```
|
||||
|
||||
## 错误消息
|
||||
|
||||
当验证失败时,会返回中文错误消息:
|
||||
|
||||
```
|
||||
"授权日期格式不正确,必须是YYYYMMDD-YYYYMMDD格式,且日期范围必须包括今天"
|
||||
```
|
||||
|
||||
## 测试用例
|
||||
|
||||
验证器包含完整的测试用例,覆盖以下场景:
|
||||
|
||||
### 有效场景
|
||||
- 今天到今天
|
||||
- 昨天到今天
|
||||
- 今天到明天
|
||||
- 上周到今天
|
||||
- 今天到下周
|
||||
- 昨天到明天
|
||||
|
||||
### 无效场景
|
||||
- 明天到后天(不包括今天)
|
||||
- 上周到昨天(不包括今天)
|
||||
- 格式错误(缺少连字符、多个连字符、长度不对、非数字)
|
||||
- 无效日期(2月30日、13月等)
|
||||
- 开始日期晚于结束日期
|
||||
|
||||
## 实现细节
|
||||
|
||||
### 核心验证逻辑
|
||||
|
||||
```go
|
||||
func validateAuthDate(fl validator.FieldLevel) bool {
|
||||
authDate := fl.Field().String()
|
||||
if authDate == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 1. 检查格式:YYYYMMDD-YYYYMMDD
|
||||
parts := strings.Split(authDate, "-")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 2. 解析日期
|
||||
startDate, err := parseYYYYMMDD(parts[0])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
endDate, err := parseYYYYMMDD(parts[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 3. 检查日期顺序
|
||||
if startDate.After(endDate) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 4. 检查是否包括今天
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
return !startDate.After(today) && !endDate.Before(today)
|
||||
}
|
||||
```
|
||||
|
||||
### 日期解析
|
||||
|
||||
```go
|
||||
func parseYYYYMMDD(dateStr string) (time.Time, error) {
|
||||
if len(dateStr) != 8 {
|
||||
return time.Time{}, fmt.Errorf("日期格式错误")
|
||||
}
|
||||
|
||||
year, _ := strconv.Atoi(dateStr[:4])
|
||||
month, _ := strconv.Atoi(dateStr[4:6])
|
||||
day, _ := strconv.Atoi(dateStr[6:8])
|
||||
|
||||
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// 验证日期有效性
|
||||
expectedDateStr := date.Format("20060102")
|
||||
if expectedDateStr != dateStr {
|
||||
return time.Time{}, fmt.Errorf("无效日期")
|
||||
}
|
||||
|
||||
return date, nil
|
||||
}
|
||||
```
|
||||
|
||||
## 注册方式
|
||||
|
||||
验证器已在 `RegisterCustomValidators` 函数中自动注册:
|
||||
|
||||
```go
|
||||
func RegisterCustomValidators(validate *validator.Validate) {
|
||||
// ... 其他验证器
|
||||
validate.RegisterValidation("auth_date", validateAuthDate)
|
||||
}
|
||||
```
|
||||
|
||||
翻译也已自动注册:
|
||||
|
||||
```go
|
||||
validate.RegisterTranslation("auth_date", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("auth_date", "{0}格式不正确,必须是YYYYMMDD-YYYYMMDD格式,且日期范围必须包括今天", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("auth_date", getFieldDisplayName(fe.Field()))
|
||||
return t
|
||||
})
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **时区处理**:验证器使用 UTC 时区进行日期比较
|
||||
2. **空值处理**:空字符串由 `required` 标签处理,本验证器返回 `true`
|
||||
3. **日期精度**:只比较日期部分,忽略时间部分
|
||||
4. **闰年支持**:自动处理闰年验证
|
||||
5. **错误消息**:提供中文错误消息,便于用户理解
|
||||
|
||||
## 运行测试
|
||||
|
||||
```bash
|
||||
# 运行所有 authDate 相关测试
|
||||
go test ./internal/shared/validator -v -run TestValidateAuthDate
|
||||
|
||||
# 运行所有验证器测试
|
||||
go test ./internal/shared/validator -v
|
||||
```
|
||||
180
internal/shared/validator/auth_date_test.go
Normal file
180
internal/shared/validator/auth_date_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
func TestValidateAuthDate(t *testing.T) {
|
||||
validate := validator.New()
|
||||
validate.RegisterValidation("auth_date", validateAuthDate)
|
||||
|
||||
today := time.Now().Format("20060102")
|
||||
yesterday := time.Now().AddDate(0, 0, -1).Format("20060102")
|
||||
tomorrow := time.Now().AddDate(0, 0, 1).Format("20060102")
|
||||
lastWeek := time.Now().AddDate(0, 0, -7).Format("20060102")
|
||||
nextWeek := time.Now().AddDate(0, 0, 7).Format("20060102")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authDate string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "今天到今天 - 有效",
|
||||
authDate: today + "-" + today,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "昨天到今天 - 有效",
|
||||
authDate: yesterday + "-" + today,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "今天到明天 - 有效",
|
||||
authDate: today + "-" + tomorrow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "上周到今天 - 有效",
|
||||
authDate: lastWeek + "-" + today,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "今天到下周 - 有效",
|
||||
authDate: today + "-" + nextWeek,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "昨天到明天 - 有效",
|
||||
authDate: yesterday + "-" + tomorrow,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "明天到后天 - 无效(不包括今天)",
|
||||
authDate: tomorrow + "-" + time.Now().AddDate(0, 0, 2).Format("20060102"),
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "上周到昨天 - 无效(不包括今天)",
|
||||
authDate: lastWeek + "-" + yesterday,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "格式错误 - 缺少连字符",
|
||||
authDate: "2024010120240131",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "格式错误 - 多个连字符",
|
||||
authDate: "20240101-20240131-20240201",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "格式错误 - 日期长度不对",
|
||||
authDate: "202401-20240131",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "格式错误 - 非数字",
|
||||
authDate: "20240101-2024013A",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效日期 - 2月30日",
|
||||
authDate: "20240230-20240301",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效日期 - 13月",
|
||||
authDate: "20241301-20241331",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "开始日期晚于结束日期",
|
||||
authDate: "20240131-20240101",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "空字符串 - 由required处理",
|
||||
authDate: "",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validate.Var(tt.authDate, "auth_date")
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("validateAuthDate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseYYYYMMDD(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dateStr string
|
||||
want time.Time
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "有效日期",
|
||||
dateStr: "20240101",
|
||||
want: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "闰年2月29日",
|
||||
dateStr: "20240229",
|
||||
want: time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "非闰年2月29日",
|
||||
dateStr: "20230229",
|
||||
want: time.Time{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "长度错误",
|
||||
dateStr: "202401",
|
||||
want: time.Time{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "非数字",
|
||||
dateStr: "2024010A",
|
||||
want: time.Time{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效月份",
|
||||
dateStr: "20241301",
|
||||
want: time.Time{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "无效日期",
|
||||
dateStr: "20240230",
|
||||
want: time.Time{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := parseYYYYMMDD(tt.dateStr)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseYYYYMMDD() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.wantErr && !got.Equal(tt.want) {
|
||||
t.Errorf("parseYYYYMMDD() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,12 @@
|
||||
package validator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
@@ -38,6 +42,18 @@ func RegisterCustomValidators(validate *validator.Validate) {
|
||||
|
||||
// URL验证器
|
||||
validate.RegisterValidation("url", validateURL)
|
||||
|
||||
// 企业邮箱验证器
|
||||
validate.RegisterValidation("enterprise_email", validateEnterpriseEmail)
|
||||
|
||||
// 企业地址验证器
|
||||
validate.RegisterValidation("enterprise_address", validateEnterpriseAddress)
|
||||
|
||||
// IP地址验证器
|
||||
validate.RegisterValidation("ip", validateIP)
|
||||
|
||||
// 授权日期验证器
|
||||
validate.RegisterValidation("auth_date", validateAuthDate)
|
||||
}
|
||||
|
||||
// validatePhone 手机号验证
|
||||
@@ -111,4 +127,113 @@ func validateURL(fl validator.FieldLevel) bool {
|
||||
urlStr := fl.Field().String()
|
||||
_, err := url.ParseRequestURI(urlStr)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// validateEnterpriseEmail 企业邮箱验证
|
||||
func validateEnterpriseEmail(fl validator.FieldLevel) bool {
|
||||
email := fl.Field().String()
|
||||
// 邮箱格式验证:用户名@域名.顶级域名
|
||||
matched, _ := regexp.MatchString(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`, email)
|
||||
return matched
|
||||
}
|
||||
|
||||
// validateEnterpriseAddress 企业地址验证
|
||||
func validateEnterpriseAddress(fl validator.FieldLevel) bool {
|
||||
address := fl.Field().String()
|
||||
// 地址长度验证:2-200字符,不能只包含空格
|
||||
if len(strings.TrimSpace(address)) < 2 || len(address) > 200 {
|
||||
return false
|
||||
}
|
||||
// 地址不能只包含特殊字符
|
||||
matched, _ := regexp.MatchString(`^[^\s]+.*[^\s]+$`, strings.TrimSpace(address))
|
||||
return matched
|
||||
}
|
||||
|
||||
// validateIP IP地址验证(支持IPv4)
|
||||
func validateIP(fl validator.FieldLevel) bool {
|
||||
ip := fl.Field().String()
|
||||
// 使用正则表达式验证IPv4格式
|
||||
pattern := `^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$`
|
||||
matched, _ := regexp.MatchString(pattern, ip)
|
||||
return matched
|
||||
}
|
||||
|
||||
// validateAuthDate 授权日期验证器
|
||||
// 格式:YYYYMMDD-YYYYMMDD,之前的日期范围必须包括今天
|
||||
func validateAuthDate(fl validator.FieldLevel) bool {
|
||||
authDate := fl.Field().String()
|
||||
if authDate == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 检查格式:YYYYMMDD-YYYYMMDD
|
||||
parts := strings.Split(authDate, "-")
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
startDateStr := parts[0]
|
||||
endDateStr := parts[1]
|
||||
|
||||
// 检查日期格式是否为8位数字
|
||||
if len(startDateStr) != 8 || len(endDateStr) != 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 解析开始日期
|
||||
startDate, err := parseYYYYMMDD(startDateStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 解析结束日期
|
||||
endDate, err := parseYYYYMMDD(endDateStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查开始日期不能晚于结束日期
|
||||
if startDate.After(endDate) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取今天的日期(去掉时间部分)
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
|
||||
// 检查日期范围是否包括今天
|
||||
// 如果两个日期都是今天也行
|
||||
return !startDate.After(today) && !endDate.Before(today)
|
||||
}
|
||||
|
||||
// parseYYYYMMDD 解析YYYYMMDD格式的日期字符串
|
||||
func parseYYYYMMDD(dateStr string) (time.Time, error) {
|
||||
if len(dateStr) != 8 {
|
||||
return time.Time{}, fmt.Errorf("日期格式错误")
|
||||
}
|
||||
|
||||
year, err := strconv.Atoi(dateStr[:4])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
month, err := strconv.Atoi(dateStr[4:6])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
day, err := strconv.Atoi(dateStr[6:8])
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
// 验证日期有效性
|
||||
date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
// 检查解析后的日期是否与输入一致(防止无效日期如20230230)
|
||||
expectedDateStr := date.Format("20060102")
|
||||
if expectedDateStr != dateStr {
|
||||
return time.Time{}, fmt.Errorf("无效日期")
|
||||
}
|
||||
|
||||
return date, nil
|
||||
}
|
||||
@@ -162,6 +162,38 @@ func registerCustomFieldTranslations(validate *validator.Validate, trans ut.Tran
|
||||
t, _ := ut.T("url", getFieldDisplayName(fe.Field()))
|
||||
return t
|
||||
})
|
||||
|
||||
// 企业邮箱翻译
|
||||
validate.RegisterTranslation("enterprise_email", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("enterprise_email", "{0}必须是有效的企业邮箱地址", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("enterprise_email", getFieldDisplayName(fe.Field()))
|
||||
return t
|
||||
})
|
||||
|
||||
// 企业地址翻译
|
||||
validate.RegisterTranslation("enterprise_address", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("enterprise_address", "{0}长度必须在2-200字符之间,且不能只包含空格", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("enterprise_address", getFieldDisplayName(fe.Field()))
|
||||
return t
|
||||
})
|
||||
|
||||
// IP地址翻译
|
||||
validate.RegisterTranslation("ip", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("ip", "{0}必须是有效的IPv4地址格式", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("ip", getFieldDisplayName(fe.Field()))
|
||||
return t
|
||||
})
|
||||
|
||||
// 授权日期翻译
|
||||
validate.RegisterTranslation("auth_date", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("auth_date", "{0}格式不正确,必须是YYYYMMDD-YYYYMMDD格式,且日期范围必须包括今天", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("auth_date", getFieldDisplayName(fe.Field()))
|
||||
return t
|
||||
})
|
||||
}
|
||||
|
||||
// getFieldDisplayName 获取字段显示名称(中文)
|
||||
@@ -176,6 +208,9 @@ func getFieldDisplayName(field string) string {
|
||||
"code": "验证码",
|
||||
"username": "用户名",
|
||||
"email": "邮箱",
|
||||
"enterprise_email": "企业邮箱",
|
||||
"enterprise_address": "企业地址",
|
||||
"ip_address": "IP地址",
|
||||
"display_name": "显示名称",
|
||||
"scene": "使用场景",
|
||||
"Password": "密码",
|
||||
@@ -244,6 +279,10 @@ func getFieldDisplayName(field string) string {
|
||||
"ID": "ID",
|
||||
"ids": "ID列表",
|
||||
"IDs": "ID列表",
|
||||
"auth_date": "授权日期",
|
||||
"AuthDate": "授权日期",
|
||||
"id_card": "身份证号",
|
||||
"IDCard": "身份证号",
|
||||
}
|
||||
|
||||
if displayName, exists := fieldNames[field]; exists {
|
||||
|
||||
@@ -65,7 +65,7 @@ func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -213,4 +213,4 @@ func (v *RequestValidator) ValidateValue(field interface{}, tag string) error {
|
||||
// ValidateStruct 验证结构体(用于业务逻辑)
|
||||
func (v *RequestValidator) ValidateStruct(s interface{}) error {
|
||||
return v.validator.Struct(s)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user