package cache import ( "context" "crypto/md5" "encoding/hex" "encoding/json" "fmt" "reflect" "strings" "sync" "time" "go.uber.org/zap" "gorm.io/gorm" "tyapi-server/internal/shared/interfaces" ) // GormCachePlugin GORM缓存插件 type GormCachePlugin struct { cache interfaces.CacheService logger *zap.Logger config CacheConfig // 缓存失效去重机制 invalidationQueue map[string]*time.Timer queueMutex sync.RWMutex } // CacheConfig 缓存配置 type CacheConfig struct { // 基础配置 DefaultTTL time.Duration `json:"default_ttl"` // 默认TTL TablePrefix string `json:"table_prefix"` // 表前缀 EnabledTables []string `json:"enabled_tables"` // 启用缓存的表 DisabledTables []string `json:"disabled_tables"` // 禁用缓存的表 // 查询配置 MaxCacheSize int `json:"max_cache_size"` // 单次查询最大缓存记录数 CacheComplexSQL bool `json:"cache_complex_sql"` // 是否缓存复杂SQL // 高级特性 EnableStats bool `json:"enable_stats"` // 启用统计 EnableWarmup bool `json:"enable_warmup"` // 启用预热 PenetrationGuard bool `json:"penetration_guard"` // 缓存穿透保护 BloomFilter bool `json:"bloom_filter"` // 布隆过滤器 // 失效策略 AutoInvalidate bool `json:"auto_invalidate"` // 自动失效 InvalidateDelay time.Duration `json:"invalidate_delay"` // 延迟失效时间 } // DefaultCacheConfig 默认缓存配置 func DefaultCacheConfig() CacheConfig { return CacheConfig{ DefaultTTL: 30 * time.Minute, TablePrefix: "gorm_cache", MaxCacheSize: 1000, CacheComplexSQL: false, EnableStats: true, EnableWarmup: false, PenetrationGuard: true, BloomFilter: false, AutoInvalidate: true, // 增加延迟失效时间,减少频繁的缓存失效操作 InvalidateDelay: 500 * time.Millisecond, } } // NewGormCachePlugin 创建GORM缓存插件 func NewGormCachePlugin(cache interfaces.CacheService, logger *zap.Logger, config ...CacheConfig) *GormCachePlugin { cfg := DefaultCacheConfig() if len(config) > 0 { cfg = config[0] } return &GormCachePlugin{ cache: cache, logger: logger, config: cfg, invalidationQueue: make(map[string]*time.Timer), } } // Name 插件名称 func (p *GormCachePlugin) Name() string { return "gorm-cache-plugin" } // Initialize 初始化插件 func (p *GormCachePlugin) Initialize(db *gorm.DB) error { p.logger.Info("初始化GORM缓存插件", zap.Duration("default_ttl", p.config.DefaultTTL), zap.Bool("auto_invalidate", p.config.AutoInvalidate), zap.Bool("penetration_guard", p.config.PenetrationGuard), zap.Duration("invalidate_delay", p.config.InvalidateDelay), ) // 注册回调函数 return p.registerCallbacks(db) } // Shutdown 关闭插件,清理资源 func (p *GormCachePlugin) Shutdown() { p.queueMutex.Lock() defer p.queueMutex.Unlock() // 停止所有定时器 for table, timer := range p.invalidationQueue { timer.Stop() p.logger.Debug("停止缓存失效定时器", zap.String("table", table)) } // 清空队列 p.invalidationQueue = make(map[string]*time.Timer) p.logger.Info("GORM缓存插件已关闭") } // registerCallbacks 注册GORM回调 func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error { // Query回调 - 查询时检查缓存 db.Callback().Query().Before("gorm:query").Register("cache:before_query", p.beforeQuery) db.Callback().Query().After("gorm:query").Register("cache:after_query", p.afterQuery) // Create回调 - 创建时失效缓存 db.Callback().Create().After("gorm:create").Register("cache:after_create", p.afterCreate) // Update回调 - 更新时失效缓存 db.Callback().Update().After("gorm:update").Register("cache:after_update", p.afterUpdate) // Delete回调 - 删除时失效缓存 db.Callback().Delete().After("gorm:delete").Register("cache:after_delete", p.afterDelete) return nil } // ================ 查询回调 ================ // beforeQuery 查询前回调 func (p *GormCachePlugin) beforeQuery(db *gorm.DB) { // 检查是否启用缓存 if !p.shouldCache(db) { p.logger.Debug("跳过缓存", zap.String("table", db.Statement.Table)) return } ctx := db.Statement.Context if ctx == nil { ctx = context.Background() } // 生成缓存键 cacheKey := p.generateCacheKey(db) // 从缓存获取结果 var cachedResult CachedResult if err := p.cache.Get(ctx, cacheKey, &cachedResult); err == nil { p.logger.Debug("缓存命中", zap.String("cache_key", cacheKey), zap.String("table", db.Statement.Table), ) // 恢复查询结果 if err := p.restoreFromCache(db, &cachedResult); err == nil { // 设置标记,跳过实际查询 db.Statement.Set("cache:hit", true) db.Statement.Set("cache:key", cacheKey) // 更新统计 if p.config.EnableStats { p.updateStats("hit", db.Statement.Table) } return } else { p.logger.Warn("缓存数据恢复失败,将执行数据库查询", zap.String("cache_key", cacheKey), zap.Error(err)) } } else { p.logger.Debug("缓存未命中", zap.String("cache_key", cacheKey), zap.Error(err)) } // 缓存未命中,设置标记 db.Statement.Set("cache:miss", true) db.Statement.Set("cache:key", cacheKey) if p.config.EnableStats { p.updateStats("miss", db.Statement.Table) } } // afterQuery 查询后回调 func (p *GormCachePlugin) afterQuery(db *gorm.DB) { // 检查是否缓存未命中 if _, ok := db.Statement.Get("cache:miss"); !ok { return } // 检查查询是否成功 if db.Error != nil { return } ctx := db.Statement.Context if ctx == nil { ctx = context.Background() } cacheKey, _ := db.Statement.Get("cache:key") // 将查询结果保存到缓存 if err := p.saveToCache(ctx, cacheKey.(string), db); err != nil { p.logger.Warn("保存查询结果到缓存失败", zap.String("cache_key", cacheKey.(string)), zap.Error(err), ) } } // ================ CUD回调 ================ // afterCreate 创建后回调 func (p *GormCachePlugin) afterCreate(db *gorm.DB) { if !p.config.AutoInvalidate || db.Error != nil { return } // 只对启用缓存的表执行失效操作 if p.shouldInvalidateTable(db.Statement.Table) { p.invalidateTableCache(db.Statement.Context, db.Statement.Table) } } // afterUpdate 更新后回调 func (p *GormCachePlugin) afterUpdate(db *gorm.DB) { if !p.config.AutoInvalidate || db.Error != nil { return } // 只对启用缓存的表执行失效操作 if p.shouldInvalidateTable(db.Statement.Table) { p.invalidateTableCache(db.Statement.Context, db.Statement.Table) } } // afterDelete 删除后回调 func (p *GormCachePlugin) afterDelete(db *gorm.DB) { if !p.config.AutoInvalidate || db.Error != nil { return } // 只对启用缓存的表执行失效操作 if p.shouldInvalidateTable(db.Statement.Table) { p.invalidateTableCache(db.Statement.Context, db.Statement.Table) } } // ================ 缓存管理方法 ================ // shouldInvalidateTable 判断是否应该对表执行缓存失效操作 func (p *GormCachePlugin) shouldInvalidateTable(table string) bool { // 使用全局缓存配置管理器进行智能决策 if GlobalCacheConfigManager != nil { return GlobalCacheConfigManager.IsTableCacheEnabled(table) } // 如果全局管理器未初始化,使用本地配置 // 检查表是否在禁用列表中 for _, disabledTable := range p.config.DisabledTables { if disabledTable == table { return false } } // 如果配置了启用列表,只对启用列表中的表执行失效操作 if len(p.config.EnabledTables) > 0 { for _, enabledTable := range p.config.EnabledTables { if enabledTable == table { return true } } return false } // 如果没有配置启用列表,默认对所有表执行失效操作(除了禁用列表中的表) return true } // shouldCache 判断是否应该缓存 func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool { // 检查是否手动禁用缓存 if value, ok := db.Statement.Get("cache:disabled"); ok && value.(bool) { return false } // 检查是否手动启用缓存 if value, ok := db.Statement.Get("cache:enabled"); ok && value.(bool) { return true } // 使用全局缓存配置管理器进行智能决策 if GlobalCacheConfigManager != nil { return GlobalCacheConfigManager.IsTableCacheEnabled(db.Statement.Table) } // 如果全局管理器未初始化,使用本地配置 // 检查表是否在禁用列表中 for _, disabledTable := range p.config.DisabledTables { if disabledTable == db.Statement.Table { return false } } // 检查表是否在启用列表中(如果配置了启用列表) if len(p.config.EnabledTables) > 0 { for _, table := range p.config.EnabledTables { if table == db.Statement.Table { return true } } return false } // 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表) return true } // isComplexQuery 判断是否为复杂查询 func (p *GormCachePlugin) isComplexQuery(db *gorm.DB) bool { sql := db.Statement.SQL.String() // 检查是否包含复杂操作 complexKeywords := []string{ "JOIN", "UNION", "SUBQUERY", "GROUP BY", "HAVING", "WINDOW", "RECURSIVE", } upperSQL := strings.ToUpper(sql) for _, keyword := range complexKeywords { if strings.Contains(upperSQL, keyword) { return true } } return false } // generateCacheKey 生成缓存键 func (p *GormCachePlugin) generateCacheKey(db *gorm.DB) string { // 构建缓存键的组成部分 keyParts := []string{ p.config.TablePrefix, db.Statement.Table, } // 添加SQL语句hash sqlHash := p.hashSQL(db.Statement.SQL.String(), db.Statement.Vars) keyParts = append(keyParts, sqlHash) return strings.Join(keyParts, ":") } // hashSQL 对SQL语句和参数进行hash func (p *GormCachePlugin) hashSQL(sql string, vars []interface{}) string { // 修复:改进SQL hash生成,确保唯一性 combined := sql for _, v := range vars { // 使用更精确的格式化,避免类型信息丢失 switch val := v.(type) { case string: combined += ":" + val case int, int32, int64: combined += fmt.Sprintf(":%d", val) case float32, float64: combined += fmt.Sprintf(":%f", val) case bool: combined += fmt.Sprintf(":%t", val) default: // 对于复杂类型,使用JSON序列化 if jsonData, err := json.Marshal(v); err == nil { combined += ":" + string(jsonData) } else { combined += fmt.Sprintf(":%v", v) } } } // 计算MD5 hash hasher := md5.New() hasher.Write([]byte(combined)) return hex.EncodeToString(hasher.Sum(nil)) } // CachedResult 缓存结果结构 type CachedResult struct { Data interface{} `json:"data"` RowCount int64 `json:"row_count"` Timestamp time.Time `json:"timestamp"` } // saveToCache 保存结果到缓存 func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *gorm.DB) error { // 检查结果大小限制 if db.Statement.RowsAffected > int64(p.config.MaxCacheSize) { p.logger.Debug("查询结果过大,跳过缓存", zap.String("cache_key", cacheKey), zap.Int64("rows", db.Statement.RowsAffected), ) return nil } // 获取查询结果 dest := db.Statement.Dest if dest == nil { return fmt.Errorf("查询结果为空") } // 修复:改进缓存数据保存逻辑 var dataToCache interface{} // 如果dest是切片,需要特殊处理 destValue := reflect.ValueOf(dest) if destValue.Kind() == reflect.Ptr && destValue.Elem().Kind() == reflect.Slice { // 对于切片,直接使用原始数据 dataToCache = destValue.Elem().Interface() } else { // 对于单个对象,也直接使用原始数据 dataToCache = dest } // 构建缓存结果 result := CachedResult{ Data: dataToCache, RowCount: db.Statement.RowsAffected, Timestamp: time.Now(), } // 获取TTL ttl := p.getTTL(db) // 保存到缓存 if err := p.cache.Set(ctx, cacheKey, result, ttl); err != nil { p.logger.Error("保存查询结果到缓存失败", zap.String("cache_key", cacheKey), zap.Error(err), ) return fmt.Errorf("保存到缓存失败: %w", err) } p.logger.Debug("查询结果已缓存", zap.String("cache_key", cacheKey), zap.Int64("rows", db.Statement.RowsAffected), zap.Duration("ttl", ttl), ) return nil } // restoreFromCache 从缓存恢复结果 func (p *GormCachePlugin) restoreFromCache(db *gorm.DB, cachedResult *CachedResult) error { if cachedResult.Data == nil { return fmt.Errorf("缓存数据为空") } // 反序列化到目标对象 destValue := reflect.ValueOf(db.Statement.Dest) if destValue.Kind() != reflect.Ptr || destValue.IsNil() { return fmt.Errorf("目标对象必须是指针") } // 修复:改进缓存数据恢复逻辑 cachedValue := reflect.ValueOf(cachedResult.Data) // 如果类型完全匹配,直接赋值 if cachedValue.Type().AssignableTo(destValue.Elem().Type()) { destValue.Elem().Set(cachedValue) } else { // 尝试JSON转换 jsonData, err := json.Marshal(cachedResult.Data) if err != nil { p.logger.Error("序列化缓存数据失败", zap.Error(err)) return fmt.Errorf("缓存数据类型转换失败: %w", err) } if err := json.Unmarshal(jsonData, db.Statement.Dest); err != nil { p.logger.Error("反序列化缓存数据失败", zap.String("json_data", string(jsonData)), zap.Error(err)) return fmt.Errorf("JSON反序列化失败: %w", err) } } // 设置影响行数 db.Statement.RowsAffected = cachedResult.RowCount p.logger.Debug("从缓存恢复数据成功", zap.Int64("rows", cachedResult.RowCount), zap.Time("timestamp", cachedResult.Timestamp)) return nil } // getTTL 获取TTL func (p *GormCachePlugin) getTTL(db *gorm.DB) time.Duration { // 检查是否设置了自定义TTL if value, ok := db.Statement.Get("cache:ttl"); ok { if ttl, ok := value.(time.Duration); ok { return ttl } } return p.config.DefaultTTL } // invalidateTableCache 失效表相关缓存 func (p *GormCachePlugin) invalidateTableCache(ctx context.Context, table string) { // 使用去重机制,避免重复的缓存失效操作 p.queueMutex.Lock() defer p.queueMutex.Unlock() // 如果已经有相同的失效操作在队列中,取消之前的定时器 if timer, exists := p.invalidationQueue[table]; exists { timer.Stop() delete(p.invalidationQueue, table) } // 创建独立的上下文,避免受到原始请求上下文的影响 // 设置合理的超时时间,避免缓存失效操作阻塞 cacheCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) // 创建新的定时器 timer := time.AfterFunc(p.config.InvalidateDelay, func() { // 执行缓存失效 p.doInvalidateTableCache(cacheCtx, table) // 清理定时器引用 p.queueMutex.Lock() delete(p.invalidationQueue, table) p.queueMutex.Unlock() // 取消上下文 cancel() }) // 将定时器加入队列 p.invalidationQueue[table] = timer p.logger.Debug("缓存失效操作已加入队列", zap.String("table", table), zap.Duration("delay", p.config.InvalidateDelay), ) } // doInvalidateTableCache 执行缓存失效 func (p *GormCachePlugin) doInvalidateTableCache(ctx context.Context, table string) { pattern := fmt.Sprintf("%s:%s:*", p.config.TablePrefix, table) // 添加重试机制,提高缓存失效的可靠性 maxRetries := 3 for attempt := 1; attempt <= maxRetries; attempt++ { if err := p.cache.DeletePattern(ctx, pattern); err != nil { if attempt < maxRetries { p.logger.Warn("缓存失效失败,准备重试", zap.String("table", table), zap.String("pattern", pattern), zap.Int("attempt", attempt), zap.Error(err), ) // 短暂延迟后重试 time.Sleep(time.Duration(attempt) * 100 * time.Millisecond) continue } p.logger.Warn("失效表缓存失败", zap.String("table", table), zap.String("pattern", pattern), zap.Int("attempts", maxRetries), zap.Error(err), ) return } // 成功删除,记录日志并退出 p.logger.Debug("表缓存已失效", zap.String("table", table), zap.String("pattern", pattern), ) return } } // updateStats 更新统计信息 func (p *GormCachePlugin) updateStats(operation, table string) { // 这里可以接入Prometheus等监控系统 p.logger.Debug("缓存统计", zap.String("operation", operation), zap.String("table", table), ) } // ================ 高级功能 ================ // WarmupCache 预热缓存 func (p *GormCachePlugin) WarmupCache(ctx context.Context, db *gorm.DB, queries []string) error { if !p.config.EnableWarmup { return fmt.Errorf("缓存预热未启用") } for _, query := range queries { if err := db.Raw(query).Error; err != nil { p.logger.Warn("缓存预热失败", zap.String("query", query), zap.Error(err), ) } } return nil } // GetCacheStats 获取缓存统计 func (p *GormCachePlugin) GetCacheStats(ctx context.Context) (map[string]interface{}, error) { stats, err := p.cache.Stats(ctx) if err != nil { return nil, err } return map[string]interface{}{ "hits": stats.Hits, "misses": stats.Misses, "keys": stats.Keys, "memory": stats.Memory, "connections": stats.Connections, "config": p.config, }, nil } // SetCacheEnabled 设置缓存启用状态 func (p *GormCachePlugin) SetCacheEnabled(db *gorm.DB, enabled bool) *gorm.DB { return db.Set("cache:enabled", enabled) } // SetCacheDisabled 设置缓存禁用状态 func (p *GormCachePlugin) SetCacheDisabled(db *gorm.DB, disabled bool) *gorm.DB { return db.Set("cache:disabled", disabled) } // SetCacheTTL 设置缓存TTL func (p *GormCachePlugin) SetCacheTTL(db *gorm.DB, ttl time.Duration) *gorm.DB { return db.Set("cache:ttl", ttl) }