v0.1
This commit is contained in:
		
							
								
								
									
										194
									
								
								internal/shared/cache/cache_config_manager.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										194
									
								
								internal/shared/cache/cache_config_manager.go
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,194 @@ | ||||
| package cache | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| // CacheConfigManager 缓存配置管理器 | ||||
| // 提供全局缓存配置管理和表级别的缓存决策 | ||||
| type CacheConfigManager struct { | ||||
| 	config CacheConfig | ||||
| 	mutex  sync.RWMutex | ||||
| } | ||||
|  | ||||
| // GlobalCacheConfigManager 全局缓存配置管理器实例 | ||||
| var GlobalCacheConfigManager *CacheConfigManager | ||||
|  | ||||
| // InitCacheConfigManager 初始化全局缓存配置管理器 | ||||
| func InitCacheConfigManager(config CacheConfig) { | ||||
| 	GlobalCacheConfigManager = &CacheConfigManager{ | ||||
| 		config: config, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // GetCacheConfig 获取当前缓存配置 | ||||
| func (m *CacheConfigManager) GetCacheConfig() CacheConfig { | ||||
| 	m.mutex.RLock() | ||||
| 	defer m.mutex.RUnlock() | ||||
| 	return m.config | ||||
| } | ||||
|  | ||||
| // UpdateCacheConfig 更新缓存配置 | ||||
| func (m *CacheConfigManager) UpdateCacheConfig(config CacheConfig) { | ||||
| 	m.mutex.Lock() | ||||
| 	defer m.mutex.Unlock() | ||||
| 	m.config = config | ||||
| } | ||||
|  | ||||
| // IsTableCacheEnabled 检查表是否启用缓存 | ||||
| func (m *CacheConfigManager) IsTableCacheEnabled(tableName string) bool { | ||||
| 	m.mutex.RLock() | ||||
| 	defer m.mutex.RUnlock() | ||||
|  | ||||
| 	// 检查表是否在禁用列表中 | ||||
| 	for _, disabledTable := range m.config.DisabledTables { | ||||
| 		if disabledTable == tableName { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 如果配置了启用列表,只对启用列表中的表启用缓存 | ||||
| 	if len(m.config.EnabledTables) > 0 { | ||||
| 		for _, enabledTable := range m.config.EnabledTables { | ||||
| 			if enabledTable == tableName { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // IsTableCacheDisabled 检查表是否禁用缓存 | ||||
| func (m *CacheConfigManager) IsTableCacheDisabled(tableName string) bool { | ||||
| 	return !m.IsTableCacheEnabled(tableName) | ||||
| } | ||||
|  | ||||
| // GetEnabledTables 获取启用缓存的表列表 | ||||
| func (m *CacheConfigManager) GetEnabledTables() []string { | ||||
| 	m.mutex.RLock() | ||||
| 	defer m.mutex.RUnlock() | ||||
| 	return m.config.EnabledTables | ||||
| } | ||||
|  | ||||
| // GetDisabledTables 获取禁用缓存的表列表 | ||||
| func (m *CacheConfigManager) GetDisabledTables() []string { | ||||
| 	m.mutex.RLock() | ||||
| 	defer m.mutex.RUnlock() | ||||
| 	return m.config.DisabledTables | ||||
| } | ||||
|  | ||||
| // AddEnabledTable 添加启用缓存的表 | ||||
| func (m *CacheConfigManager) AddEnabledTable(tableName string) { | ||||
| 	m.mutex.Lock() | ||||
| 	defer m.mutex.Unlock() | ||||
| 	 | ||||
| 	// 检查是否已存在 | ||||
| 	for _, table := range m.config.EnabledTables { | ||||
| 		if table == tableName { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	 | ||||
| 	m.config.EnabledTables = append(m.config.EnabledTables, tableName) | ||||
| } | ||||
|  | ||||
| // AddDisabledTable 添加禁用缓存的表 | ||||
| func (m *CacheConfigManager) AddDisabledTable(tableName string) { | ||||
| 	m.mutex.Lock() | ||||
| 	defer m.mutex.Unlock() | ||||
| 	 | ||||
| 	// 检查是否已存在 | ||||
| 	for _, table := range m.config.DisabledTables { | ||||
| 		if table == tableName { | ||||
| 			return | ||||
| 		} | ||||
| 	} | ||||
| 	 | ||||
| 	m.config.DisabledTables = append(m.config.DisabledTables, tableName) | ||||
| } | ||||
|  | ||||
| // RemoveEnabledTable 移除启用缓存的表 | ||||
| func (m *CacheConfigManager) RemoveEnabledTable(tableName string) { | ||||
| 	m.mutex.Lock() | ||||
| 	defer m.mutex.Unlock() | ||||
| 	 | ||||
| 	var newEnabledTables []string | ||||
| 	for _, table := range m.config.EnabledTables { | ||||
| 		if table != tableName { | ||||
| 			newEnabledTables = append(newEnabledTables, table) | ||||
| 		} | ||||
| 	} | ||||
| 	m.config.EnabledTables = newEnabledTables | ||||
| } | ||||
|  | ||||
| // RemoveDisabledTable 移除禁用缓存的表 | ||||
| func (m *CacheConfigManager) RemoveDisabledTable(tableName string) { | ||||
| 	m.mutex.Lock() | ||||
| 	defer m.mutex.Unlock() | ||||
| 	 | ||||
| 	var newDisabledTables []string | ||||
| 	for _, table := range m.config.DisabledTables { | ||||
| 		if table != tableName { | ||||
| 			newDisabledTables = append(newDisabledTables, table) | ||||
| 		} | ||||
| 	} | ||||
| 	m.config.DisabledTables = newDisabledTables | ||||
| } | ||||
|  | ||||
| // GetTableCacheStatus 获取表的缓存状态信息 | ||||
| func (m *CacheConfigManager) GetTableCacheStatus(tableName string) map[string]interface{} { | ||||
| 	m.mutex.RLock() | ||||
| 	defer m.mutex.RUnlock() | ||||
| 	 | ||||
| 	status := map[string]interface{}{ | ||||
| 		"table_name": tableName, | ||||
| 		"enabled":    m.IsTableCacheEnabled(tableName), | ||||
| 		"disabled":   m.IsTableCacheDisabled(tableName), | ||||
| 	} | ||||
| 	 | ||||
| 	// 检查是否在启用列表中 | ||||
| 	for _, table := range m.config.EnabledTables { | ||||
| 		if table == tableName { | ||||
| 			status["in_enabled_list"] = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	 | ||||
| 	// 检查是否在禁用列表中 | ||||
| 	for _, table := range m.config.DisabledTables { | ||||
| 		if table == tableName { | ||||
| 			status["in_disabled_list"] = true | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
| 	 | ||||
| 	return status | ||||
| } | ||||
|  | ||||
| // GetAllTableStatus 获取所有表的缓存状态 | ||||
| func (m *CacheConfigManager) GetAllTableStatus() map[string]interface{} { | ||||
| 	m.mutex.RLock() | ||||
| 	defer m.mutex.RUnlock() | ||||
| 	 | ||||
| 	result := map[string]interface{}{ | ||||
| 		"enabled_tables":  m.config.EnabledTables, | ||||
| 		"disabled_tables": m.config.DisabledTables, | ||||
| 		"config": map[string]interface{}{ | ||||
| 			"default_ttl":        m.config.DefaultTTL, | ||||
| 			"table_prefix":       m.config.TablePrefix, | ||||
| 			"max_cache_size":     m.config.MaxCacheSize, | ||||
| 			"cache_complex_sql":  m.config.CacheComplexSQL, | ||||
| 			"enable_stats":       m.config.EnableStats, | ||||
| 			"enable_warmup":      m.config.EnableWarmup, | ||||
| 			"penetration_guard":  m.config.PenetrationGuard, | ||||
| 			"bloom_filter":       m.config.BloomFilter, | ||||
| 			"auto_invalidate":    m.config.AutoInvalidate, | ||||
| 			"invalidate_delay":   m.config.InvalidateDelay, | ||||
| 		}, | ||||
| 	} | ||||
| 	 | ||||
| 	return result | ||||
| }  | ||||
							
								
								
									
										240
									
								
								internal/shared/cache/gorm_cache_plugin.go
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										240
									
								
								internal/shared/cache/gorm_cache_plugin.go
									
									
									
									
										vendored
									
									
								
							| @@ -8,6 +8,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.uber.org/zap" | ||||
| @@ -21,6 +22,10 @@ type GormCachePlugin struct { | ||||
| 	cache  interfaces.CacheService | ||||
| 	logger *zap.Logger | ||||
| 	config CacheConfig | ||||
| 	 | ||||
| 	// 缓存失效去重机制 | ||||
| 	invalidationQueue map[string]*time.Timer | ||||
| 	queueMutex        sync.RWMutex | ||||
| } | ||||
|  | ||||
| // CacheConfig 缓存配置 | ||||
| @@ -58,7 +63,8 @@ func DefaultCacheConfig() CacheConfig { | ||||
| 		PenetrationGuard: true, | ||||
| 		BloomFilter:      false, | ||||
| 		AutoInvalidate:   true, | ||||
| 		InvalidateDelay:  100 * time.Millisecond, | ||||
| 		// 增加延迟失效时间,减少频繁的缓存失效操作 | ||||
| 		InvalidateDelay:  500 * time.Millisecond, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -73,6 +79,7 @@ func NewGormCachePlugin(cache interfaces.CacheService, logger *zap.Logger, confi | ||||
| 		cache:  cache, | ||||
| 		logger: logger, | ||||
| 		config: cfg, | ||||
| 		invalidationQueue: make(map[string]*time.Timer), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| @@ -87,12 +94,30 @@ func (p *GormCachePlugin) Initialize(db *gorm.DB) error { | ||||
| 		zap.Duration("default_ttl", p.config.DefaultTTL), | ||||
| 		zap.Bool("auto_invalidate", p.config.AutoInvalidate), | ||||
| 		zap.Bool("penetration_guard", p.config.PenetrationGuard), | ||||
| 		zap.Duration("invalidate_delay", p.config.InvalidateDelay), | ||||
| 	) | ||||
|  | ||||
| 	// 注册回调函数 | ||||
| 	return p.registerCallbacks(db) | ||||
| } | ||||
|  | ||||
| // Shutdown 关闭插件,清理资源 | ||||
| func (p *GormCachePlugin) Shutdown() { | ||||
| 	p.queueMutex.Lock() | ||||
| 	defer p.queueMutex.Unlock() | ||||
| 	 | ||||
| 	// 停止所有定时器 | ||||
| 	for table, timer := range p.invalidationQueue { | ||||
| 		timer.Stop() | ||||
| 		p.logger.Debug("停止缓存失效定时器", zap.String("table", table)) | ||||
| 	} | ||||
| 	 | ||||
| 	// 清空队列 | ||||
| 	p.invalidationQueue = make(map[string]*time.Timer) | ||||
| 	 | ||||
| 	p.logger.Info("GORM缓存插件已关闭") | ||||
| } | ||||
|  | ||||
| // registerCallbacks 注册GORM回调 | ||||
| func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error { | ||||
| 	// Query回调 - 查询时检查缓存 | ||||
| @@ -117,6 +142,7 @@ func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error { | ||||
| func (p *GormCachePlugin) beforeQuery(db *gorm.DB) { | ||||
| 	// 检查是否启用缓存 | ||||
| 	if !p.shouldCache(db) { | ||||
| 		p.logger.Debug("跳过缓存", zap.String("table", db.Statement.Table)) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| @@ -147,7 +173,15 @@ func (p *GormCachePlugin) beforeQuery(db *gorm.DB) { | ||||
| 				p.updateStats("hit", db.Statement.Table) | ||||
| 			} | ||||
| 			return | ||||
| 		} else { | ||||
| 			p.logger.Warn("缓存数据恢复失败,将执行数据库查询",  | ||||
| 				zap.String("cache_key", cacheKey),  | ||||
| 				zap.Error(err)) | ||||
| 		} | ||||
| 	} else { | ||||
| 		p.logger.Debug("缓存未命中",  | ||||
| 			zap.String("cache_key", cacheKey),  | ||||
| 			zap.Error(err)) | ||||
| 	} | ||||
|  | ||||
| 	// 缓存未命中,设置标记 | ||||
| @@ -195,7 +229,10 @@ func (p *GormCachePlugin) afterCreate(db *gorm.DB) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	p.invalidateTableCache(db.Statement.Context, db.Statement.Table) | ||||
| 	// 只对启用缓存的表执行失效操作 | ||||
| 	if p.shouldInvalidateTable(db.Statement.Table) { | ||||
| 		p.invalidateTableCache(db.Statement.Context, db.Statement.Table) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // afterUpdate 更新后回调 | ||||
| @@ -204,7 +241,10 @@ func (p *GormCachePlugin) afterUpdate(db *gorm.DB) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	p.invalidateTableCache(db.Statement.Context, db.Statement.Table) | ||||
| 	// 只对启用缓存的表执行失效操作 | ||||
| 	if p.shouldInvalidateTable(db.Statement.Table) { | ||||
| 		p.invalidateTableCache(db.Statement.Context, db.Statement.Table) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // afterDelete 删除后回调 | ||||
| @@ -213,26 +253,64 @@ func (p *GormCachePlugin) afterDelete(db *gorm.DB) { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	p.invalidateTableCache(db.Statement.Context, db.Statement.Table) | ||||
| 	// 只对启用缓存的表执行失效操作 | ||||
| 	if p.shouldInvalidateTable(db.Statement.Table) { | ||||
| 		p.invalidateTableCache(db.Statement.Context, db.Statement.Table) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ================ 缓存管理方法 ================ | ||||
|  | ||||
| // shouldInvalidateTable 判断是否应该对表执行缓存失效操作 | ||||
| func (p *GormCachePlugin) shouldInvalidateTable(table string) bool { | ||||
| 	// 使用全局缓存配置管理器进行智能决策 | ||||
| 	if GlobalCacheConfigManager != nil { | ||||
| 		return GlobalCacheConfigManager.IsTableCacheEnabled(table) | ||||
| 	} | ||||
|  | ||||
| 	// 如果全局管理器未初始化,使用本地配置 | ||||
| 	// 检查表是否在禁用列表中 | ||||
| 	for _, disabledTable := range p.config.DisabledTables { | ||||
| 		if disabledTable == table { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 如果配置了启用列表,只对启用列表中的表执行失效操作 | ||||
| 	if len(p.config.EnabledTables) > 0 { | ||||
| 		for _, enabledTable := range p.config.EnabledTables { | ||||
| 			if enabledTable == table { | ||||
| 				return true | ||||
| 			} | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 如果没有配置启用列表,默认对所有表执行失效操作(除了禁用列表中的表) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // shouldCache 判断是否应该缓存 | ||||
| func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool { | ||||
| 	// 检查是否明确禁用缓存 | ||||
| 	// 检查是否手动禁用缓存 | ||||
| 	if value, ok := db.Statement.Get("cache:disabled"); ok && value.(bool) { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 检查是否明确启用缓存 | ||||
| 	// 检查是否手动启用缓存 | ||||
| 	if value, ok := db.Statement.Get("cache:enabled"); ok && value.(bool) { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// 使用全局缓存配置管理器进行智能决策 | ||||
| 	if GlobalCacheConfigManager != nil { | ||||
| 		return GlobalCacheConfigManager.IsTableCacheEnabled(db.Statement.Table) | ||||
| 	} | ||||
|  | ||||
| 	// 如果全局管理器未初始化,使用本地配置 | ||||
| 	// 检查表是否在禁用列表中 | ||||
| 	for _, table := range p.config.DisabledTables { | ||||
| 		if table == db.Statement.Table { | ||||
| 	for _, disabledTable := range p.config.DisabledTables { | ||||
| 		if disabledTable == db.Statement.Table { | ||||
| 			return false | ||||
| 		} | ||||
| 	} | ||||
| @@ -247,11 +325,7 @@ func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 检查是否为复杂查询 | ||||
| 	if !p.config.CacheComplexSQL && p.isComplexQuery(db) { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| @@ -292,10 +366,27 @@ func (p *GormCachePlugin) generateCacheKey(db *gorm.DB) string { | ||||
|  | ||||
| // hashSQL 对SQL语句和参数进行hash | ||||
| func (p *GormCachePlugin) hashSQL(sql string, vars []interface{}) string { | ||||
| 	// 将SQL和参数组合 | ||||
| 	// 修复:改进SQL hash生成,确保唯一性 | ||||
| 	combined := sql | ||||
| 	for _, v := range vars { | ||||
| 		combined += fmt.Sprintf(":%v", v) | ||||
| 		// 使用更精确的格式化,避免类型信息丢失 | ||||
| 		switch val := v.(type) { | ||||
| 		case string: | ||||
| 			combined += ":" + val | ||||
| 		case int, int32, int64: | ||||
| 			combined += fmt.Sprintf(":%d", val) | ||||
| 		case float32, float64: | ||||
| 			combined += fmt.Sprintf(":%f", val) | ||||
| 		case bool: | ||||
| 			combined += fmt.Sprintf(":%t", val) | ||||
| 		default: | ||||
| 			// 对于复杂类型,使用JSON序列化 | ||||
| 			if jsonData, err := json.Marshal(v); err == nil { | ||||
| 				combined += ":" + string(jsonData) | ||||
| 			} else { | ||||
| 				combined += fmt.Sprintf(":%v", v) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 计算MD5 hash | ||||
| @@ -328,9 +419,22 @@ func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db * | ||||
| 		return fmt.Errorf("查询结果为空") | ||||
| 	} | ||||
|  | ||||
| 	// 修复:改进缓存数据保存逻辑 | ||||
| 	var dataToCache interface{} | ||||
| 	 | ||||
| 	// 如果dest是切片,需要特殊处理 | ||||
| 	destValue := reflect.ValueOf(dest) | ||||
| 	if destValue.Kind() == reflect.Ptr && destValue.Elem().Kind() == reflect.Slice { | ||||
| 		// 对于切片,直接使用原始数据 | ||||
| 		dataToCache = destValue.Elem().Interface() | ||||
| 	} else { | ||||
| 		// 对于单个对象,也直接使用原始数据 | ||||
| 		dataToCache = dest | ||||
| 	} | ||||
|  | ||||
| 	// 构建缓存结果 | ||||
| 	result := CachedResult{ | ||||
| 		Data:      dest, | ||||
| 		Data:      dataToCache, | ||||
| 		RowCount:  db.Statement.RowsAffected, | ||||
| 		Timestamp: time.Now(), | ||||
| 	} | ||||
| @@ -340,6 +444,10 @@ func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db * | ||||
|  | ||||
| 	// 保存到缓存 | ||||
| 	if err := p.cache.Set(ctx, cacheKey, result, ttl); err != nil { | ||||
| 		p.logger.Error("保存查询结果到缓存失败", | ||||
| 			zap.String("cache_key", cacheKey), | ||||
| 			zap.Error(err), | ||||
| 		) | ||||
| 		return fmt.Errorf("保存到缓存失败: %w", err) | ||||
| 	} | ||||
|  | ||||
| @@ -364,25 +472,35 @@ func (p *GormCachePlugin) restoreFromCache(db *gorm.DB, cachedResult *CachedResu | ||||
| 		return fmt.Errorf("目标对象必须是指针") | ||||
| 	} | ||||
|  | ||||
| 	// 将缓存数据复制到目标 | ||||
| 	// 修复:改进缓存数据恢复逻辑 | ||||
| 	cachedValue := reflect.ValueOf(cachedResult.Data) | ||||
| 	if !cachedValue.Type().AssignableTo(destValue.Elem().Type()) { | ||||
| 	 | ||||
| 	// 如果类型完全匹配,直接赋值 | ||||
| 	if cachedValue.Type().AssignableTo(destValue.Elem().Type()) { | ||||
| 		destValue.Elem().Set(cachedValue) | ||||
| 	} else { | ||||
| 		// 尝试JSON转换 | ||||
| 		jsonData, err := json.Marshal(cachedResult.Data) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("缓存数据类型不匹配") | ||||
| 			p.logger.Error("序列化缓存数据失败", zap.Error(err)) | ||||
| 			return fmt.Errorf("缓存数据类型转换失败: %w", err) | ||||
| 		} | ||||
|  | ||||
| 		if err := json.Unmarshal(jsonData, db.Statement.Dest); err != nil { | ||||
| 			p.logger.Error("反序列化缓存数据失败",  | ||||
| 				zap.String("json_data", string(jsonData)),  | ||||
| 				zap.Error(err)) | ||||
| 			return fmt.Errorf("JSON反序列化失败: %w", err) | ||||
| 		} | ||||
| 	} else { | ||||
| 		destValue.Elem().Set(cachedValue) | ||||
| 	} | ||||
|  | ||||
| 	// 设置影响行数 | ||||
| 	db.Statement.RowsAffected = cachedResult.RowCount | ||||
|  | ||||
| 	p.logger.Debug("从缓存恢复数据成功",  | ||||
| 		zap.Int64("rows", cachedResult.RowCount), | ||||
| 		zap.Time("timestamp", cachedResult.Timestamp)) | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -400,37 +518,79 @@ func (p *GormCachePlugin) getTTL(db *gorm.DB) time.Duration { | ||||
|  | ||||
| // invalidateTableCache 失效表相关缓存 | ||||
| func (p *GormCachePlugin) invalidateTableCache(ctx context.Context, table string) { | ||||
| 	if ctx == nil { | ||||
| 		ctx = context.Background() | ||||
| 	} | ||||
|  | ||||
| 	// 延迟失效(避免并发问题) | ||||
| 	if p.config.InvalidateDelay > 0 { | ||||
| 		time.AfterFunc(p.config.InvalidateDelay, func() { | ||||
| 			p.doInvalidateTableCache(ctx, table) | ||||
| 		}) | ||||
| 	} else { | ||||
| 		p.doInvalidateTableCache(ctx, table) | ||||
| 	// 使用去重机制,避免重复的缓存失效操作 | ||||
| 	p.queueMutex.Lock() | ||||
| 	defer p.queueMutex.Unlock() | ||||
| 	 | ||||
| 	// 如果已经有相同的失效操作在队列中,取消之前的定时器 | ||||
| 	if timer, exists := p.invalidationQueue[table]; exists { | ||||
| 		timer.Stop() | ||||
| 		delete(p.invalidationQueue, table) | ||||
| 	} | ||||
| 	 | ||||
| 	// 创建独立的上下文,避免受到原始请求上下文的影响 | ||||
| 	// 设置合理的超时时间,避免缓存失效操作阻塞 | ||||
| 	cacheCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) | ||||
| 	 | ||||
| 	// 创建新的定时器 | ||||
| 	timer := time.AfterFunc(p.config.InvalidateDelay, func() { | ||||
| 		// 执行缓存失效 | ||||
| 		p.doInvalidateTableCache(cacheCtx, table) | ||||
| 		 | ||||
| 		// 清理定时器引用 | ||||
| 		p.queueMutex.Lock() | ||||
| 		delete(p.invalidationQueue, table) | ||||
| 		p.queueMutex.Unlock() | ||||
| 		 | ||||
| 		// 取消上下文 | ||||
| 		cancel() | ||||
| 	}) | ||||
| 	 | ||||
| 	// 将定时器加入队列 | ||||
| 	p.invalidationQueue[table] = timer | ||||
| 	 | ||||
| 	p.logger.Debug("缓存失效操作已加入队列", | ||||
| 		zap.String("table", table), | ||||
| 		zap.Duration("delay", p.config.InvalidateDelay), | ||||
| 	) | ||||
| } | ||||
|  | ||||
| // doInvalidateTableCache 执行缓存失效 | ||||
| func (p *GormCachePlugin) doInvalidateTableCache(ctx context.Context, table string) { | ||||
| 	pattern := fmt.Sprintf("%s:%s:*", p.config.TablePrefix, table) | ||||
|  | ||||
| 	if err := p.cache.DeletePattern(ctx, pattern); err != nil { | ||||
| 		p.logger.Warn("失效表缓存失败", | ||||
| 	// 添加重试机制,提高缓存失效的可靠性 | ||||
| 	maxRetries := 3 | ||||
| 	for attempt := 1; attempt <= maxRetries; attempt++ { | ||||
| 		if err := p.cache.DeletePattern(ctx, pattern); err != nil { | ||||
| 			if attempt < maxRetries { | ||||
| 				p.logger.Warn("缓存失效失败,准备重试", | ||||
| 					zap.String("table", table), | ||||
| 					zap.String("pattern", pattern), | ||||
| 					zap.Int("attempt", attempt), | ||||
| 					zap.Error(err), | ||||
| 				) | ||||
| 				// 短暂延迟后重试 | ||||
| 				time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) | ||||
| 				continue | ||||
| 			} | ||||
| 			 | ||||
| 			p.logger.Warn("失效表缓存失败", | ||||
| 				zap.String("table", table), | ||||
| 				zap.String("pattern", pattern), | ||||
| 				zap.Int("attempts", maxRetries), | ||||
| 				zap.Error(err), | ||||
| 			) | ||||
| 			return | ||||
| 		} | ||||
| 		 | ||||
| 		// 成功删除,记录日志并退出 | ||||
| 		p.logger.Debug("表缓存已失效", | ||||
| 			zap.String("table", table), | ||||
| 			zap.String("pattern", pattern), | ||||
| 			zap.Error(err), | ||||
| 		) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	p.logger.Debug("表缓存已失效", | ||||
| 		zap.String("table", table), | ||||
| 		zap.String("pattern", pattern), | ||||
| 	) | ||||
| } | ||||
|  | ||||
| // updateStats 更新统计信息 | ||||
|   | ||||
							
								
								
									
										113
									
								
								internal/shared/crypto/crypto.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								internal/shared/crypto/crypto.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,113 @@ | ||||
| package crypto | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/md5" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/hex" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| ) | ||||
|  | ||||
| // PKCS7填充 | ||||
| func PKCS7Padding(ciphertext []byte, blockSize int) []byte { | ||||
| 	padding := blockSize - len(ciphertext)%blockSize | ||||
| 	padtext := bytes.Repeat([]byte{byte(padding)}, padding) | ||||
| 	return append(ciphertext, padtext...) | ||||
| } | ||||
|  | ||||
| // 去除PKCS7填充 | ||||
| func PKCS7UnPadding(origData []byte) ([]byte, error) { | ||||
| 	length := len(origData) | ||||
| 	if length == 0 { | ||||
| 		return nil, errors.New("input data error") | ||||
| 	} | ||||
| 	unpadding := int(origData[length-1]) | ||||
| 	if unpadding > length { | ||||
| 		return nil, errors.New("unpadding size is invalid") | ||||
| 	} | ||||
|  | ||||
| 	// 检查填充字节是否一致 | ||||
| 	for i := 0; i < unpadding; i++ { | ||||
| 		if origData[length-1-i] != byte(unpadding) { | ||||
| 			return nil, errors.New("invalid padding") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return origData[:(length - unpadding)], nil | ||||
| } | ||||
|  | ||||
| // AES CBC模式加密,Base64传入传出 | ||||
| func AesEncrypt(plainText []byte, key string) (string, error) { | ||||
| 	keyBytes, err := hex.DecodeString(key) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	block, err := aes.NewCipher(keyBytes) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	blockSize := block.BlockSize() | ||||
| 	plainText = PKCS7Padding(plainText, blockSize) | ||||
|  | ||||
| 	cipherText := make([]byte, blockSize+len(plainText)) | ||||
| 	iv := cipherText[:blockSize] // 使用前blockSize字节作为IV | ||||
| 	_, err = io.ReadFull(rand.Reader, iv) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	mode := cipher.NewCBCEncrypter(block, iv) | ||||
| 	mode.CryptBlocks(cipherText[blockSize:], plainText) | ||||
|  | ||||
| 	return base64.StdEncoding.EncodeToString(cipherText), nil | ||||
| } | ||||
|  | ||||
| // AES CBC模式解密,Base64传入传出 | ||||
| func AesDecrypt(cipherTextBase64 string, key string) ([]byte, error) { | ||||
| 	keyBytes, err := hex.DecodeString(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	block, err := aes.NewCipher(keyBytes) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	blockSize := block.BlockSize() | ||||
| 	if len(cipherText) < blockSize { | ||||
| 		return nil, errors.New("ciphertext too short") | ||||
| 	} | ||||
|  | ||||
| 	iv := cipherText[:blockSize] | ||||
| 	cipherText = cipherText[blockSize:] | ||||
|  | ||||
| 	if len(cipherText)%blockSize != 0 { | ||||
| 		return nil, errors.New("ciphertext is not a multiple of the block size") | ||||
| 	} | ||||
|  | ||||
| 	mode := cipher.NewCBCDecrypter(block, iv) | ||||
| 	mode.CryptBlocks(cipherText, cipherText) | ||||
|  | ||||
| 	plainText, err := PKCS7UnPadding(cipherText) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return plainText, nil | ||||
| } | ||||
|  | ||||
| // Md5Encrypt 用于对传入的message进行MD5加密 | ||||
| func Md5Encrypt(message string) string { | ||||
| 	hash := md5.New() | ||||
| 	hash.Write([]byte(message))              // 将字符串转换为字节切片并写入 | ||||
| 	return hex.EncodeToString(hash.Sum(nil)) // 将哈希值转换为16进制字符串并返回 | ||||
| } | ||||
							
								
								
									
										63
									
								
								internal/shared/crypto/generate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								internal/shared/crypto/generate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| package crypto | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/hex" | ||||
| 	"io" | ||||
| 	mathrand "math/rand" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // 生成AES-128密钥的函数,符合市面规范 | ||||
| func GenerateSecretKey() (string, error) { | ||||
| 	key := make([]byte, 16) // 16字节密钥 | ||||
| 	_, err := io.ReadFull(rand.Reader, key) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return hex.EncodeToString(key), nil | ||||
| } | ||||
|  | ||||
| func GenerateSecretId() (string, error) { | ||||
| 	// 创建一个字节数组,用于存储随机数据 | ||||
| 	bytes := make([]byte, 8) // 因为每个字节表示两个16进制字符 | ||||
|  | ||||
| 	// 读取随机字节到数组中 | ||||
| 	_, err := rand.Read(bytes) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	// 将字节数组转换为16进制字符串 | ||||
| 	return hex.EncodeToString(bytes), nil | ||||
| } | ||||
|  | ||||
| // GenerateTransactionID 生成16位数的交易单号 | ||||
| func GenerateTransactionID() string { | ||||
| 	length := 16 | ||||
| 	// 获取当前时间戳 | ||||
| 	timestamp := time.Now().UnixNano() | ||||
|  | ||||
| 	// 转换为字符串 | ||||
| 	timeStr := strconv.FormatInt(timestamp, 10) | ||||
|  | ||||
| 	// 生成随机数 | ||||
| 	mathrand.Seed(time.Now().UnixNano()) | ||||
| 	randomPart := strconv.Itoa(mathrand.Intn(1000000)) | ||||
|  | ||||
| 	// 组合时间戳和随机数 | ||||
| 	combined := timeStr + randomPart | ||||
|  | ||||
| 	// 如果长度超出指定值,则截断;如果不够,则填充随机字符 | ||||
| 	if len(combined) >= length { | ||||
| 		return combined[:length] | ||||
| 	} | ||||
|  | ||||
| 	// 如果长度不够,填充0 | ||||
| 	for len(combined) < length { | ||||
| 		combined += strconv.Itoa(mathrand.Intn(10)) // 填充随机数 | ||||
| 	} | ||||
|  | ||||
| 	return combined | ||||
| } | ||||
							
								
								
									
										150
									
								
								internal/shared/crypto/west_crypto.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								internal/shared/crypto/west_crypto.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,150 @@ | ||||
| package crypto | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"crypto/aes" | ||||
| 	"crypto/cipher" | ||||
| 	"crypto/sha1" | ||||
| 	"encoding/base64" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	KEY_SIZE = 16 // AES-128, 16 bytes | ||||
| ) | ||||
|  | ||||
| // Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding | ||||
| func WestDexEncrypt(data, secretKey string) (string, error) { | ||||
| 	key := generateAESKey(KEY_SIZE*8, []byte(secretKey)) | ||||
| 	ciphertext, err := aesEncrypt([]byte(data), key) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| 	return base64.StdEncoding.EncodeToString(ciphertext), nil | ||||
| } | ||||
|  | ||||
| // Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding | ||||
| func WestDexDecrypt(encodedData, secretKey string) ([]byte, error) { | ||||
| 	ciphertext, err := base64.StdEncoding.DecodeString(encodedData) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	key := generateAESKey(KEY_SIZE*8, []byte(secretKey)) | ||||
| 	plaintext, err := aesDecrypt(ciphertext, key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return plaintext, nil | ||||
| } | ||||
|  | ||||
| // generateAESKey generates a key for AES encryption using a SHA-1 based PRNG | ||||
| func generateAESKey(length int, password []byte) []byte { | ||||
| 	h := sha1.New() | ||||
| 	h.Write(password) | ||||
| 	state := h.Sum(nil) | ||||
|  | ||||
| 	keyBytes := make([]byte, 0, length/8) | ||||
| 	for len(keyBytes) < length/8 { | ||||
| 		h := sha1.New() | ||||
| 		h.Write(state) | ||||
| 		state = h.Sum(nil) | ||||
| 		keyBytes = append(keyBytes, state...) | ||||
| 	} | ||||
|  | ||||
| 	return keyBytes[:length/8] | ||||
| } | ||||
|  | ||||
| // aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding | ||||
| func aesEncrypt(plaintext, key []byte) ([]byte, error) { | ||||
| 	block, err := aes.NewCipher(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize()) | ||||
| 	ciphertext := make([]byte, len(paddedPlaintext)) | ||||
| 	mode := newECBEncrypter(block) | ||||
| 	mode.CryptBlocks(ciphertext, paddedPlaintext) | ||||
| 	return ciphertext, nil | ||||
| } | ||||
|  | ||||
| // aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding | ||||
| func aesDecrypt(ciphertext, key []byte) ([]byte, error) { | ||||
| 	block, err := aes.NewCipher(key) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	plaintext := make([]byte, len(ciphertext)) | ||||
| 	mode := newECBDecrypter(block) | ||||
| 	mode.CryptBlocks(plaintext, ciphertext) | ||||
| 	return pkcs5Unpadding(plaintext), nil | ||||
| } | ||||
|  | ||||
| // pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding | ||||
| func pkcs5Padding(src []byte, blockSize int) []byte { | ||||
| 	padding := blockSize - len(src)%blockSize | ||||
| 	padtext := bytes.Repeat([]byte{byte(padding)}, padding) | ||||
| 	return append(src, padtext...) | ||||
| } | ||||
|  | ||||
| // pkcs5Unpadding removes PKCS5 padding from the input | ||||
| func pkcs5Unpadding(src []byte) []byte { | ||||
| 	length := len(src) | ||||
| 	unpadding := int(src[length-1]) | ||||
| 	return src[:(length - unpadding)] | ||||
| } | ||||
|  | ||||
| // ECB mode encryption/decryption | ||||
| type ecb struct { | ||||
| 	b         cipher.Block | ||||
| 	blockSize int | ||||
| } | ||||
|  | ||||
| func newECB(b cipher.Block) *ecb { | ||||
| 	return &ecb{ | ||||
| 		b:         b, | ||||
| 		blockSize: b.BlockSize(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ecbEncrypter ecb | ||||
|  | ||||
| func newECBEncrypter(b cipher.Block) cipher.BlockMode { | ||||
| 	return (*ecbEncrypter)(newECB(b)) | ||||
| } | ||||
|  | ||||
| func (x *ecbEncrypter) BlockSize() int { return x.blockSize } | ||||
|  | ||||
| func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { | ||||
| 	if len(src)%x.blockSize != 0 { | ||||
| 		panic("crypto/cipher: input not full blocks") | ||||
| 	} | ||||
| 	if len(dst) < len(src) { | ||||
| 		panic("crypto/cipher: output smaller than input") | ||||
| 	} | ||||
| 	for len(src) > 0 { | ||||
| 		x.b.Encrypt(dst, src[:x.blockSize]) | ||||
| 		src = src[x.blockSize:] | ||||
| 		dst = dst[x.blockSize:] | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type ecbDecrypter ecb | ||||
|  | ||||
| func newECBDecrypter(b cipher.Block) cipher.BlockMode { | ||||
| 	return (*ecbDecrypter)(newECB(b)) | ||||
| } | ||||
|  | ||||
| func (x *ecbDecrypter) BlockSize() int { return x.blockSize } | ||||
|  | ||||
| func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { | ||||
| 	if len(src)%x.blockSize != 0 { | ||||
| 		panic("crypto/cipher: input not full blocks") | ||||
| 	} | ||||
| 	if len(dst) < len(src) { | ||||
| 		panic("crypto/cipher: output smaller than input") | ||||
| 	} | ||||
| 	for len(src) > 0 { | ||||
| 		x.b.Decrypt(dst, src[:x.blockSize]) | ||||
| 		src = src[x.blockSize:] | ||||
| 		dst = dst[x.blockSize:] | ||||
| 	} | ||||
| } | ||||
| @@ -8,6 +8,7 @@ import ( | ||||
| 	"go.uber.org/zap" | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"tyapi-server/internal/shared/cache" | ||||
| 	"tyapi-server/internal/shared/interfaces" | ||||
| ) | ||||
|  | ||||
| @@ -26,41 +27,112 @@ func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName stri | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ================ 智能缓存决策方法 ================ | ||||
|  | ||||
| // isTableCacheEnabled 检查表是否启用缓存 | ||||
| func (r *CachedBaseRepositoryImpl) isTableCacheEnabled() bool { | ||||
| 	// 使用全局缓存配置管理器 | ||||
| 	if cache.GlobalCacheConfigManager != nil { | ||||
| 		return cache.GlobalCacheConfigManager.IsTableCacheEnabled(r.tableName) | ||||
| 	} | ||||
| 	 | ||||
| 	// 如果全局管理器未初始化,默认启用缓存 | ||||
| 	r.logger.Warn("全局缓存配置管理器未初始化,默认启用缓存",  | ||||
| 		zap.String("table", r.tableName)) | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // shouldUseCacheForTable 智能判断是否应该对当前表使用缓存 | ||||
| func (r *CachedBaseRepositoryImpl) shouldUseCacheForTable() bool { | ||||
| 	// 检查表是否启用缓存 | ||||
| 	if !r.isTableCacheEnabled() { | ||||
| 		r.logger.Debug("表未启用缓存,跳过缓存操作",  | ||||
| 			zap.String("table", r.tableName)) | ||||
| 		return false | ||||
| 	} | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // ================ 智能缓存方法 ================ | ||||
|  | ||||
| // GetWithCache 带缓存的单条查询 | ||||
| // GetWithCache 带缓存的单条查询(智能决策) | ||||
| func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl) | ||||
| 	db := r.GetDB(ctx) | ||||
| 	 | ||||
| 	// 智能决策:根据表配置决定是否使用缓存 | ||||
| 	if r.shouldUseCacheForTable() { | ||||
| 		db = db.Set("cache:enabled", true).Set("cache:ttl", ttl) | ||||
| 		r.logger.Debug("执行带缓存查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.Duration("ttl", ttl), | ||||
| 			zap.String("where", where)) | ||||
| 	} else { | ||||
| 		db = db.Set("cache:disabled", true) | ||||
| 		r.logger.Debug("执行无缓存查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.String("where", where)) | ||||
| 	} | ||||
|  | ||||
| 	return db.Where(where, args...).First(dest).Error | ||||
| } | ||||
|  | ||||
| // FindWithCache 带缓存的多条查询 | ||||
| // FindWithCache 带缓存的多条查询(智能决策) | ||||
| func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl) | ||||
| 	db := r.GetDB(ctx) | ||||
| 	 | ||||
| 	// 智能决策:根据表配置决定是否使用缓存 | ||||
| 	if r.shouldUseCacheForTable() { | ||||
| 		db = db.Set("cache:enabled", true).Set("cache:ttl", ttl) | ||||
| 		r.logger.Debug("执行带缓存批量查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.Duration("ttl", ttl), | ||||
| 			zap.String("where", where)) | ||||
| 	} else { | ||||
| 		db = db.Set("cache:disabled", true) | ||||
| 		r.logger.Debug("执行无缓存批量查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.String("where", where)) | ||||
| 	} | ||||
|  | ||||
| 	return db.Where(where, args...).Find(dest).Error | ||||
| } | ||||
|  | ||||
| // CountWithCache 带缓存的计数查询 | ||||
| // CountWithCache 带缓存的计数查询(智能决策) | ||||
| func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl). | ||||
| 		Model(entity) | ||||
| 	db := r.GetDB(ctx).Model(entity) | ||||
| 	 | ||||
| 	// 智能决策:根据表配置决定是否使用缓存 | ||||
| 	if r.shouldUseCacheForTable() { | ||||
| 		db = db.Set("cache:enabled", true).Set("cache:ttl", ttl) | ||||
| 		r.logger.Debug("执行带缓存计数查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.Duration("ttl", ttl), | ||||
| 			zap.String("where", where)) | ||||
| 	} else { | ||||
| 		db = db.Set("cache:disabled", true) | ||||
| 		r.logger.Debug("执行无缓存计数查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.String("where", where)) | ||||
| 	} | ||||
|  | ||||
| 	return db.Where(where, args...).Count(count).Error | ||||
| } | ||||
|  | ||||
| // ListWithCache 带缓存的列表查询 | ||||
| // ListWithCache 带缓存的列表查询(智能决策) | ||||
| func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl) | ||||
| 	db := r.GetDB(ctx) | ||||
| 	 | ||||
| 	// 智能决策:根据表配置决定是否使用缓存 | ||||
| 	if r.shouldUseCacheForTable() { | ||||
| 		db = db.Set("cache:enabled", true).Set("cache:ttl", ttl) | ||||
| 		r.logger.Debug("执行带缓存列表查询",  | ||||
| 			zap.String("table", r.tableName), | ||||
| 			zap.Duration("ttl", ttl)) | ||||
| 	} else { | ||||
| 		db = db.Set("cache:disabled", true) | ||||
| 		r.logger.Debug("执行无缓存列表查询",  | ||||
| 			zap.String("table", r.tableName)) | ||||
| 	} | ||||
|  | ||||
| 	// 应用where条件 | ||||
| 	if options.Where != "" { | ||||
| @@ -90,12 +162,12 @@ func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest inter | ||||
|  | ||||
| // CacheListOptions 缓存列表查询选项 | ||||
| type CacheListOptions struct { | ||||
| 	Where     string        `json:"where"` | ||||
| 	Args      []interface{} `json:"args"` | ||||
| 	Order     string        `json:"order"` | ||||
| 	Limit     int           `json:"limit"` | ||||
| 	Offset    int           `json:"offset"` | ||||
| 	Preloads  []string      `json:"preloads"` | ||||
| 	Where    string        `json:"where"` | ||||
| 	Args     []interface{} `json:"args"` | ||||
| 	Order    string        `json:"order"` | ||||
| 	Limit    int           `json:"limit"` | ||||
| 	Offset   int           `json:"offset"` | ||||
| 	Preloads []string      `json:"preloads"` | ||||
| } | ||||
|  | ||||
| // ================ 缓存控制方法 ================ | ||||
| @@ -142,6 +214,10 @@ func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl { | ||||
|  | ||||
| // SmartGetByID 智能ID查询(自动缓存) | ||||
| func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error { | ||||
| 	r.logger.Debug("执行智能ID查询",  | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.String("id", id)) | ||||
| 	 | ||||
| 	return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id) | ||||
| } | ||||
|  | ||||
| @@ -151,7 +227,7 @@ func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest int | ||||
| 	if len(ttl) > 0 { | ||||
| 		cacheTTL = ttl[0] | ||||
| 	} | ||||
| 	 | ||||
|  | ||||
| 	return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value) | ||||
| } | ||||
|  | ||||
| @@ -161,6 +237,12 @@ func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface | ||||
| 	cacheTTL := r.calculateCacheTTL(options) | ||||
| 	useCache := r.shouldUseCache(options) | ||||
|  | ||||
| 	r.logger.Debug("执行智能列表查询",  | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.Bool("use_cache", useCache), | ||||
| 		zap.Duration("cache_ttl", cacheTTL)) | ||||
|  | ||||
| 	// 修复:确保缓存标记在查询前正确设置 | ||||
| 	db := r.GetDB(ctx) | ||||
| 	if useCache { | ||||
| 		db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL) | ||||
| @@ -254,14 +336,14 @@ func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions | ||||
|  | ||||
| // WarmupCommonQueries 预热常用查询 | ||||
| func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error { | ||||
| 	r.logger.Info("开始预热缓存",  | ||||
| 	r.logger.Info("开始预热缓存", | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.Int("queries", len(queries)), | ||||
| 	) | ||||
|  | ||||
| 	for _, query := range queries { | ||||
| 		if err := r.executeWarmupQuery(ctx, query); err != nil { | ||||
| 			r.logger.Warn("缓存预热失败",  | ||||
| 			r.logger.Warn("缓存预热失败", | ||||
| 				zap.String("query", query.Name), | ||||
| 				zap.Error(err), | ||||
| 			) | ||||
| @@ -273,11 +355,11 @@ func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, quer | ||||
|  | ||||
| // WarmupQuery 预热查询定义 | ||||
| type WarmupQuery struct { | ||||
| 	Name     string        `json:"name"` | ||||
| 	SQL      string        `json:"sql"` | ||||
| 	Args     []interface{} `json:"args"` | ||||
| 	TTL      time.Duration `json:"ttl"` | ||||
| 	Dest     interface{}   `json:"dest"` | ||||
| 	Name string        `json:"name"` | ||||
| 	SQL  string        `json:"sql"` | ||||
| 	Args []interface{} `json:"args"` | ||||
| 	TTL  time.Duration `json:"ttl"` | ||||
| 	Dest interface{}   `json:"dest"` | ||||
| } | ||||
|  | ||||
| // executeWarmupQuery 执行预热查询 | ||||
| @@ -313,7 +395,7 @@ func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interfa | ||||
| 		if err := r.CreateEntity(ctx, newEntity); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		 | ||||
|  | ||||
| 		// 将新创建的实体复制到dest | ||||
| 		// 这里需要反射或其他方式复制 | ||||
| 		return nil | ||||
| @@ -333,7 +415,7 @@ func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids [] | ||||
|  | ||||
| // RefreshCache 刷新缓存 | ||||
| func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error { | ||||
| 	r.logger.Info("刷新缓存",  | ||||
| 	r.logger.Info("刷新缓存", | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.String("pattern", pattern), | ||||
| 	) | ||||
| @@ -348,9 +430,9 @@ func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern str | ||||
| // GetCacheInfo 获取缓存信息 | ||||
| func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} { | ||||
| 	return map[string]interface{}{ | ||||
| 		"table_name":       r.tableName, | ||||
| 		"cache_enabled":    true, | ||||
| 		"default_ttl":      "30m", | ||||
| 		"table_name":    r.tableName, | ||||
| 		"cache_enabled": true, | ||||
| 		"default_ttl":   "30m", | ||||
| 		"cache_patterns": []string{ | ||||
| 			fmt.Sprintf("gorm_cache:%s:*", r.tableName), | ||||
| 		}, | ||||
| @@ -359,9 +441,9 @@ func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} { | ||||
|  | ||||
| // LogCacheOperation 记录缓存操作 | ||||
| func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) { | ||||
| 	r.logger.Debug("缓存操作",  | ||||
| 	r.logger.Debug("缓存操作", | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.String("operation", operation), | ||||
| 		zap.String("details", details), | ||||
| 	) | ||||
| }  | ||||
| } | ||||
|   | ||||
| @@ -2,18 +2,43 @@ package esign | ||||
|  | ||||
| import "fmt" | ||||
|  | ||||
| type EsignContractConfig struct { | ||||
| 	Name       string `json:"name" yaml:"name"` | ||||
| 	ExpireDays int    `json:"expireDays" yaml:"expire_days"` | ||||
| 	RetryCount int    `json:"retryCount" yaml:"retry_count"` | ||||
| } | ||||
|  | ||||
| type EsignAuthConfig struct { | ||||
| 	OrgAuthModes         []string `json:"orgAuthModes" yaml:"org_auth_modes"` | ||||
| 	DefaultAuthMode      string   `json:"defaultAuthMode" yaml:"default_auth_mode"` | ||||
| 	PsnAuthModes         []string `json:"psnAuthModes" yaml:"psn_auth_modes"` | ||||
| 	WillingnessAuthModes []string `json:"willingnessAuthModes" yaml:"willingness_auth_modes"` | ||||
| 	RedirectUrl          string   `json:"redirectUrl" yaml:"redirect_url"` | ||||
| } | ||||
|  | ||||
| type EsignSignConfig struct { | ||||
| 	AutoFinish     bool   `json:"autoFinish" yaml:"auto_finish"` | ||||
| 	SignFieldStyle int    `json:"signFieldStyle" yaml:"sign_field_style"` | ||||
| 	ClientType     string `json:"clientType" yaml:"client_type"` | ||||
| 	RedirectUrl    string `json:"redirectUrl" yaml:"redirect_url"` | ||||
| } | ||||
|  | ||||
| // Config e签宝服务配置结构体 | ||||
| // 包含应用ID、密钥、服务器URL和模板ID等基础配置信息 | ||||
| // 新增Contract、Auth、Sign配置 | ||||
| type Config struct { | ||||
| 	AppID      string `json:"appId"`      // 应用ID | ||||
| 	AppSecret  string `json:"appSecret"`  // 应用密钥 | ||||
| 	ServerURL  string `json:"serverUrl"`  // 服务器URL | ||||
| 	TemplateID string `json:"templateId"` // 模板ID | ||||
| 	AppID         string              `json:"appId" yaml:"app_id"` | ||||
| 	AppSecret     string              `json:"appSecret" yaml:"app_secret"` | ||||
| 	ServerURL     string              `json:"serverUrl" yaml:"server_url"` | ||||
| 	TemplateID    string              `json:"templateId" yaml:"template_id"` | ||||
| 	Contract      *EsignContractConfig `json:"contract" yaml:"contract"` | ||||
| 	Auth          *EsignAuthConfig     `json:"auth" yaml:"auth"` | ||||
| 	Sign          *EsignSignConfig     `json:"sign" yaml:"sign"` | ||||
| } | ||||
|  | ||||
| // NewConfig 创建新的配置实例 | ||||
| // 提供配置验证和默认值设置 | ||||
| func NewConfig(appID, appSecret, serverURL, templateID string) (*Config, error) { | ||||
| func NewConfig(appID, appSecret, serverURL, templateID string, contract *EsignContractConfig, auth *EsignAuthConfig, sign *EsignSignConfig) (*Config, error) { | ||||
| 	if appID == "" { | ||||
| 		return nil, fmt.Errorf("应用ID不能为空") | ||||
| 	} | ||||
| @@ -26,12 +51,15 @@ func NewConfig(appID, appSecret, serverURL, templateID string) (*Config, error) | ||||
| 	if templateID == "" { | ||||
| 		return nil, fmt.Errorf("模板ID不能为空") | ||||
| 	} | ||||
| 	 | ||||
|  | ||||
| 	return &Config{ | ||||
| 		AppID:      appID, | ||||
| 		AppSecret:  appSecret, | ||||
| 		ServerURL:  serverURL, | ||||
| 		TemplateID: templateID, | ||||
| 		Contract:   contract, | ||||
| 		Auth:       auth, | ||||
| 		Sign:       sign, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| @@ -60,7 +88,7 @@ const ( | ||||
| 	AuthModeBank    = "PSN_BANK"    // 银行卡认证 | ||||
|  | ||||
| 	// 意愿认证模式 | ||||
| 	WillingnessAuthSMS = "CODE_SMS" // 短信验证码 | ||||
| 	WillingnessAuthSMS   = "CODE_SMS"   // 短信验证码 | ||||
| 	WillingnessAuthEmail = "CODE_EMAIL" // 邮箱验证码 | ||||
|  | ||||
| 	// 证件类型常量 | ||||
| @@ -69,7 +97,7 @@ const ( | ||||
|  | ||||
| 	// 签署区样式常量 | ||||
| 	SignFieldStyleNormal = 1 // 普通签章 | ||||
| 	SignFieldStyleSeam    = 2 // 骑缝签章 | ||||
| 	SignFieldStyleSeam   = 2 // 骑缝签章 | ||||
|  | ||||
| 	// 签署人类型常量 | ||||
| 	SignerTypePerson = 0 // 个人 | ||||
| @@ -80,4 +108,4 @@ const ( | ||||
|  | ||||
| 	// 客户端类型常量 | ||||
| 	ClientTypeAll = "ALL" // 所有客户端 | ||||
| )  | ||||
| ) | ||||
|   | ||||
| @@ -13,6 +13,24 @@ func Example() { | ||||
| 		"your_app_secret", | ||||
| 		"https://smlopenapi.esign.cn", | ||||
| 		"your_template_id", | ||||
| 		&EsignContractConfig{ | ||||
| 			Name:       "测试合同", | ||||
| 			ExpireDays: 30, | ||||
| 			RetryCount: 3, | ||||
| 		}, | ||||
| 		&EsignAuthConfig{ | ||||
| 			OrgAuthModes:         []string{"ORG"}, | ||||
| 			DefaultAuthMode:      "ORG", | ||||
| 			PsnAuthModes:         []string{"PSN"}, | ||||
| 			WillingnessAuthModes: []string{"WILLINGNESS"}, | ||||
| 			RedirectUrl:          "https://www.tianyuanapi.com/certification/callback", | ||||
| 		}, | ||||
| 		&EsignSignConfig{ | ||||
| 			AutoFinish:     true, | ||||
| 			SignFieldStyle: 1, | ||||
| 			ClientType:     "ALL", | ||||
| 			RedirectUrl:    "https://www.tianyuanapi.com/certification/callback", | ||||
| 		}, | ||||
| 	) | ||||
| 	if err != nil { | ||||
| 		log.Fatal("配置创建失败:", err) | ||||
| @@ -122,7 +140,22 @@ func Example() { | ||||
| // ExampleBasicUsage 基础用法示例 | ||||
| func ExampleBasicUsage() { | ||||
| 	// 最简单的用法 - 一行代码完成合同签署 | ||||
| 	config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id") | ||||
| 	config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{ | ||||
| 		Name:       "测试合同", | ||||
| 		ExpireDays: 30, | ||||
| 		RetryCount: 3, | ||||
| 	}, &EsignAuthConfig{ | ||||
| 		OrgAuthModes:         []string{"ORG"}, | ||||
| 		DefaultAuthMode:      "ORG", | ||||
| 		PsnAuthModes:         []string{"PSN"}, | ||||
| 		WillingnessAuthModes: []string{"WILLINGNESS"}, | ||||
| 		RedirectUrl:          "https://www.tianyuanapi.com/certification/callback", | ||||
| 	}, &EsignSignConfig{ | ||||
| 		AutoFinish:     true, | ||||
| 		SignFieldStyle: 1, | ||||
| 		ClientType:     "ALL", | ||||
| 		RedirectUrl:    "https://www.tianyuanapi.com/certification/callback", | ||||
| 	}) | ||||
| 	client := NewClient(config) | ||||
|  | ||||
| 	// 快速合同签署 | ||||
| @@ -143,7 +176,22 @@ func ExampleBasicUsage() { | ||||
|  | ||||
| // ExampleWithCustomData 自定义数据示例 | ||||
| func ExampleWithCustomData() { | ||||
| 	config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id") | ||||
| 	config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{ | ||||
| 		Name:       "测试合同", | ||||
| 		ExpireDays: 30, | ||||
| 		RetryCount: 3, | ||||
| 	}, &EsignAuthConfig{ | ||||
| 		OrgAuthModes:         []string{"ORG"}, | ||||
| 		DefaultAuthMode:      "ORG", | ||||
| 		PsnAuthModes:         []string{"PSN"}, | ||||
| 		WillingnessAuthModes: []string{"WILLINGNESS"}, | ||||
| 		RedirectUrl:          "https://www.tianyuanapi.com/certification/callback", | ||||
| 	}, &EsignSignConfig{ | ||||
| 		AutoFinish:     true, | ||||
| 		SignFieldStyle: 1, | ||||
| 		ClientType:     "ALL", | ||||
| 		RedirectUrl:    "https://www.tianyuanapi.com/certification/callback", | ||||
| 	}) | ||||
| 	client := NewClient(config) | ||||
|  | ||||
| 	// 使用自定义模板数据 | ||||
| @@ -171,7 +219,22 @@ func ExampleWithCustomData() { | ||||
|  | ||||
| // ExampleEnterpriseAuth 企业认证示例 | ||||
| func ExampleEnterpriseAuth() { | ||||
| 	config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id") | ||||
| 	config, _ := NewConfig("app_id", "app_secret", "server_url", "template_id", &EsignContractConfig{ | ||||
| 		Name:       "测试合同", | ||||
| 		ExpireDays: 30, | ||||
| 		RetryCount: 3, | ||||
| 	}, &EsignAuthConfig{ | ||||
| 		OrgAuthModes:         []string{"ORG"}, | ||||
| 		DefaultAuthMode:      "ORG", | ||||
| 		PsnAuthModes:         []string{"PSN"}, | ||||
| 		WillingnessAuthModes: []string{"WILLINGNESS"}, | ||||
| 		RedirectUrl:          "https://www.tianyuanapi.com/certification/callback", | ||||
| 	}, &EsignSignConfig{ | ||||
| 		AutoFinish: true, | ||||
| 		SignFieldStyle: 1, | ||||
| 		ClientType: "ALL", | ||||
| 		RedirectUrl: "https://www.tianyuanapi.com/certification/callback", | ||||
| 	}) | ||||
| 	client := NewClient(config) | ||||
|  | ||||
| 	// 企业认证 | ||||
|   | ||||
| @@ -34,8 +34,8 @@ func (s *FileOpsService) UpdateConfig(config *Config) { | ||||
| func (s *FileOpsService) DownloadSignedFile(signFlowId string) (*DownloadSignedFileResponse, error) { | ||||
| 	fmt.Println("开始下载已签署文件及附属材料...") | ||||
|  | ||||
| 	// 发送API请求 | ||||
| 	urlPath := fmt.Sprintf("/v3/sign-flow/%s/attachments", signFlowId) | ||||
| 	// 按照最新e签宝文档,接口路径应为 /v3/sign-flow/{signFlowId}/file-download-url | ||||
| 	urlPath := fmt.Sprintf("/v3/sign-flow/%s/file-download-url", signFlowId) | ||||
| 	responseBody, err := s.httpClient.Request("GET", urlPath, nil) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("下载已签署文件失败: %v", err) | ||||
|   | ||||
| @@ -44,8 +44,8 @@ func (h *HTTPClient) UpdateConfig(config *Config) { | ||||
| func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error) { | ||||
| 	// 生成签名所需参数 | ||||
| 	timestamp := getCurrentTimestamp() | ||||
| 	nonce := generateNonce() | ||||
| 	date := getCurrentDate() | ||||
| 	// date := getCurrentDate() | ||||
| 	date := "" | ||||
|  | ||||
| 	// 计算Content-MD5 | ||||
| 	contentMD5 := "" | ||||
| @@ -53,14 +53,14 @@ func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error | ||||
| 		contentMD5 = getContentMD5(body) | ||||
| 	} | ||||
|  | ||||
| 	// 根据Java示例,Headers为空字符串 | ||||
| 	headers := "" | ||||
|  | ||||
| 	// 生成签名 | ||||
| 	// 生成签名(用原始urlPath) | ||||
| 	signature := generateSignature(h.config.AppSecret, method, "*/*", contentMD5, "application/json", date, headers, urlPath) | ||||
|  | ||||
| 	// 创建HTTP请求 | ||||
| 	url := h.config.ServerURL + urlPath | ||||
| 	// 实际请求url用encode后的urlPath | ||||
| 	encodedURLPath := encodeURLQueryParams(urlPath) | ||||
| 	url := h.config.ServerURL + encodedURLPath | ||||
| 	req, err := http.NewRequest(method, url, bytes.NewBuffer(body)) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("创建HTTP请求失败: %v", err) | ||||
| @@ -69,12 +69,11 @@ func (h *HTTPClient) Request(method, urlPath string, body []byte) ([]byte, error | ||||
| 	// 设置请求头 | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Content-MD5", contentMD5) | ||||
| 	req.Header.Set("Date", date) | ||||
| 	// req.Header.Set("Date", date) | ||||
| 	req.Header.Set("Accept", "*/*") | ||||
| 	req.Header.Set("X-Tsign-Open-App-Id", h.config.AppID) | ||||
| 	req.Header.Set("X-Tsign-Open-Auth-Mode", "Signature") | ||||
| 	req.Header.Set("X-Tsign-Open-Ca-Timestamp", timestamp) | ||||
| 	req.Header.Set("X-Tsign-Open-Nonce", nonce) | ||||
| 	req.Header.Set("X-Tsign-Open-Ca-Signature", signature) | ||||
|  | ||||
| 	// 发送请求 | ||||
| @@ -197,3 +196,24 @@ func sortURLQueryParams(urlPath string) string { | ||||
| 	} | ||||
| 	return basePath | ||||
| } | ||||
|  | ||||
| // encodeURLQueryParams 对urlPath中的query参数值进行encode | ||||
| func encodeURLQueryParams(urlPath string) string { | ||||
| 	if !strings.Contains(urlPath, "?") { | ||||
| 		return urlPath | ||||
| 	} | ||||
| 	parts := strings.SplitN(urlPath, "?", 2) | ||||
| 	basePath := parts[0] | ||||
| 	queryString := parts[1] | ||||
| 	values, err := url.ParseQuery(queryString) | ||||
| 	if err != nil { | ||||
| 		return urlPath | ||||
| 	} | ||||
| 	var encodedPairs []string | ||||
| 	for key, vals := range values { | ||||
| 		for _, val := range vals { | ||||
| 			encodedPairs = append(encodedPairs, key+"="+url.QueryEscape(val)) | ||||
| 		} | ||||
| 	} | ||||
| 	return basePath + "?" + strings.Join(encodedPairs, "&") | ||||
| } | ||||
|   | ||||
| @@ -66,6 +66,9 @@ func (s *OrgAuthService) GetAuthURL(req *OrgAuthRequest) (string, string, string | ||||
| 			}, | ||||
| 		}, | ||||
| 		ClientType: ClientTypeAll, | ||||
| 		RedirectConfig: &RedirectConfig{ | ||||
| 			RedirectUrl: s.config.Auth.RedirectUrl, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// 序列化请求数据 | ||||
|   | ||||
| @@ -43,7 +43,7 @@ func (s *SignFlowService) Create(req *CreateSignFlowRequest) (string, error) { | ||||
| 		Docs: []DocInfo{ | ||||
| 			{ | ||||
| 				FileId:   req.FileID, | ||||
| 				FileName: "天远数据API合作协议.pdf", | ||||
| 				FileName: s.config.Contract.Name, | ||||
| 			}, | ||||
| 		}, | ||||
| 		SignFlowConfig: s.buildSignFlowConfig(), | ||||
| @@ -141,9 +141,9 @@ func (s *SignFlowService) buildPartyASigner(fileID string) SignerInfo { | ||||
| 					AutoSign:       true, | ||||
| 					SignFieldStyle: SignFieldStyleNormal, | ||||
| 					SignFieldPosition: &SignFieldPosition{ | ||||
| 						PositionPage: "1", | ||||
| 						PositionPage: "8", | ||||
| 						PositionX:    200, | ||||
| 						PositionY:    200, | ||||
| 						PositionY:    430, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @@ -188,9 +188,9 @@ func (s *SignFlowService) buildPartyBSigner(fileID, signerAccount, signerName, t | ||||
| 					AutoSign:       false, | ||||
| 					SignFieldStyle: SignFieldStyleNormal, | ||||
| 					SignFieldPosition: &SignFieldPosition{ | ||||
| 						PositionPage: "1", | ||||
| 						PositionX:    458, | ||||
| 						PositionY:    200, | ||||
| 						PositionPage: "8", | ||||
| 						PositionX:    450, | ||||
| 						PositionY:    430, | ||||
| 					}, | ||||
| 				}, | ||||
| 			}, | ||||
| @@ -201,9 +201,9 @@ func (s *SignFlowService) buildPartyBSigner(fileID, signerAccount, signerName, t | ||||
| // buildSignFlowConfig 构建签署流程配置 | ||||
| func (s *SignFlowService) buildSignFlowConfig() SignFlowConfig { | ||||
| 	return SignFlowConfig{ | ||||
| 		SignFlowTitle:      "天远数据API合作协议签署", | ||||
| 		SignFlowExpireTime: calculateExpireTime(7), // 7天后过期 | ||||
| 		AutoFinish:         true,                   // 所有签署方完成后自动完结 | ||||
| 		SignFlowTitle:      s.config.Contract.Name, | ||||
| 		SignFlowExpireTime: calculateExpireTime(s.config.Contract.ExpireDays), | ||||
| 		AutoFinish:         s.config.Sign.AutoFinish, | ||||
| 		AuthConfig: &AuthConfig{ | ||||
| 			PsnAvailableAuthModes: []string{AuthModeMobile3}, | ||||
| 			WillingnessAuthModes:  []string{WillingnessAuthSMS}, | ||||
| @@ -211,5 +211,8 @@ func (s *SignFlowService) buildSignFlowConfig() SignFlowConfig { | ||||
| 		ContractConfig: &ContractConfig{ | ||||
| 			AllowToRescind: false, | ||||
| 		}, | ||||
| 		RedirectConfig: &RedirectConfig{ | ||||
| 			RedirectUrl: s.config.Sign.RedirectUrl, | ||||
| 		}, | ||||
| 	} | ||||
| }  | ||||
| @@ -36,7 +36,7 @@ func (s *TemplateService) Fill(components []Component) (*FillTemplate, error) { | ||||
| 	fmt.Println("开始填写模板生成文件...") | ||||
|  | ||||
| 	// 生成带时间戳的文件名 | ||||
| 	fileName := generateFileName("天远数据API合作协议", "pdf") | ||||
| 	fileName := generateFileName(s.config.Contract.Name, "pdf") | ||||
|  | ||||
| 	// 构建请求数据 | ||||
| 	requestData := FillTemplateRequest{ | ||||
|   | ||||
| @@ -5,10 +5,40 @@ import ( | ||||
| 	"crypto/md5" | ||||
| 	"crypto/sha256" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/hex" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"sort" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // appendSignDataString 拼接待签名字符串 | ||||
| func appendSignDataString(httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters string) string { | ||||
| 	if accept == "" { | ||||
| 		accept = "*/*" | ||||
| 	} | ||||
| 	if contentType == "" { | ||||
| 		contentType = "application/json; charset=UTF-8" | ||||
| 	} | ||||
| 	// 前四项 | ||||
| 	signStr := httpMethod + "\n" + accept + "\n" + contentMD5 + "\n" + contentType + "\n" | ||||
| 	// 处理 date | ||||
| 	if date == "" { | ||||
| 		signStr += "\n" | ||||
| 	} else { | ||||
| 		signStr += date + "\n" | ||||
| 	} | ||||
| 	// 处理 headers | ||||
| 	if headers == "" { | ||||
| 		signStr += pathAndParameters | ||||
| 	} else { | ||||
| 		signStr += headers + "\n" + pathAndParameters | ||||
| 	} | ||||
| 	return signStr | ||||
| } | ||||
|  | ||||
| // generateSignature 生成e签宝API请求签名 | ||||
| // 使用HMAC-SHA256算法对请求参数进行签名 | ||||
| // | ||||
| @@ -24,8 +54,8 @@ import ( | ||||
| // | ||||
| // 返回: Base64编码的签名字符串 | ||||
| func generateSignature(appSecret, httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters string) string { | ||||
| 	// 构建待签名字符串,按照e签宝API规范拼接 | ||||
| 	signStr := httpMethod + "\n" + accept + "\n" + contentMD5 + "\n" + contentType + "\n" + date + "\n" + headers + pathAndParameters | ||||
| 	// 构建待签名字符串,按照e签宝API规范拼接(兼容Python实现细节) | ||||
| 	signStr := appendSignDataString(httpMethod, accept, contentMD5, contentType, date, headers, pathAndParameters) | ||||
|  | ||||
| 	// 使用HMAC-SHA256计算签名 | ||||
| 	h := hmac.New(sha256.New, []byte(appSecret)) | ||||
| @@ -102,3 +132,66 @@ func generateFileName(baseName, extension string) string { | ||||
| func calculateExpireTime(days int) int64 { | ||||
| 	return time.Now().AddDate(0, 0, days).UnixMilli() | ||||
| } | ||||
|  | ||||
| // verifySignature 验证e签宝回调签名 | ||||
| func VerifySignature(callbackData interface{}, headers map[string]string, queryParams map[string]string, appSecret string) error { | ||||
| 	// 1. 获取签名相关参数 | ||||
| 	signature, ok := headers["X-Tsign-Open-Signature"] | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("缺少签名头: X-Tsign-Open-Signature") | ||||
| 	} | ||||
|  | ||||
| 	timestamp, ok := headers["X-Tsign-Open-Timestamp"] | ||||
| 	if !ok { | ||||
| 		return fmt.Errorf("缺少时间戳头: X-Tsign-Open-Timestamp") | ||||
| 	} | ||||
|  | ||||
| 	// 2. 构建查询参数字符串 | ||||
| 	var queryKeys []string | ||||
| 	for key := range queryParams { | ||||
| 		queryKeys = append(queryKeys, key) | ||||
| 	} | ||||
| 	sort.Strings(queryKeys) // 按ASCII码升序排序 | ||||
|  | ||||
| 	var queryValues []string | ||||
| 	for _, key := range queryKeys { | ||||
| 		queryValues = append(queryValues, queryParams[key]) | ||||
| 	} | ||||
| 	queryString := strings.Join(queryValues, "") | ||||
|  | ||||
| 	// 3. 获取请求体数据 | ||||
| 	bodyData, err := getRequestBodyString(callbackData) | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("获取请求体数据失败: %w", err) | ||||
| 	} | ||||
|  | ||||
| 	// 4. 构建验签数据 | ||||
| 	data := timestamp + queryString + bodyData | ||||
|  | ||||
| 	// 5. 计算签名 | ||||
| 	expectedSignature := calculateSignature(data, appSecret) | ||||
|  | ||||
| 	// 6. 比较签名 | ||||
| 	if strings.ToLower(expectedSignature) != strings.ToLower(signature) { | ||||
| 		return fmt.Errorf("签名验证失败") | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // calculateSignature 计算HMAC-SHA256签名 | ||||
| func calculateSignature(data, secret string) string { | ||||
| 	h := hmac.New(sha256.New, []byte(secret)) | ||||
| 	h.Write([]byte(data)) | ||||
| 	return strings.ToUpper(hex.EncodeToString(h.Sum(nil))) | ||||
| } | ||||
|  | ||||
| // getRequestBodyString 获取请求体字符串 | ||||
| func getRequestBodyString(callbackData interface{}) (string, error) { | ||||
| 	// 将map转换为JSON字符串 | ||||
| 	jsonBytes, err := json.Marshal(callbackData) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("JSON序列化失败: %w", err) | ||||
| 	} | ||||
| 	return string(jsonBytes), nil | ||||
| } | ||||
|   | ||||
| @@ -5,6 +5,7 @@ import ( | ||||
| 	"net/http" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"github.com/go-playground/validator/v10" | ||||
| ) | ||||
|  | ||||
| // HTTPHandler HTTP处理器接口 | ||||
| @@ -95,6 +96,11 @@ type RequestValidator interface { | ||||
|  | ||||
| 	// 绑定和验证 | ||||
| 	BindAndValidate(c *gin.Context, dto interface{}) error | ||||
|  | ||||
| 	// 业务逻辑验证方法 | ||||
| 	GetValidator() *validator.Validate | ||||
| 	ValidateValue(field interface{}, tag string) error | ||||
| 	ValidateStruct(s interface{}) error | ||||
| } | ||||
|  | ||||
| // PaginationMeta 分页元数据 | ||||
|   | ||||
							
								
								
									
										50
									
								
								internal/shared/middleware/api_auth.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								internal/shared/middleware/api_auth.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"tyapi-server/internal/config" | ||||
| 	"tyapi-server/internal/shared/interfaces" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"go.uber.org/zap" | ||||
| ) | ||||
|  | ||||
| // ApiAuthMiddleware API认证中间件 | ||||
| type ApiAuthMiddleware struct { | ||||
| 	config          *config.Config | ||||
| 	logger          *zap.Logger | ||||
| 	responseBuilder interfaces.ResponseBuilder | ||||
| } | ||||
|  | ||||
| // NewApiAuthMiddleware 创建API认证中间件 | ||||
| func NewApiAuthMiddleware(cfg *config.Config, logger *zap.Logger, responseBuilder interfaces.ResponseBuilder) *ApiAuthMiddleware { | ||||
| 	return &ApiAuthMiddleware{ | ||||
| 		config:          cfg, | ||||
| 		logger:          logger, | ||||
| 		responseBuilder: responseBuilder, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // GetName 返回中间件名称 | ||||
| func (m *ApiAuthMiddleware) GetName() string { | ||||
| 	return "api_auth" | ||||
| } | ||||
|  | ||||
| // GetPriority 返回中间件优先级 | ||||
| func (m *ApiAuthMiddleware) GetPriority() int { | ||||
| 	return 60 // 中等优先级,在日志之后,业务处理之前 | ||||
| } | ||||
|  | ||||
| // Handle 返回中间件处理函数 | ||||
| func (m *ApiAuthMiddleware) Handle() gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
| 		// 获取客户端IP地址,并存入上下文 | ||||
| 		clientIP := c.ClientIP() | ||||
| 		c.Set("client_ip", clientIP) | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // IsGlobal 是否为全局中间件 | ||||
| func (m *ApiAuthMiddleware) IsGlobal() bool { | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										72
									
								
								internal/shared/middleware/domain.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										72
									
								
								internal/shared/middleware/domain.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,72 @@ | ||||
| package middleware | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"tyapi-server/internal/config" | ||||
|  | ||||
| 	"github.com/gin-gonic/gin" | ||||
| 	"go.uber.org/zap" | ||||
| ) | ||||
|  | ||||
| // DomainAuthMiddleware 域名认证中间件 | ||||
| type DomainAuthMiddleware struct { | ||||
| 	config *config.Config | ||||
| 	logger *zap.Logger | ||||
| } | ||||
|  | ||||
| // NewDomainAuthMiddleware 创建域名认证中间件 | ||||
| func NewDomainAuthMiddleware(cfg *config.Config, logger *zap.Logger) *DomainAuthMiddleware { | ||||
| 	return &DomainAuthMiddleware{ | ||||
| 		config: cfg, | ||||
| 		logger: logger, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // GetName 返回中间件名称 | ||||
| func (m *DomainAuthMiddleware) GetName() string { | ||||
| 	return "domain_auth" | ||||
| } | ||||
|  | ||||
| // GetPriority 返回中间件优先级 | ||||
| func (m *DomainAuthMiddleware) GetPriority() int { | ||||
| 	return 100 | ||||
| } | ||||
|  | ||||
| // Handle 返回中间件处理函数 | ||||
| func (m *DomainAuthMiddleware) Handle(domain string) gin.HandlerFunc { | ||||
| 	return func(c *gin.Context) { | ||||
|  | ||||
| 		// 开发环境下跳过外部验证 | ||||
| 		if m.config.App.IsDevelopment() { | ||||
| 			m.logger.Info("开发环境:跳过域名验证", | ||||
| 				zap.String("domain", domain)) | ||||
| 			c.Next() | ||||
| 		} | ||||
| 		if domain == "" { | ||||
| 			domain = m.config.API.Domain | ||||
| 		} | ||||
| 		host := c.Request.Host | ||||
|  | ||||
| 		// 移除端口部分 | ||||
| 		if idx := strings.Index(host, ":"); idx != -1 { | ||||
| 			host = host[:idx] | ||||
| 		} | ||||
|  | ||||
| 		if host == domain { | ||||
| 			// 设置域名匹配标记 | ||||
| 			c.Set("domainMatched", domain) | ||||
| 			c.Next() | ||||
| 		} else { | ||||
| 			// 不匹配域名,跳过当前组处理,继续执行其他路由 | ||||
| 			c.Abort() | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		c.Next() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // IsGlobal 是否为全局中间件 | ||||
| func (m *DomainAuthMiddleware) IsGlobal() bool { | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										282
									
								
								internal/shared/payment/alipay.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										282
									
								
								internal/shared/payment/alipay.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,282 @@ | ||||
| package payment | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/hex" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"sync/atomic" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/shopspring/decimal" | ||||
| 	"github.com/smartwalle/alipay/v3" | ||||
| ) | ||||
|  | ||||
| type AlipayConfig struct { | ||||
| 	AppID           string | ||||
| 	PrivateKey      string | ||||
| 	AlipayPublicKey string | ||||
| 	IsProduction    bool | ||||
| 	NotifyUrl       string | ||||
| 	ReturnURL       string // 同步回调地址 | ||||
| } | ||||
| type AliPayService struct { | ||||
| 	config       AlipayConfig | ||||
| 	AlipayClient *alipay.Client | ||||
| } | ||||
|  | ||||
| // NewAliPayService 是一个构造函数,用于初始化 AliPayService | ||||
| func NewAliPayService(config AlipayConfig) *AliPayService { | ||||
| 	client, err := alipay.New(config.AppID, config.PrivateKey, config.IsProduction) | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Sprintf("创建支付宝客户端失败: %v", err)) | ||||
| 	} | ||||
|  | ||||
| 	// 加载支付宝公钥 | ||||
| 	err = client.LoadAliPayPublicKey(config.AlipayPublicKey) | ||||
| 	if err != nil { | ||||
| 		panic(fmt.Sprintf("加载支付宝公钥失败: %v", err)) | ||||
| 	} | ||||
| 	return &AliPayService{ | ||||
| 		config:       config, | ||||
| 		AlipayClient: client, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (a *AliPayService) CreateAlipayAppOrder(amount decimal.Decimal, subject string, outTradeNo string) (string, error) { | ||||
| 	client := a.AlipayClient | ||||
| 	totalAmount := amount.StringFixed(2) // 保留2位小数 | ||||
| 	// 构造移动支付请求 | ||||
| 	p := alipay.TradeAppPay{ | ||||
| 		Trade: alipay.Trade{ | ||||
| 			Subject:     subject, | ||||
| 			OutTradeNo:  outTradeNo, | ||||
| 			TotalAmount: totalAmount, | ||||
| 			ProductCode: "QUICK_MSECURITY_PAY", // 移动端支付专用代码 | ||||
| 			NotifyURL:   a.config.NotifyUrl,    // 异步回调通知地址 | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// 获取APP支付字符串,这里会签名 | ||||
| 	payStr, err := client.TradeAppPay(p) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("创建支付宝订单失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return payStr, nil | ||||
| } | ||||
|  | ||||
| // CreateAlipayH5Order 创建支付宝H5支付订单 | ||||
| func (a *AliPayService) CreateAlipayH5Order(amount decimal.Decimal, subject string, outTradeNo string) (string, error) { | ||||
| 	client := a.AlipayClient | ||||
| 	totalAmount := amount.StringFixed(2) // 保留2位小数 | ||||
| 	// 构造H5支付请求 | ||||
| 	p := alipay.TradeWapPay{ | ||||
| 		Trade: alipay.Trade{ | ||||
| 			Subject:     subject, | ||||
| 			OutTradeNo:  outTradeNo, | ||||
| 			TotalAmount: totalAmount, | ||||
| 			ProductCode: "QUICK_WAP_PAY",    // H5支付专用产品码 | ||||
| 			NotifyURL:   a.config.NotifyUrl, // 异步回调通知地址 | ||||
| 			ReturnURL:   a.config.ReturnURL, | ||||
| 		}, | ||||
| 	} | ||||
| 	// 获取H5支付请求字符串,这里会签名 | ||||
| 	payUrl, err := client.TradeWapPay(p) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("创建支付宝H5订单失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return payUrl.String(), nil | ||||
| } | ||||
|  | ||||
| // CreateAlipayPCOrder 创建支付宝PC端支付订单 | ||||
| func (a *AliPayService) CreateAlipayPCOrder(amount decimal.Decimal, subject string, outTradeNo string) (string, error) { | ||||
| 	client := a.AlipayClient | ||||
| 	totalAmount := amount.StringFixed(2) // 保留2位小数 | ||||
|  | ||||
| 	// 构造PC端支付请求 | ||||
| 	p := alipay.TradePagePay{ | ||||
| 		Trade: alipay.Trade{ | ||||
| 			Subject:     subject, | ||||
| 			OutTradeNo:  outTradeNo, | ||||
| 			TotalAmount: totalAmount, | ||||
| 			ProductCode: "FAST_INSTANT_TRADE_PAY", // PC端支付专用产品码 | ||||
| 			NotifyURL:   a.config.NotifyUrl,       // 异步回调通知地址 | ||||
| 			ReturnURL:   a.config.ReturnURL,       // 同步回调地址 | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// 获取PC端支付URL,这里会签名 | ||||
| 	payUrl, err := client.TradePagePay(p) | ||||
| 	if err != nil { | ||||
| 		return "", fmt.Errorf("创建支付宝PC端订单失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return payUrl.String(), nil | ||||
| } | ||||
|  | ||||
| // CreateAlipayOrder 根据平台类型创建支付宝支付订单 | ||||
| func (a *AliPayService) CreateAlipayOrder(ctx context.Context, platform string, amount decimal.Decimal, subject string, outTradeNo string) (string, error) { | ||||
| 	switch platform { | ||||
| 	case "app": | ||||
| 		// 调用App支付的创建方法 | ||||
| 		return a.CreateAlipayAppOrder(amount, subject, outTradeNo) | ||||
| 	case "h5": | ||||
| 		// 调用H5支付的创建方法,并传入 returnUrl | ||||
| 		return a.CreateAlipayH5Order(amount, subject, outTradeNo) | ||||
| 	case "pc": | ||||
| 		// 调用PC端支付的创建方法 | ||||
| 		return a.CreateAlipayPCOrder(amount, subject, outTradeNo) | ||||
| 	default: | ||||
| 		return "", fmt.Errorf("不支持的支付平台: %s", platform) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // AliRefund 发起支付宝退款 | ||||
| func (a *AliPayService) AliRefund(ctx context.Context, outTradeNo string, refundAmount decimal.Decimal) (*alipay.TradeRefundRsp, error) { | ||||
| 	refund := alipay.TradeRefund{ | ||||
| 		OutTradeNo:   outTradeNo, | ||||
| 		RefundAmount: refundAmount.StringFixed(2), // 保留2位小数 | ||||
| 		OutRequestNo: fmt.Sprintf("%s-refund", outTradeNo), | ||||
| 	} | ||||
|  | ||||
| 	// 发起退款请求 | ||||
| 	refundResp, err := a.AlipayClient.TradeRefund(ctx, refund) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("支付宝退款请求错误:%v", err) | ||||
| 	} | ||||
| 	return refundResp, nil | ||||
| } | ||||
|  | ||||
| // HandleAliPaymentNotification 支付宝支付回调 | ||||
| func (a *AliPayService) HandleAliPaymentNotification(r *http.Request) (*alipay.Notification, error) { | ||||
| 	// 解析表单 | ||||
| 	err := r.ParseForm() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("解析请求表单失败:%v", err) | ||||
| 	} | ||||
| 	// 解析并验证通知,DecodeNotification 会自动验证签名 | ||||
| 	notification, err := a.AlipayClient.DecodeNotification(r.Form) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("验证签名失败: %v", err) | ||||
| 	} | ||||
| 	return notification, nil | ||||
| } | ||||
|  | ||||
| func (a *AliPayService) IsAlipayPaymentSuccess(notification *alipay.Notification) bool { | ||||
| 	return notification.TradeStatus == alipay.TradeStatusSuccess | ||||
| } | ||||
|  | ||||
| func (a *AliPayService) QueryOrderStatus(ctx context.Context, outTradeNo string) (*alipay.TradeQueryRsp, error) { | ||||
| 	queryRequest := alipay.TradeQuery{ | ||||
| 		OutTradeNo: outTradeNo, | ||||
| 	} | ||||
|  | ||||
| 	// 发起查询请求 | ||||
| 	resp, err := a.AlipayClient.TradeQuery(ctx, queryRequest) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("查询支付宝订单失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 返回交易状态 | ||||
| 	if resp.IsSuccess() { | ||||
| 		return resp, nil | ||||
| 	} | ||||
|  | ||||
| 	return nil, fmt.Errorf("查询支付宝订单失败: %v", resp.SubMsg) | ||||
| } | ||||
|  | ||||
| // 添加全局原子计数器 | ||||
| var alipayOrderCounter uint32 = 0 | ||||
|  | ||||
| // GenerateOutTradeNo 生成唯一订单号的函数 - 优化版本 | ||||
| func (a *AliPayService) GenerateOutTradeNo() string { | ||||
|  | ||||
| 	// 获取当前时间戳(毫秒级) | ||||
| 	timestamp := time.Now().UnixMilli() | ||||
| 	timeStr := strconv.FormatInt(timestamp, 10) | ||||
|  | ||||
| 	// 原子递增计数器 | ||||
| 	counter := atomic.AddUint32(&alipayOrderCounter, 1) | ||||
|  | ||||
| 	// 生成4字节真随机数 | ||||
| 	randomBytes := make([]byte, 4) | ||||
| 	_, err := rand.Read(randomBytes) | ||||
| 	if err != nil { | ||||
| 		// 如果随机数生成失败,回退到使用时间纳秒数据 | ||||
| 		randomBytes = []byte(strconv.FormatInt(time.Now().UnixNano()%1000000, 16)) | ||||
| 	} | ||||
| 	randomHex := hex.EncodeToString(randomBytes) | ||||
|  | ||||
| 	// 组合所有部分: 前缀 + 时间戳 + 计数器 + 随机数 | ||||
| 	orderNo := fmt.Sprintf("%s%06x%s", timeStr[:10], counter%0xFFFFFF, randomHex[:6]) | ||||
|  | ||||
| 	// 确保长度不超过32字符(大多数支付平台的限制) | ||||
| 	if len(orderNo) > 32 { | ||||
| 		orderNo = orderNo[:32] | ||||
| 	} | ||||
|  | ||||
| 	return orderNo | ||||
| } | ||||
|  | ||||
| // AliTransfer 支付宝单笔转账到支付宝账户(提现功能) | ||||
| func (a *AliPayService) AliTransfer( | ||||
| 	ctx context.Context, | ||||
| 	payeeAccount string, // 收款方支付宝账户 | ||||
| 	payeeName string, // 收款方姓名 | ||||
| 	amount decimal.Decimal, // 转账金额 | ||||
| 	remark string, // 转账备注 | ||||
| 	outBizNo string, // 商户转账唯一订单号(可使用GenerateOutTradeNo生成) | ||||
| ) (*alipay.FundTransUniTransferRsp, error) { | ||||
| 	// 参数校验 | ||||
| 	if payeeAccount == "" { | ||||
| 		return nil, fmt.Errorf("收款账户不能为空") | ||||
| 	} | ||||
| 	if amount.LessThanOrEqual(decimal.Zero) { | ||||
| 		return nil, fmt.Errorf("转账金额必须大于0") | ||||
| 	} | ||||
|  | ||||
| 	// 构造转账请求 | ||||
| 	req := alipay.FundTransUniTransfer{ | ||||
| 		OutBizNo:    outBizNo, | ||||
| 		TransAmount: amount.StringFixed(2),  // 保留2位小数 | ||||
| 		ProductCode: "TRANS_ACCOUNT_NO_PWD", // 单笔无密转账到支付宝账户 | ||||
| 		BizScene:    "DIRECT_TRANSFER",      // 单笔转账 | ||||
| 		OrderTitle:  "账户提现",                 // 转账标题 | ||||
| 		Remark:      remark, | ||||
| 		PayeeInfo: &alipay.PayeeInfo{ | ||||
| 			Identity:     payeeAccount, | ||||
| 			IdentityType: "ALIPAY_LOGON_ID", // 根据账户类型选择: | ||||
| 			Name:         payeeName, | ||||
| 			// ALIPAY_USER_ID/ALIPAY_LOGON_ID | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	// 执行转账请求 | ||||
| 	transferRsp, err := a.AlipayClient.FundTransUniTransfer(ctx, req) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("支付宝转账请求失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return transferRsp, nil | ||||
| } | ||||
| func (a *AliPayService) QueryTransferStatus( | ||||
| 	ctx context.Context, | ||||
| 	outBizNo string, | ||||
| ) (*alipay.FundTransOrderQueryRsp, error) { | ||||
| 	req := alipay.FundTransOrderQuery{ | ||||
| 		OutBizNo: outBizNo, | ||||
| 	} | ||||
| 	response, err := a.AlipayClient.FundTransOrderQuery(ctx, req) | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("支付宝接口调用失败: %v", err) | ||||
| 	} | ||||
| 	// 处理响应 | ||||
| 	if response.Code.IsFailure() { | ||||
| 		return nil, fmt.Errorf("支付宝返回错误: %s-%s", response.Code, response.Msg) | ||||
| 	} | ||||
| 	return response, nil | ||||
| } | ||||
							
								
								
									
										246
									
								
								internal/shared/validator/AUTH_DATE_VALIDATOR.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										246
									
								
								internal/shared/validator/AUTH_DATE_VALIDATOR.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,246 @@ | ||||
| # AuthDate 授权日期验证器 | ||||
|  | ||||
| ## 概述 | ||||
|  | ||||
| `authDate` 是一个自定义验证器,用于验证授权日期格式和有效性。该验证器确保日期格式正确,并且日期范围必须包括今天。 | ||||
|  | ||||
| ## 验证规则 | ||||
|  | ||||
| ### 1. 格式要求 | ||||
| - 必须为 `YYYYMMDD-YYYYMMDD` 格式 | ||||
| - 两个日期之间用连字符 `-` 分隔 | ||||
| - 每个日期必须是8位数字 | ||||
|  | ||||
| ### 2. 日期有效性 | ||||
| - 开始日期不能晚于结束日期 | ||||
| - 日期范围必须包括今天(如果两个日期都是今天也行) | ||||
| - 支持闰年验证 | ||||
| - 验证月份和日期的有效性 | ||||
|  | ||||
| ### 3. 业务逻辑 | ||||
| - 空值由 `required` 标签处理,本验证器返回 `true` | ||||
| - 日期范围必须覆盖今天,确保授权在有效期内 | ||||
|  | ||||
| ## 使用示例 | ||||
|  | ||||
| ### 在 DTO 中使用 | ||||
|  | ||||
| ```go | ||||
| type FLXG0V4BReq struct { | ||||
|     Name     string `json:"name" validate:"required,name"` | ||||
|     IDCard   string `json:"id_card" validate:"required,idCard"` | ||||
|     AuthDate string `json:"auth_date" validate:"required,authDate"` | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ### 在结构体中使用 | ||||
|  | ||||
| ```go | ||||
| type AuthorizationRequest struct { | ||||
|     UserID   string `json:"user_id" validate:"required"` | ||||
|     AuthDate string `json:"auth_date" validate:"required,authDate"` | ||||
|     Scope    string `json:"scope" validate:"required"` | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 有效示例 | ||||
|  | ||||
| ### ✅ 有效的日期范围 | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240101-20240131"  // 1月1日到1月31日(如果今天是1月15日) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240115-20240115"  // 今天到今天 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240110-20240120"  // 昨天到明天(如果今天是1月15日) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240101-20240201"  // 上个月到下个月(如果今天是1月15日) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ### ❌ 无效的日期范围 | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240116-20240120"  // 明天到后天(不包括今天) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240101-20240114"  // 上个月到昨天(不包括今天) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240131-20240101"  // 开始日期晚于结束日期 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240101-2024013A"  // 非数字字符 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "202401-20240131"    // 日期长度不对 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "2024010120240131"   // 缺少连字符 | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ```json | ||||
| { | ||||
|     "auth_date": "20240230-20240301"  // 无效日期(2月30日) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 错误消息 | ||||
|  | ||||
| 当验证失败时,会返回中文错误消息: | ||||
|  | ||||
| ``` | ||||
| "授权日期格式不正确,必须是YYYYMMDD-YYYYMMDD格式,且日期范围必须包括今天" | ||||
| ``` | ||||
|  | ||||
| ## 测试用例 | ||||
|  | ||||
| 验证器包含完整的测试用例,覆盖以下场景: | ||||
|  | ||||
| ### 有效场景 | ||||
| - 今天到今天 | ||||
| - 昨天到今天 | ||||
| - 今天到明天 | ||||
| - 上周到今天 | ||||
| - 今天到下周 | ||||
| - 昨天到明天 | ||||
|  | ||||
| ### 无效场景 | ||||
| - 明天到后天(不包括今天) | ||||
| - 上周到昨天(不包括今天) | ||||
| - 格式错误(缺少连字符、多个连字符、长度不对、非数字) | ||||
| - 无效日期(2月30日、13月等) | ||||
| - 开始日期晚于结束日期 | ||||
|  | ||||
| ## 实现细节 | ||||
|  | ||||
| ### 核心验证逻辑 | ||||
|  | ||||
| ```go | ||||
| func validateAuthDate(fl validator.FieldLevel) bool { | ||||
|     authDate := fl.Field().String() | ||||
|     if authDate == "" { | ||||
|         return true // 空值由required标签处理 | ||||
|     } | ||||
|  | ||||
|     // 1. 检查格式:YYYYMMDD-YYYYMMDD | ||||
|     parts := strings.Split(authDate, "-") | ||||
|     if len(parts) != 2 { | ||||
|         return false | ||||
|     } | ||||
|  | ||||
|     // 2. 解析日期 | ||||
|     startDate, err := parseYYYYMMDD(parts[0]) | ||||
|     if err != nil { | ||||
|         return false | ||||
|     } | ||||
|  | ||||
|     endDate, err := parseYYYYMMDD(parts[1]) | ||||
|     if err != nil { | ||||
|         return false | ||||
|     } | ||||
|  | ||||
|     // 3. 检查日期顺序 | ||||
|     if startDate.After(endDate) { | ||||
|         return false | ||||
|     } | ||||
|  | ||||
|     // 4. 检查是否包括今天 | ||||
|     today := time.Now().Truncate(24 * time.Hour) | ||||
|     return !startDate.After(today) && !endDate.Before(today) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ### 日期解析 | ||||
|  | ||||
| ```go | ||||
| func parseYYYYMMDD(dateStr string) (time.Time, error) { | ||||
|     if len(dateStr) != 8 { | ||||
|         return time.Time{}, fmt.Errorf("日期格式错误") | ||||
|     } | ||||
|  | ||||
|     year, _ := strconv.Atoi(dateStr[:4]) | ||||
|     month, _ := strconv.Atoi(dateStr[4:6]) | ||||
|     day, _ := strconv.Atoi(dateStr[6:8]) | ||||
|  | ||||
|     date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC) | ||||
|      | ||||
|     // 验证日期有效性 | ||||
|     expectedDateStr := date.Format("20060102") | ||||
|     if expectedDateStr != dateStr { | ||||
|         return time.Time{}, fmt.Errorf("无效日期") | ||||
|     } | ||||
|  | ||||
|     return date, nil | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## 注册方式 | ||||
|  | ||||
| 验证器已在 `RegisterCustomValidators` 函数中自动注册: | ||||
|  | ||||
| ```go | ||||
| func RegisterCustomValidators(validate *validator.Validate) { | ||||
|     // ... 其他验证器 | ||||
|     validate.RegisterValidation("auth_date", validateAuthDate) | ||||
| } | ||||
| ``` | ||||
|  | ||||
| 翻译也已自动注册: | ||||
|  | ||||
| ```go | ||||
| validate.RegisterTranslation("auth_date", trans, func(ut ut.Translator) error { | ||||
|     return ut.Add("auth_date", "{0}格式不正确,必须是YYYYMMDD-YYYYMMDD格式,且日期范围必须包括今天", true) | ||||
| }, func(ut ut.Translator, fe validator.FieldError) string { | ||||
|     t, _ := ut.T("auth_date", getFieldDisplayName(fe.Field())) | ||||
|     return t | ||||
| }) | ||||
| ``` | ||||
|  | ||||
| ## 注意事项 | ||||
|  | ||||
| 1. **时区处理**:验证器使用 UTC 时区进行日期比较 | ||||
| 2. **空值处理**:空字符串由 `required` 标签处理,本验证器返回 `true` | ||||
| 3. **日期精度**:只比较日期部分,忽略时间部分 | ||||
| 4. **闰年支持**:自动处理闰年验证 | ||||
| 5. **错误消息**:提供中文错误消息,便于用户理解 | ||||
|  | ||||
| ## 运行测试 | ||||
|  | ||||
| ```bash | ||||
| # 运行所有 authDate 相关测试 | ||||
| go test ./internal/shared/validator -v -run TestValidateAuthDate | ||||
|  | ||||
| # 运行所有验证器测试 | ||||
| go test ./internal/shared/validator -v | ||||
| ```  | ||||
							
								
								
									
										180
									
								
								internal/shared/validator/auth_date_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										180
									
								
								internal/shared/validator/auth_date_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,180 @@ | ||||
| package validator | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-playground/validator/v10" | ||||
| ) | ||||
|  | ||||
| func TestValidateAuthDate(t *testing.T) { | ||||
| 	validate := validator.New() | ||||
| 	validate.RegisterValidation("auth_date", validateAuthDate) | ||||
|  | ||||
| 	today := time.Now().Format("20060102") | ||||
| 	yesterday := time.Now().AddDate(0, 0, -1).Format("20060102") | ||||
| 	tomorrow := time.Now().AddDate(0, 0, 1).Format("20060102") | ||||
| 	lastWeek := time.Now().AddDate(0, 0, -7).Format("20060102") | ||||
| 	nextWeek := time.Now().AddDate(0, 0, 7).Format("20060102") | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		authDate string | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:     "今天到今天 - 有效", | ||||
| 			authDate: today + "-" + today, | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "昨天到今天 - 有效", | ||||
| 			authDate: yesterday + "-" + today, | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "今天到明天 - 有效", | ||||
| 			authDate: today + "-" + tomorrow, | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "上周到今天 - 有效", | ||||
| 			authDate: lastWeek + "-" + today, | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "今天到下周 - 有效", | ||||
| 			authDate: today + "-" + nextWeek, | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "昨天到明天 - 有效", | ||||
| 			authDate: yesterday + "-" + tomorrow, | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "明天到后天 - 无效(不包括今天)", | ||||
| 			authDate: tomorrow + "-" + time.Now().AddDate(0, 0, 2).Format("20060102"), | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "上周到昨天 - 无效(不包括今天)", | ||||
| 			authDate: lastWeek + "-" + yesterday, | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "格式错误 - 缺少连字符", | ||||
| 			authDate: "2024010120240131", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "格式错误 - 多个连字符", | ||||
| 			authDate: "20240101-20240131-20240201", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "格式错误 - 日期长度不对", | ||||
| 			authDate: "202401-20240131", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "格式错误 - 非数字", | ||||
| 			authDate: "20240101-2024013A", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "无效日期 - 2月30日", | ||||
| 			authDate: "20240230-20240301", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "无效日期 - 13月", | ||||
| 			authDate: "20241301-20241331", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "开始日期晚于结束日期", | ||||
| 			authDate: "20240131-20240101", | ||||
| 			wantErr:  true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "空字符串 - 由required处理", | ||||
| 			authDate: "", | ||||
| 			wantErr:  false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			err := validate.Var(tt.authDate, "auth_date") | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("validateAuthDate() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestParseYYYYMMDD(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name    string | ||||
| 		dateStr string | ||||
| 		want    time.Time | ||||
| 		wantErr bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:    "有效日期", | ||||
| 			dateStr: "20240101", | ||||
| 			want:    time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "闰年2月29日", | ||||
| 			dateStr: "20240229", | ||||
| 			want:    time.Date(2024, 2, 29, 0, 0, 0, 0, time.UTC), | ||||
| 			wantErr: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "非闰年2月29日", | ||||
| 			dateStr: "20230229", | ||||
| 			want:    time.Time{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "长度错误", | ||||
| 			dateStr: "202401", | ||||
| 			want:    time.Time{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "非数字", | ||||
| 			dateStr: "2024010A", | ||||
| 			want:    time.Time{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "无效月份", | ||||
| 			dateStr: "20241301", | ||||
| 			want:    time.Time{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:    "无效日期", | ||||
| 			dateStr: "20240230", | ||||
| 			want:    time.Time{}, | ||||
| 			wantErr: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			got, err := parseYYYYMMDD(tt.dateStr) | ||||
| 			if (err != nil) != tt.wantErr { | ||||
| 				t.Errorf("parseYYYYMMDD() error = %v, wantErr %v", err, tt.wantErr) | ||||
| 				return | ||||
| 			} | ||||
| 			if !tt.wantErr && !got.Equal(tt.want) { | ||||
| 				t.Errorf("parseYYYYMMDD() = %v, want %v", got, tt.want) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| }  | ||||
| @@ -1,8 +1,12 @@ | ||||
| package validator | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/url" | ||||
| 	"regexp" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"github.com/go-playground/validator/v10" | ||||
| ) | ||||
| @@ -38,6 +42,18 @@ func RegisterCustomValidators(validate *validator.Validate) { | ||||
| 	 | ||||
| 	// URL验证器 | ||||
| 	validate.RegisterValidation("url", validateURL) | ||||
| 	 | ||||
| 	// 企业邮箱验证器 | ||||
| 	validate.RegisterValidation("enterprise_email", validateEnterpriseEmail) | ||||
| 	 | ||||
| 	// 企业地址验证器 | ||||
| 	validate.RegisterValidation("enterprise_address", validateEnterpriseAddress) | ||||
| 	 | ||||
| 	// IP地址验证器 | ||||
| 	validate.RegisterValidation("ip", validateIP) | ||||
| 	 | ||||
| 	// 授权日期验证器 | ||||
| 	validate.RegisterValidation("auth_date", validateAuthDate) | ||||
| } | ||||
|  | ||||
| // validatePhone 手机号验证 | ||||
| @@ -111,4 +127,113 @@ func validateURL(fl validator.FieldLevel) bool { | ||||
| 	urlStr := fl.Field().String() | ||||
| 	_, err := url.ParseRequestURI(urlStr) | ||||
| 	return err == nil | ||||
| } | ||||
|  | ||||
| // validateEnterpriseEmail 企业邮箱验证 | ||||
| func validateEnterpriseEmail(fl validator.FieldLevel) bool { | ||||
| 	email := fl.Field().String() | ||||
| 	// 邮箱格式验证:用户名@域名.顶级域名 | ||||
| 	matched, _ := regexp.MatchString(`^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$`, email) | ||||
| 	return matched | ||||
| } | ||||
|  | ||||
| // validateEnterpriseAddress 企业地址验证 | ||||
| func validateEnterpriseAddress(fl validator.FieldLevel) bool { | ||||
| 	address := fl.Field().String() | ||||
| 	// 地址长度验证:2-200字符,不能只包含空格 | ||||
| 	if len(strings.TrimSpace(address)) < 2 || len(address) > 200 { | ||||
| 		return false | ||||
| 	} | ||||
| 	// 地址不能只包含特殊字符 | ||||
| 	matched, _ := regexp.MatchString(`^[^\s]+.*[^\s]+$`, strings.TrimSpace(address)) | ||||
| 	return matched | ||||
| } | ||||
|  | ||||
| // validateIP IP地址验证(支持IPv4) | ||||
| func validateIP(fl validator.FieldLevel) bool { | ||||
| 	ip := fl.Field().String() | ||||
| 	// 使用正则表达式验证IPv4格式 | ||||
| 	pattern := `^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$` | ||||
| 	matched, _ := regexp.MatchString(pattern, ip) | ||||
| 	return matched | ||||
| }  | ||||
|  | ||||
| // validateAuthDate 授权日期验证器 | ||||
| // 格式:YYYYMMDD-YYYYMMDD,之前的日期范围必须包括今天 | ||||
| func validateAuthDate(fl validator.FieldLevel) bool { | ||||
| 	authDate := fl.Field().String() | ||||
| 	if authDate == "" { | ||||
| 		return true // 空值由required标签处理 | ||||
| 	} | ||||
|  | ||||
| 	// 检查格式:YYYYMMDD-YYYYMMDD | ||||
| 	parts := strings.Split(authDate, "-") | ||||
| 	if len(parts) != 2 { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	startDateStr := parts[0] | ||||
| 	endDateStr := parts[1] | ||||
|  | ||||
| 	// 检查日期格式是否为8位数字 | ||||
| 	if len(startDateStr) != 8 || len(endDateStr) != 8 { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 解析开始日期 | ||||
| 	startDate, err := parseYYYYMMDD(startDateStr) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 解析结束日期 | ||||
| 	endDate, err := parseYYYYMMDD(endDateStr) | ||||
| 	if err != nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 检查开始日期不能晚于结束日期 | ||||
| 	if startDate.After(endDate) { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 获取今天的日期(去掉时间部分) | ||||
| 	today := time.Now().Truncate(24 * time.Hour) | ||||
|  | ||||
| 	// 检查日期范围是否包括今天 | ||||
| 	// 如果两个日期都是今天也行 | ||||
| 	return !startDate.After(today) && !endDate.Before(today) | ||||
| } | ||||
|  | ||||
| // parseYYYYMMDD 解析YYYYMMDD格式的日期字符串 | ||||
| func parseYYYYMMDD(dateStr string) (time.Time, error) { | ||||
| 	if len(dateStr) != 8 { | ||||
| 		return time.Time{}, fmt.Errorf("日期格式错误") | ||||
| 	} | ||||
|  | ||||
| 	year, err := strconv.Atoi(dateStr[:4]) | ||||
| 	if err != nil { | ||||
| 		return time.Time{}, err | ||||
| 	} | ||||
|  | ||||
| 	month, err := strconv.Atoi(dateStr[4:6]) | ||||
| 	if err != nil { | ||||
| 		return time.Time{}, err | ||||
| 	} | ||||
|  | ||||
| 	day, err := strconv.Atoi(dateStr[6:8]) | ||||
| 	if err != nil { | ||||
| 		return time.Time{}, err | ||||
| 	} | ||||
|  | ||||
| 	// 验证日期有效性 | ||||
| 	date := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.UTC) | ||||
| 	 | ||||
| 	// 检查解析后的日期是否与输入一致(防止无效日期如20230230) | ||||
| 	expectedDateStr := date.Format("20060102") | ||||
| 	if expectedDateStr != dateStr { | ||||
| 		return time.Time{}, fmt.Errorf("无效日期") | ||||
| 	} | ||||
|  | ||||
| 	return date, nil | ||||
| }  | ||||
| @@ -162,6 +162,38 @@ func registerCustomFieldTranslations(validate *validator.Validate, trans ut.Tran | ||||
| 		t, _ := ut.T("url", getFieldDisplayName(fe.Field())) | ||||
| 		return t | ||||
| 	}) | ||||
|  | ||||
| 	// 企业邮箱翻译 | ||||
| 	validate.RegisterTranslation("enterprise_email", trans, func(ut ut.Translator) error { | ||||
| 		return ut.Add("enterprise_email", "{0}必须是有效的企业邮箱地址", true) | ||||
| 	}, func(ut ut.Translator, fe validator.FieldError) string { | ||||
| 		t, _ := ut.T("enterprise_email", getFieldDisplayName(fe.Field())) | ||||
| 		return t | ||||
| 	}) | ||||
|  | ||||
| 	// 企业地址翻译 | ||||
| 	validate.RegisterTranslation("enterprise_address", trans, func(ut ut.Translator) error { | ||||
| 		return ut.Add("enterprise_address", "{0}长度必须在2-200字符之间,且不能只包含空格", true) | ||||
| 	}, func(ut ut.Translator, fe validator.FieldError) string { | ||||
| 		t, _ := ut.T("enterprise_address", getFieldDisplayName(fe.Field())) | ||||
| 		return t | ||||
| 	}) | ||||
|  | ||||
| 	// IP地址翻译 | ||||
| 	validate.RegisterTranslation("ip", trans, func(ut ut.Translator) error { | ||||
| 		return ut.Add("ip", "{0}必须是有效的IPv4地址格式", true) | ||||
| 	}, func(ut ut.Translator, fe validator.FieldError) string { | ||||
| 		t, _ := ut.T("ip", getFieldDisplayName(fe.Field())) | ||||
| 		return t | ||||
| 	}) | ||||
|  | ||||
| 	// 授权日期翻译 | ||||
| 	validate.RegisterTranslation("auth_date", trans, func(ut ut.Translator) error { | ||||
| 		return ut.Add("auth_date", "{0}格式不正确,必须是YYYYMMDD-YYYYMMDD格式,且日期范围必须包括今天", true) | ||||
| 	}, func(ut ut.Translator, fe validator.FieldError) string { | ||||
| 		t, _ := ut.T("auth_date", getFieldDisplayName(fe.Field())) | ||||
| 		return t | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // getFieldDisplayName 获取字段显示名称(中文) | ||||
| @@ -176,6 +208,9 @@ func getFieldDisplayName(field string) string { | ||||
| 		"code":                 "验证码", | ||||
| 		"username":             "用户名", | ||||
| 		"email":                "邮箱", | ||||
| 		"enterprise_email":     "企业邮箱", | ||||
| 		"enterprise_address":   "企业地址", | ||||
| 		"ip_address":           "IP地址", | ||||
| 		"display_name":         "显示名称", | ||||
| 		"scene":                "使用场景", | ||||
| 		"Password":             "密码", | ||||
| @@ -244,6 +279,10 @@ func getFieldDisplayName(field string) string { | ||||
| 		"ID":                   "ID", | ||||
| 		"ids":                  "ID列表", | ||||
| 		"IDs":                  "ID列表", | ||||
| 		"auth_date":            "授权日期", | ||||
| 		"AuthDate":             "授权日期", | ||||
| 		"id_card":              "身份证号", | ||||
| 		"IDCard":               "身份证号", | ||||
| 	} | ||||
|  | ||||
| 	if displayName, exists := fieldNames[field]; exists { | ||||
|   | ||||
| @@ -65,7 +65,7 @@ func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 	 | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -80,7 +80,7 @@ func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
| 	 | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| @@ -213,4 +213,4 @@ func (v *RequestValidator) ValidateValue(field interface{}, tag string) error { | ||||
| // ValidateStruct 验证结构体(用于业务逻辑) | ||||
| func (v *RequestValidator) ValidateStruct(s interface{}) error { | ||||
| 	return v.validator.Struct(s) | ||||
| }  | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user