Files
tyapi-server/internal/shared/cache/gorm_cache_plugin.go
2025-07-31 15:41:00 +08:00

656 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package cache
import (
"context"
"crypto/md5"
"encoding/hex"
"encoding/json"
"fmt"
"reflect"
"strings"
"sync"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/shared/interfaces"
)
// GormCachePlugin GORM缓存插件
type GormCachePlugin struct {
cache interfaces.CacheService
logger *zap.Logger
config CacheConfig
// 缓存失效去重机制
invalidationQueue map[string]*time.Timer
queueMutex sync.RWMutex
}
// CacheConfig 缓存配置
type CacheConfig struct {
// 基础配置
DefaultTTL time.Duration `json:"default_ttl"` // 默认TTL
TablePrefix string `json:"table_prefix"` // 表前缀
EnabledTables []string `json:"enabled_tables"` // 启用缓存的表
DisabledTables []string `json:"disabled_tables"` // 禁用缓存的表
// 查询配置
MaxCacheSize int `json:"max_cache_size"` // 单次查询最大缓存记录数
CacheComplexSQL bool `json:"cache_complex_sql"` // 是否缓存复杂SQL
// 高级特性
EnableStats bool `json:"enable_stats"` // 启用统计
EnableWarmup bool `json:"enable_warmup"` // 启用预热
PenetrationGuard bool `json:"penetration_guard"` // 缓存穿透保护
BloomFilter bool `json:"bloom_filter"` // 布隆过滤器
// 失效策略
AutoInvalidate bool `json:"auto_invalidate"` // 自动失效
InvalidateDelay time.Duration `json:"invalidate_delay"` // 延迟失效时间
}
// DefaultCacheConfig 默认缓存配置
func DefaultCacheConfig() CacheConfig {
return CacheConfig{
DefaultTTL: 30 * time.Minute,
TablePrefix: "gorm_cache",
MaxCacheSize: 1000,
CacheComplexSQL: false,
EnableStats: true,
EnableWarmup: false,
PenetrationGuard: true,
BloomFilter: false,
AutoInvalidate: true,
// 增加延迟失效时间,减少频繁的缓存失效操作
InvalidateDelay: 500 * time.Millisecond,
}
}
// NewGormCachePlugin 创建GORM缓存插件
func NewGormCachePlugin(cache interfaces.CacheService, logger *zap.Logger, config ...CacheConfig) *GormCachePlugin {
cfg := DefaultCacheConfig()
if len(config) > 0 {
cfg = config[0]
}
return &GormCachePlugin{
cache: cache,
logger: logger,
config: cfg,
invalidationQueue: make(map[string]*time.Timer),
}
}
// Name 插件名称
func (p *GormCachePlugin) Name() string {
return "gorm-cache-plugin"
}
// Initialize 初始化插件
func (p *GormCachePlugin) Initialize(db *gorm.DB) error {
p.logger.Info("初始化GORM缓存插件",
zap.Duration("default_ttl", p.config.DefaultTTL),
zap.Bool("auto_invalidate", p.config.AutoInvalidate),
zap.Bool("penetration_guard", p.config.PenetrationGuard),
zap.Duration("invalidate_delay", p.config.InvalidateDelay),
)
// 注册回调函数
return p.registerCallbacks(db)
}
// Shutdown 关闭插件,清理资源
func (p *GormCachePlugin) Shutdown() {
p.queueMutex.Lock()
defer p.queueMutex.Unlock()
// 停止所有定时器
for table, timer := range p.invalidationQueue {
timer.Stop()
p.logger.Debug("停止缓存失效定时器", zap.String("table", table))
}
// 清空队列
p.invalidationQueue = make(map[string]*time.Timer)
p.logger.Info("GORM缓存插件已关闭")
}
// registerCallbacks 注册GORM回调
func (p *GormCachePlugin) registerCallbacks(db *gorm.DB) error {
// Query回调 - 查询时检查缓存
db.Callback().Query().Before("gorm:query").Register("cache:before_query", p.beforeQuery)
db.Callback().Query().After("gorm:query").Register("cache:after_query", p.afterQuery)
// Create回调 - 创建时失效缓存
db.Callback().Create().After("gorm:create").Register("cache:after_create", p.afterCreate)
// Update回调 - 更新时失效缓存
db.Callback().Update().After("gorm:update").Register("cache:after_update", p.afterUpdate)
// Delete回调 - 删除时失效缓存
db.Callback().Delete().After("gorm:delete").Register("cache:after_delete", p.afterDelete)
return nil
}
// ================ 查询回调 ================
// beforeQuery 查询前回调
func (p *GormCachePlugin) beforeQuery(db *gorm.DB) {
// 检查是否启用缓存
if !p.shouldCache(db) {
p.logger.Debug("跳过缓存", zap.String("table", db.Statement.Table))
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
// 生成缓存键
cacheKey := p.generateCacheKey(db)
// 从缓存获取结果
var cachedResult CachedResult
if err := p.cache.Get(ctx, cacheKey, &cachedResult); err == nil {
p.logger.Debug("缓存命中",
zap.String("cache_key", cacheKey),
zap.String("table", db.Statement.Table),
)
// 恢复查询结果
if err := p.restoreFromCache(db, &cachedResult); err == nil {
// 设置标记,跳过实际查询
db.Statement.Set("cache:hit", true)
db.Statement.Set("cache:key", cacheKey)
// 更新统计
if p.config.EnableStats {
p.updateStats("hit", db.Statement.Table)
}
return
} else {
p.logger.Warn("缓存数据恢复失败,将执行数据库查询",
zap.String("cache_key", cacheKey),
zap.Error(err))
}
} else {
p.logger.Debug("缓存未命中",
zap.String("cache_key", cacheKey),
zap.Error(err))
}
// 缓存未命中,设置标记
db.Statement.Set("cache:miss", true)
db.Statement.Set("cache:key", cacheKey)
if p.config.EnableStats {
p.updateStats("miss", db.Statement.Table)
}
}
// afterQuery 查询后回调
func (p *GormCachePlugin) afterQuery(db *gorm.DB) {
// 检查是否缓存未命中
if _, ok := db.Statement.Get("cache:miss"); !ok {
return
}
// 检查查询是否成功
if db.Error != nil {
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
cacheKey, _ := db.Statement.Get("cache:key")
// 将查询结果保存到缓存
if err := p.saveToCache(ctx, cacheKey.(string), db); err != nil {
p.logger.Warn("保存查询结果到缓存失败",
zap.String("cache_key", cacheKey.(string)),
zap.Error(err),
)
}
}
// ================ CUD回调 ================
// afterCreate 创建后回调
func (p *GormCachePlugin) afterCreate(db *gorm.DB) {
if !p.config.AutoInvalidate || db.Error != nil {
return
}
// 只对启用缓存的表执行失效操作
if p.shouldInvalidateTable(db.Statement.Table) {
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
}
// afterUpdate 更新后回调
func (p *GormCachePlugin) afterUpdate(db *gorm.DB) {
if !p.config.AutoInvalidate || db.Error != nil {
return
}
// 只对启用缓存的表执行失效操作
if p.shouldInvalidateTable(db.Statement.Table) {
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
}
// afterDelete 删除后回调
func (p *GormCachePlugin) afterDelete(db *gorm.DB) {
if !p.config.AutoInvalidate || db.Error != nil {
return
}
// 只对启用缓存的表执行失效操作
if p.shouldInvalidateTable(db.Statement.Table) {
p.invalidateTableCache(db.Statement.Context, db.Statement.Table)
}
}
// ================ 缓存管理方法 ================
// shouldInvalidateTable 判断是否应该对表执行缓存失效操作
func (p *GormCachePlugin) shouldInvalidateTable(table string) bool {
// 使用全局缓存配置管理器进行智能决策
if GlobalCacheConfigManager != nil {
return GlobalCacheConfigManager.IsTableCacheEnabled(table)
}
// 如果全局管理器未初始化,使用本地配置
// 检查表是否在禁用列表中
for _, disabledTable := range p.config.DisabledTables {
if disabledTable == table {
return false
}
}
// 如果配置了启用列表,只对启用列表中的表执行失效操作
if len(p.config.EnabledTables) > 0 {
for _, enabledTable := range p.config.EnabledTables {
if enabledTable == table {
return true
}
}
return false
}
// 如果没有配置启用列表,默认对所有表执行失效操作(除了禁用列表中的表)
return true
}
// shouldCache 判断是否应该缓存
func (p *GormCachePlugin) shouldCache(db *gorm.DB) bool {
// 检查是否手动禁用缓存
if value, ok := db.Statement.Get("cache:disabled"); ok && value.(bool) {
return false
}
// 检查是否手动启用缓存
if value, ok := db.Statement.Get("cache:enabled"); ok && value.(bool) {
return true
}
// 使用全局缓存配置管理器进行智能决策
if GlobalCacheConfigManager != nil {
return GlobalCacheConfigManager.IsTableCacheEnabled(db.Statement.Table)
}
// 如果全局管理器未初始化,使用本地配置
// 检查表是否在禁用列表中
for _, disabledTable := range p.config.DisabledTables {
if disabledTable == db.Statement.Table {
return false
}
}
// 检查表是否在启用列表中(如果配置了启用列表)
if len(p.config.EnabledTables) > 0 {
for _, table := range p.config.EnabledTables {
if table == db.Statement.Table {
return true
}
}
return false
}
// 如果没有配置启用列表,默认对所有表启用缓存(除了禁用列表中的表)
return true
}
// isComplexQuery 判断是否为复杂查询
func (p *GormCachePlugin) isComplexQuery(db *gorm.DB) bool {
sql := db.Statement.SQL.String()
// 检查是否包含复杂操作
complexKeywords := []string{
"JOIN", "UNION", "SUBQUERY", "GROUP BY",
"HAVING", "WINDOW", "RECURSIVE",
}
upperSQL := strings.ToUpper(sql)
for _, keyword := range complexKeywords {
if strings.Contains(upperSQL, keyword) {
return true
}
}
return false
}
// generateCacheKey 生成缓存键
func (p *GormCachePlugin) generateCacheKey(db *gorm.DB) string {
// 构建缓存键的组成部分
keyParts := []string{
p.config.TablePrefix,
db.Statement.Table,
}
// 添加SQL语句hash
sqlHash := p.hashSQL(db.Statement.SQL.String(), db.Statement.Vars)
keyParts = append(keyParts, sqlHash)
return strings.Join(keyParts, ":")
}
// hashSQL 对SQL语句和参数进行hash
func (p *GormCachePlugin) hashSQL(sql string, vars []interface{}) string {
// 修复改进SQL hash生成确保唯一性
combined := sql
for _, v := range vars {
// 使用更精确的格式化,避免类型信息丢失
switch val := v.(type) {
case string:
combined += ":" + val
case int, int32, int64:
combined += fmt.Sprintf(":%d", val)
case float32, float64:
combined += fmt.Sprintf(":%f", val)
case bool:
combined += fmt.Sprintf(":%t", val)
default:
// 对于复杂类型使用JSON序列化
if jsonData, err := json.Marshal(v); err == nil {
combined += ":" + string(jsonData)
} else {
combined += fmt.Sprintf(":%v", v)
}
}
}
// 计算MD5 hash
hasher := md5.New()
hasher.Write([]byte(combined))
return hex.EncodeToString(hasher.Sum(nil))
}
// CachedResult 缓存结果结构
type CachedResult struct {
Data interface{} `json:"data"`
RowCount int64 `json:"row_count"`
Timestamp time.Time `json:"timestamp"`
}
// saveToCache 保存结果到缓存
func (p *GormCachePlugin) saveToCache(ctx context.Context, cacheKey string, db *gorm.DB) error {
// 检查结果大小限制
if db.Statement.RowsAffected > int64(p.config.MaxCacheSize) {
p.logger.Debug("查询结果过大,跳过缓存",
zap.String("cache_key", cacheKey),
zap.Int64("rows", db.Statement.RowsAffected),
)
return nil
}
// 获取查询结果
dest := db.Statement.Dest
if dest == nil {
return fmt.Errorf("查询结果为空")
}
// 修复:改进缓存数据保存逻辑
var dataToCache interface{}
// 如果dest是切片需要特殊处理
destValue := reflect.ValueOf(dest)
if destValue.Kind() == reflect.Ptr && destValue.Elem().Kind() == reflect.Slice {
// 对于切片,直接使用原始数据
dataToCache = destValue.Elem().Interface()
} else {
// 对于单个对象,也直接使用原始数据
dataToCache = dest
}
// 构建缓存结果
result := CachedResult{
Data: dataToCache,
RowCount: db.Statement.RowsAffected,
Timestamp: time.Now(),
}
// 获取TTL
ttl := p.getTTL(db)
// 保存到缓存
if err := p.cache.Set(ctx, cacheKey, result, ttl); err != nil {
p.logger.Error("保存查询结果到缓存失败",
zap.String("cache_key", cacheKey),
zap.Error(err),
)
return fmt.Errorf("保存到缓存失败: %w", err)
}
p.logger.Debug("查询结果已缓存",
zap.String("cache_key", cacheKey),
zap.Int64("rows", db.Statement.RowsAffected),
zap.Duration("ttl", ttl),
)
return nil
}
// restoreFromCache 从缓存恢复结果
func (p *GormCachePlugin) restoreFromCache(db *gorm.DB, cachedResult *CachedResult) error {
if cachedResult.Data == nil {
return fmt.Errorf("缓存数据为空")
}
// 反序列化到目标对象
destValue := reflect.ValueOf(db.Statement.Dest)
if destValue.Kind() != reflect.Ptr || destValue.IsNil() {
return fmt.Errorf("目标对象必须是指针")
}
// 修复:改进缓存数据恢复逻辑
cachedValue := reflect.ValueOf(cachedResult.Data)
// 如果类型完全匹配,直接赋值
if cachedValue.Type().AssignableTo(destValue.Elem().Type()) {
destValue.Elem().Set(cachedValue)
} else {
// 尝试JSON转换
jsonData, err := json.Marshal(cachedResult.Data)
if err != nil {
p.logger.Error("序列化缓存数据失败", zap.Error(err))
return fmt.Errorf("缓存数据类型转换失败: %w", err)
}
if err := json.Unmarshal(jsonData, db.Statement.Dest); err != nil {
p.logger.Error("反序列化缓存数据失败",
zap.String("json_data", string(jsonData)),
zap.Error(err))
return fmt.Errorf("JSON反序列化失败: %w", err)
}
}
// 设置影响行数
db.Statement.RowsAffected = cachedResult.RowCount
p.logger.Debug("从缓存恢复数据成功",
zap.Int64("rows", cachedResult.RowCount),
zap.Time("timestamp", cachedResult.Timestamp))
return nil
}
// getTTL 获取TTL
func (p *GormCachePlugin) getTTL(db *gorm.DB) time.Duration {
// 检查是否设置了自定义TTL
if value, ok := db.Statement.Get("cache:ttl"); ok {
if ttl, ok := value.(time.Duration); ok {
return ttl
}
}
return p.config.DefaultTTL
}
// invalidateTableCache 失效表相关缓存
func (p *GormCachePlugin) invalidateTableCache(ctx context.Context, table string) {
// 使用去重机制,避免重复的缓存失效操作
p.queueMutex.Lock()
defer p.queueMutex.Unlock()
// 如果已经有相同的失效操作在队列中,取消之前的定时器
if timer, exists := p.invalidationQueue[table]; exists {
timer.Stop()
delete(p.invalidationQueue, table)
}
// 创建独立的上下文,避免受到原始请求上下文的影响
// 设置合理的超时时间,避免缓存失效操作阻塞
cacheCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
// 创建新的定时器
timer := time.AfterFunc(p.config.InvalidateDelay, func() {
// 执行缓存失效
p.doInvalidateTableCache(cacheCtx, table)
// 清理定时器引用
p.queueMutex.Lock()
delete(p.invalidationQueue, table)
p.queueMutex.Unlock()
// 取消上下文
cancel()
})
// 将定时器加入队列
p.invalidationQueue[table] = timer
p.logger.Debug("缓存失效操作已加入队列",
zap.String("table", table),
zap.Duration("delay", p.config.InvalidateDelay),
)
}
// doInvalidateTableCache 执行缓存失效
func (p *GormCachePlugin) doInvalidateTableCache(ctx context.Context, table string) {
pattern := fmt.Sprintf("%s:%s:*", p.config.TablePrefix, table)
// 添加重试机制,提高缓存失效的可靠性
maxRetries := 3
for attempt := 1; attempt <= maxRetries; attempt++ {
if err := p.cache.DeletePattern(ctx, pattern); err != nil {
if attempt < maxRetries {
p.logger.Warn("缓存失效失败,准备重试",
zap.String("table", table),
zap.String("pattern", pattern),
zap.Int("attempt", attempt),
zap.Error(err),
)
// 短暂延迟后重试
time.Sleep(time.Duration(attempt) * 100 * time.Millisecond)
continue
}
p.logger.Warn("失效表缓存失败",
zap.String("table", table),
zap.String("pattern", pattern),
zap.Int("attempts", maxRetries),
zap.Error(err),
)
return
}
// 成功删除,记录日志并退出
p.logger.Debug("表缓存已失效",
zap.String("table", table),
zap.String("pattern", pattern),
)
return
}
}
// updateStats 更新统计信息
func (p *GormCachePlugin) updateStats(operation, table string) {
// 这里可以接入Prometheus等监控系统
p.logger.Debug("缓存统计",
zap.String("operation", operation),
zap.String("table", table),
)
}
// ================ 高级功能 ================
// WarmupCache 预热缓存
func (p *GormCachePlugin) WarmupCache(ctx context.Context, db *gorm.DB, queries []string) error {
if !p.config.EnableWarmup {
return fmt.Errorf("缓存预热未启用")
}
for _, query := range queries {
if err := db.Raw(query).Error; err != nil {
p.logger.Warn("缓存预热失败",
zap.String("query", query),
zap.Error(err),
)
}
}
return nil
}
// GetCacheStats 获取缓存统计
func (p *GormCachePlugin) GetCacheStats(ctx context.Context) (map[string]interface{}, error) {
stats, err := p.cache.Stats(ctx)
if err != nil {
return nil, err
}
return map[string]interface{}{
"hits": stats.Hits,
"misses": stats.Misses,
"keys": stats.Keys,
"memory": stats.Memory,
"connections": stats.Connections,
"config": p.config,
}, nil
}
// SetCacheEnabled 设置缓存启用状态
func (p *GormCachePlugin) SetCacheEnabled(db *gorm.DB, enabled bool) *gorm.DB {
return db.Set("cache:enabled", enabled)
}
// SetCacheDisabled 设置缓存禁用状态
func (p *GormCachePlugin) SetCacheDisabled(db *gorm.DB, disabled bool) *gorm.DB {
return db.Set("cache:disabled", disabled)
}
// SetCacheTTL 设置缓存TTL
func (p *GormCachePlugin) SetCacheTTL(db *gorm.DB, ttl time.Duration) *gorm.DB {
return db.Set("cache:ttl", ttl)
}