Files
tyapi-server/internal/shared/database/cached_base_repository.go
2025-07-20 20:53:26 +08:00

367 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package database
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/shared/interfaces"
)
// CachedBaseRepositoryImpl 支持缓存的基础仓储实现
// 在BaseRepositoryImpl基础上增加智能缓存管理
type CachedBaseRepositoryImpl struct {
*BaseRepositoryImpl
tableName string
}
// NewCachedBaseRepositoryImpl 创建支持缓存的基础仓储实现
func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName string) *CachedBaseRepositoryImpl {
return &CachedBaseRepositoryImpl{
BaseRepositoryImpl: NewBaseRepositoryImpl(db, logger),
tableName: tableName,
}
}
// ================ 智能缓存方法 ================
// GetWithCache 带缓存的单条查询
func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl)
return db.Where(where, args...).First(dest).Error
}
// FindWithCache 带缓存的多条查询
func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl)
return db.Where(where, args...).Find(dest).Error
}
// CountWithCache 带缓存的计数查询
func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl).
Model(entity)
return db.Where(where, args...).Count(count).Error
}
// ListWithCache 带缓存的列表查询
func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", ttl)
// 应用where条件
if options.Where != "" {
db = db.Where(options.Where, options.Args...)
}
// 应用预加载
for _, preload := range options.Preloads {
db = db.Preload(preload)
}
// 应用排序
if options.Order != "" {
db = db.Order(options.Order)
}
// 应用分页
if options.Limit > 0 {
db = db.Limit(options.Limit)
}
if options.Offset > 0 {
db = db.Offset(options.Offset)
}
return db.Find(dest).Error
}
// CacheListOptions 缓存列表查询选项
type CacheListOptions struct {
Where string `json:"where"`
Args []interface{} `json:"args"`
Order string `json:"order"`
Limit int `json:"limit"`
Offset int `json:"offset"`
Preloads []string `json:"preloads"`
}
// ================ 缓存控制方法 ================
// WithCache 启用缓存
func (r *CachedBaseRepositoryImpl) WithCache(ttl time.Duration) *CachedBaseRepositoryImpl {
// 创建新实例避免状态污染
return &CachedBaseRepositoryImpl{
BaseRepositoryImpl: &BaseRepositoryImpl{
db: r.db.Set("cache:enabled", true).Set("cache:ttl", ttl),
logger: r.logger,
},
tableName: r.tableName,
}
}
// WithoutCache 禁用缓存
func (r *CachedBaseRepositoryImpl) WithoutCache() *CachedBaseRepositoryImpl {
return &CachedBaseRepositoryImpl{
BaseRepositoryImpl: &BaseRepositoryImpl{
db: r.db.Set("cache:disabled", true),
logger: r.logger,
},
tableName: r.tableName,
}
}
// WithShortCache 短期缓存5分钟
func (r *CachedBaseRepositoryImpl) WithShortCache() *CachedBaseRepositoryImpl {
return r.WithCache(5 * time.Minute)
}
// WithMediumCache 中期缓存30分钟
func (r *CachedBaseRepositoryImpl) WithMediumCache() *CachedBaseRepositoryImpl {
return r.WithCache(30 * time.Minute)
}
// WithLongCache 长期缓存2小时
func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl {
return r.WithCache(2 * time.Hour)
}
// ================ 智能查询方法 ================
// SmartGetByID 智能ID查询自动缓存
func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error {
return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id)
}
// SmartGetByField 智能字段查询(自动缓存)
func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest interface{}, field string, value interface{}, ttl ...time.Duration) error {
cacheTTL := 15 * time.Minute
if len(ttl) > 0 {
cacheTTL = ttl[0]
}
return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value)
}
// SmartList 智能列表查询(根据查询复杂度自动选择缓存策略)
func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface{}, options interfaces.ListOptions) error {
// 根据查询复杂度决定缓存策略
cacheTTL := r.calculateCacheTTL(options)
useCache := r.shouldUseCache(options)
db := r.GetDB(ctx)
if useCache {
db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL)
} else {
db = db.Set("cache:disabled", true)
}
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
db = db.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
// 这里应该由具体Repository实现搜索逻辑
r.logger.Debug("搜索查询默认禁用缓存", zap.String("search", options.Search))
db = db.Set("cache:disabled", true)
}
// 应用预加载
for _, include := range options.Include {
db = db.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
db = db.Order(options.Sort + " " + order)
} else {
db = db.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
db = db.Offset(offset).Limit(options.PageSize)
}
return db.Find(dest).Error
}
// calculateCacheTTL 计算缓存TTL
func (r *CachedBaseRepositoryImpl) calculateCacheTTL(options interfaces.ListOptions) time.Duration {
// 基础TTL
baseTTL := 15 * time.Minute
// 如果有搜索缩短TTL
if options.Search != "" {
return 2 * time.Minute
}
// 如果有复杂筛选缩短TTL
if len(options.Filters) > 3 {
return 5 * time.Minute
}
// 如果是简单查询延长TTL
if len(options.Filters) == 0 && options.Search == "" {
return 30 * time.Minute
}
return baseTTL
}
// shouldUseCache 判断是否应该使用缓存
func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions) bool {
// 如果有搜索,不使用缓存(搜索结果变化频繁)
if options.Search != "" {
return false
}
// 如果筛选条件过多,不使用缓存
if len(options.Filters) > 5 {
return false
}
// 如果分页页数过大,不使用缓存
if options.Page > 10 {
return false
}
return true
}
// ================ 缓存预热方法 ================
// WarmupCommonQueries 预热常用查询
func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error {
r.logger.Info("开始预热缓存",
zap.String("table", r.tableName),
zap.Int("queries", len(queries)),
)
for _, query := range queries {
if err := r.executeWarmupQuery(ctx, query); err != nil {
r.logger.Warn("缓存预热失败",
zap.String("query", query.Name),
zap.Error(err),
)
}
}
return nil
}
// WarmupQuery 预热查询定义
type WarmupQuery struct {
Name string `json:"name"`
SQL string `json:"sql"`
Args []interface{} `json:"args"`
TTL time.Duration `json:"ttl"`
Dest interface{} `json:"dest"`
}
// executeWarmupQuery 执行预热查询
func (r *CachedBaseRepositoryImpl) executeWarmupQuery(ctx context.Context, query WarmupQuery) error {
db := r.GetDB(ctx).
Set("cache:enabled", true).
Set("cache:ttl", query.TTL)
if query.SQL != "" {
return db.Raw(query.SQL, query.Args...).Scan(query.Dest).Error
}
return nil
}
// ================ 高级缓存特性 ================
// GetOrCreate 获取或创建(带缓存)
func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interface{}, where string, args []interface{}, createFn func() interface{}) error {
// 先尝试从缓存获取
if err := r.GetWithCache(ctx, dest, 15*time.Minute, where, args...); err == nil {
return nil
}
// 缓存未命中,尝试从数据库获取
if err := r.GetDB(ctx).Where(where, args...).First(dest).Error; err == nil {
return nil
}
// 数据库也没有,创建新记录
if createFn != nil {
newEntity := createFn()
if err := r.CreateEntity(ctx, newEntity); err != nil {
return err
}
// 将新创建的实体复制到dest
// 这里需要反射或其他方式复制
return nil
}
return gorm.ErrRecordNotFound
}
// BatchGetWithCache 批量获取(带缓存)
func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []string, dest interface{}, ttl time.Duration) error {
if len(ids) == 0 {
return nil
}
return r.FindWithCache(ctx, dest, ttl, "id IN ?", ids)
}
// RefreshCache 刷新缓存
func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error {
r.logger.Info("刷新缓存",
zap.String("table", r.tableName),
zap.String("pattern", pattern),
)
// 这里需要调用缓存服务的删除模式方法
// 具体实现取决于你的CacheService接口
return nil
}
// ================ 缓存统计方法 ================
// GetCacheInfo 获取缓存信息
func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
return map[string]interface{}{
"table_name": r.tableName,
"cache_enabled": true,
"default_ttl": "30m",
"cache_patterns": []string{
fmt.Sprintf("gorm_cache:%s:*", r.tableName),
},
}
}
// LogCacheOperation 记录缓存操作
func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) {
r.logger.Debug("缓存操作",
zap.String("table", r.tableName),
zap.String("operation", operation),
zap.String("details", details),
)
}