package database import ( "context" "fmt" "time" "go.uber.org/zap" "gorm.io/gorm" "tyapi-server/internal/shared/interfaces" ) // CachedBaseRepositoryImpl 支持缓存的基础仓储实现 // 在BaseRepositoryImpl基础上增加智能缓存管理 type CachedBaseRepositoryImpl struct { *BaseRepositoryImpl tableName string } // NewCachedBaseRepositoryImpl 创建支持缓存的基础仓储实现 func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName string) *CachedBaseRepositoryImpl { return &CachedBaseRepositoryImpl{ BaseRepositoryImpl: NewBaseRepositoryImpl(db, logger), tableName: tableName, } } // ================ 智能缓存方法 ================ // GetWithCache 带缓存的单条查询 func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error { db := r.GetDB(ctx). Set("cache:enabled", true). Set("cache:ttl", ttl) return db.Where(where, args...).First(dest).Error } // FindWithCache 带缓存的多条查询 func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error { db := r.GetDB(ctx). Set("cache:enabled", true). Set("cache:ttl", ttl) return db.Where(where, args...).Find(dest).Error } // CountWithCache 带缓存的计数查询 func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error { db := r.GetDB(ctx). Set("cache:enabled", true). Set("cache:ttl", ttl). Model(entity) return db.Where(where, args...).Count(count).Error } // ListWithCache 带缓存的列表查询 func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error { db := r.GetDB(ctx). Set("cache:enabled", true). Set("cache:ttl", ttl) // 应用where条件 if options.Where != "" { db = db.Where(options.Where, options.Args...) } // 应用预加载 for _, preload := range options.Preloads { db = db.Preload(preload) } // 应用排序 if options.Order != "" { db = db.Order(options.Order) } // 应用分页 if options.Limit > 0 { db = db.Limit(options.Limit) } if options.Offset > 0 { db = db.Offset(options.Offset) } return db.Find(dest).Error } // CacheListOptions 缓存列表查询选项 type CacheListOptions struct { Where string `json:"where"` Args []interface{} `json:"args"` Order string `json:"order"` Limit int `json:"limit"` Offset int `json:"offset"` Preloads []string `json:"preloads"` } // ================ 缓存控制方法 ================ // WithCache 启用缓存 func (r *CachedBaseRepositoryImpl) WithCache(ttl time.Duration) *CachedBaseRepositoryImpl { // 创建新实例避免状态污染 return &CachedBaseRepositoryImpl{ BaseRepositoryImpl: &BaseRepositoryImpl{ db: r.db.Set("cache:enabled", true).Set("cache:ttl", ttl), logger: r.logger, }, tableName: r.tableName, } } // WithoutCache 禁用缓存 func (r *CachedBaseRepositoryImpl) WithoutCache() *CachedBaseRepositoryImpl { return &CachedBaseRepositoryImpl{ BaseRepositoryImpl: &BaseRepositoryImpl{ db: r.db.Set("cache:disabled", true), logger: r.logger, }, tableName: r.tableName, } } // WithShortCache 短期缓存(5分钟) func (r *CachedBaseRepositoryImpl) WithShortCache() *CachedBaseRepositoryImpl { return r.WithCache(5 * time.Minute) } // WithMediumCache 中期缓存(30分钟) func (r *CachedBaseRepositoryImpl) WithMediumCache() *CachedBaseRepositoryImpl { return r.WithCache(30 * time.Minute) } // WithLongCache 长期缓存(2小时) func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl { return r.WithCache(2 * time.Hour) } // ================ 智能查询方法 ================ // SmartGetByID 智能ID查询(自动缓存) func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error { return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id) } // SmartGetByField 智能字段查询(自动缓存) func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest interface{}, field string, value interface{}, ttl ...time.Duration) error { cacheTTL := 15 * time.Minute if len(ttl) > 0 { cacheTTL = ttl[0] } return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value) } // SmartList 智能列表查询(根据查询复杂度自动选择缓存策略) func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface{}, options interfaces.ListOptions) error { // 根据查询复杂度决定缓存策略 cacheTTL := r.calculateCacheTTL(options) useCache := r.shouldUseCache(options) db := r.GetDB(ctx) if useCache { db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL) } else { db = db.Set("cache:disabled", true) } // 应用筛选条件 if options.Filters != nil { for key, value := range options.Filters { db = db.Where(key+" = ?", value) } } // 应用搜索条件 if options.Search != "" { // 这里应该由具体Repository实现搜索逻辑 r.logger.Debug("搜索查询默认禁用缓存", zap.String("search", options.Search)) db = db.Set("cache:disabled", true) } // 应用预加载 for _, include := range options.Include { db = db.Preload(include) } // 应用排序 if options.Sort != "" { order := "ASC" if options.Order == "desc" || options.Order == "DESC" { order = "DESC" } db = db.Order(options.Sort + " " + order) } else { db = db.Order("created_at DESC") } // 应用分页 if options.Page > 0 && options.PageSize > 0 { offset := (options.Page - 1) * options.PageSize db = db.Offset(offset).Limit(options.PageSize) } return db.Find(dest).Error } // calculateCacheTTL 计算缓存TTL func (r *CachedBaseRepositoryImpl) calculateCacheTTL(options interfaces.ListOptions) time.Duration { // 基础TTL baseTTL := 15 * time.Minute // 如果有搜索,缩短TTL if options.Search != "" { return 2 * time.Minute } // 如果有复杂筛选,缩短TTL if len(options.Filters) > 3 { return 5 * time.Minute } // 如果是简单查询,延长TTL if len(options.Filters) == 0 && options.Search == "" { return 30 * time.Minute } return baseTTL } // shouldUseCache 判断是否应该使用缓存 func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions) bool { // 如果有搜索,不使用缓存(搜索结果变化频繁) if options.Search != "" { return false } // 如果筛选条件过多,不使用缓存 if len(options.Filters) > 5 { return false } // 如果分页页数过大,不使用缓存 if options.Page > 10 { return false } return true } // ================ 缓存预热方法 ================ // WarmupCommonQueries 预热常用查询 func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error { r.logger.Info("开始预热缓存", zap.String("table", r.tableName), zap.Int("queries", len(queries)), ) for _, query := range queries { if err := r.executeWarmupQuery(ctx, query); err != nil { r.logger.Warn("缓存预热失败", zap.String("query", query.Name), zap.Error(err), ) } } return nil } // WarmupQuery 预热查询定义 type WarmupQuery struct { Name string `json:"name"` SQL string `json:"sql"` Args []interface{} `json:"args"` TTL time.Duration `json:"ttl"` Dest interface{} `json:"dest"` } // executeWarmupQuery 执行预热查询 func (r *CachedBaseRepositoryImpl) executeWarmupQuery(ctx context.Context, query WarmupQuery) error { db := r.GetDB(ctx). Set("cache:enabled", true). Set("cache:ttl", query.TTL) if query.SQL != "" { return db.Raw(query.SQL, query.Args...).Scan(query.Dest).Error } return nil } // ================ 高级缓存特性 ================ // GetOrCreate 获取或创建(带缓存) func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interface{}, where string, args []interface{}, createFn func() interface{}) error { // 先尝试从缓存获取 if err := r.GetWithCache(ctx, dest, 15*time.Minute, where, args...); err == nil { return nil } // 缓存未命中,尝试从数据库获取 if err := r.GetDB(ctx).Where(where, args...).First(dest).Error; err == nil { return nil } // 数据库也没有,创建新记录 if createFn != nil { newEntity := createFn() if err := r.CreateEntity(ctx, newEntity); err != nil { return err } // 将新创建的实体复制到dest // 这里需要反射或其他方式复制 return nil } return gorm.ErrRecordNotFound } // BatchGetWithCache 批量获取(带缓存) func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []string, dest interface{}, ttl time.Duration) error { if len(ids) == 0 { return nil } return r.FindWithCache(ctx, dest, ttl, "id IN ?", ids) } // RefreshCache 刷新缓存 func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error { r.logger.Info("刷新缓存", zap.String("table", r.tableName), zap.String("pattern", pattern), ) // 这里需要调用缓存服务的删除模式方法 // 具体实现取决于你的CacheService接口 return nil } // ================ 缓存统计方法 ================ // GetCacheInfo 获取缓存信息 func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} { return map[string]interface{}{ "table_name": r.tableName, "cache_enabled": true, "default_ttl": "30m", "cache_patterns": []string{ fmt.Sprintf("gorm_cache:%s:*", r.tableName), }, } } // LogCacheOperation 记录缓存操作 func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) { r.logger.Debug("缓存操作", zap.String("table", r.tableName), zap.String("operation", operation), zap.String("details", details), ) }