367 lines
9.9 KiB
Go
367 lines
9.9 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"time"
|
||
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
|
||
"tyapi-server/internal/shared/interfaces"
|
||
)
|
||
|
||
// CachedBaseRepositoryImpl 支持缓存的基础仓储实现
|
||
// 在BaseRepositoryImpl基础上增加智能缓存管理
|
||
type CachedBaseRepositoryImpl struct {
|
||
*BaseRepositoryImpl
|
||
tableName string
|
||
}
|
||
|
||
// NewCachedBaseRepositoryImpl 创建支持缓存的基础仓储实现
|
||
func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName string) *CachedBaseRepositoryImpl {
|
||
return &CachedBaseRepositoryImpl{
|
||
BaseRepositoryImpl: NewBaseRepositoryImpl(db, logger),
|
||
tableName: tableName,
|
||
}
|
||
}
|
||
|
||
// ================ 智能缓存方法 ================
|
||
|
||
// GetWithCache 带缓存的单条查询
|
||
func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
|
||
db := r.GetDB(ctx).
|
||
Set("cache:enabled", true).
|
||
Set("cache:ttl", ttl)
|
||
|
||
return db.Where(where, args...).First(dest).Error
|
||
}
|
||
|
||
// FindWithCache 带缓存的多条查询
|
||
func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
|
||
db := r.GetDB(ctx).
|
||
Set("cache:enabled", true).
|
||
Set("cache:ttl", ttl)
|
||
|
||
return db.Where(where, args...).Find(dest).Error
|
||
}
|
||
|
||
// CountWithCache 带缓存的计数查询
|
||
func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error {
|
||
db := r.GetDB(ctx).
|
||
Set("cache:enabled", true).
|
||
Set("cache:ttl", ttl).
|
||
Model(entity)
|
||
|
||
return db.Where(where, args...).Count(count).Error
|
||
}
|
||
|
||
// ListWithCache 带缓存的列表查询
|
||
func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error {
|
||
db := r.GetDB(ctx).
|
||
Set("cache:enabled", true).
|
||
Set("cache:ttl", ttl)
|
||
|
||
// 应用where条件
|
||
if options.Where != "" {
|
||
db = db.Where(options.Where, options.Args...)
|
||
}
|
||
|
||
// 应用预加载
|
||
for _, preload := range options.Preloads {
|
||
db = db.Preload(preload)
|
||
}
|
||
|
||
// 应用排序
|
||
if options.Order != "" {
|
||
db = db.Order(options.Order)
|
||
}
|
||
|
||
// 应用分页
|
||
if options.Limit > 0 {
|
||
db = db.Limit(options.Limit)
|
||
}
|
||
if options.Offset > 0 {
|
||
db = db.Offset(options.Offset)
|
||
}
|
||
|
||
return db.Find(dest).Error
|
||
}
|
||
|
||
// CacheListOptions 缓存列表查询选项
|
||
type CacheListOptions struct {
|
||
Where string `json:"where"`
|
||
Args []interface{} `json:"args"`
|
||
Order string `json:"order"`
|
||
Limit int `json:"limit"`
|
||
Offset int `json:"offset"`
|
||
Preloads []string `json:"preloads"`
|
||
}
|
||
|
||
// ================ 缓存控制方法 ================
|
||
|
||
// WithCache 启用缓存
|
||
func (r *CachedBaseRepositoryImpl) WithCache(ttl time.Duration) *CachedBaseRepositoryImpl {
|
||
// 创建新实例避免状态污染
|
||
return &CachedBaseRepositoryImpl{
|
||
BaseRepositoryImpl: &BaseRepositoryImpl{
|
||
db: r.db.Set("cache:enabled", true).Set("cache:ttl", ttl),
|
||
logger: r.logger,
|
||
},
|
||
tableName: r.tableName,
|
||
}
|
||
}
|
||
|
||
// WithoutCache 禁用缓存
|
||
func (r *CachedBaseRepositoryImpl) WithoutCache() *CachedBaseRepositoryImpl {
|
||
return &CachedBaseRepositoryImpl{
|
||
BaseRepositoryImpl: &BaseRepositoryImpl{
|
||
db: r.db.Set("cache:disabled", true),
|
||
logger: r.logger,
|
||
},
|
||
tableName: r.tableName,
|
||
}
|
||
}
|
||
|
||
// WithShortCache 短期缓存(5分钟)
|
||
func (r *CachedBaseRepositoryImpl) WithShortCache() *CachedBaseRepositoryImpl {
|
||
return r.WithCache(5 * time.Minute)
|
||
}
|
||
|
||
// WithMediumCache 中期缓存(30分钟)
|
||
func (r *CachedBaseRepositoryImpl) WithMediumCache() *CachedBaseRepositoryImpl {
|
||
return r.WithCache(30 * time.Minute)
|
||
}
|
||
|
||
// WithLongCache 长期缓存(2小时)
|
||
func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl {
|
||
return r.WithCache(2 * time.Hour)
|
||
}
|
||
|
||
// ================ 智能查询方法 ================
|
||
|
||
// SmartGetByID 智能ID查询(自动缓存)
|
||
func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error {
|
||
return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id)
|
||
}
|
||
|
||
// SmartGetByField 智能字段查询(自动缓存)
|
||
func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest interface{}, field string, value interface{}, ttl ...time.Duration) error {
|
||
cacheTTL := 15 * time.Minute
|
||
if len(ttl) > 0 {
|
||
cacheTTL = ttl[0]
|
||
}
|
||
|
||
return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value)
|
||
}
|
||
|
||
// SmartList 智能列表查询(根据查询复杂度自动选择缓存策略)
|
||
func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface{}, options interfaces.ListOptions) error {
|
||
// 根据查询复杂度决定缓存策略
|
||
cacheTTL := r.calculateCacheTTL(options)
|
||
useCache := r.shouldUseCache(options)
|
||
|
||
db := r.GetDB(ctx)
|
||
if useCache {
|
||
db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL)
|
||
} else {
|
||
db = db.Set("cache:disabled", true)
|
||
}
|
||
|
||
// 应用筛选条件
|
||
if options.Filters != nil {
|
||
for key, value := range options.Filters {
|
||
db = db.Where(key+" = ?", value)
|
||
}
|
||
}
|
||
|
||
// 应用搜索条件
|
||
if options.Search != "" {
|
||
// 这里应该由具体Repository实现搜索逻辑
|
||
r.logger.Debug("搜索查询默认禁用缓存", zap.String("search", options.Search))
|
||
db = db.Set("cache:disabled", true)
|
||
}
|
||
|
||
// 应用预加载
|
||
for _, include := range options.Include {
|
||
db = db.Preload(include)
|
||
}
|
||
|
||
// 应用排序
|
||
if options.Sort != "" {
|
||
order := "ASC"
|
||
if options.Order == "desc" || options.Order == "DESC" {
|
||
order = "DESC"
|
||
}
|
||
db = db.Order(options.Sort + " " + order)
|
||
} else {
|
||
db = db.Order("created_at DESC")
|
||
}
|
||
|
||
// 应用分页
|
||
if options.Page > 0 && options.PageSize > 0 {
|
||
offset := (options.Page - 1) * options.PageSize
|
||
db = db.Offset(offset).Limit(options.PageSize)
|
||
}
|
||
|
||
return db.Find(dest).Error
|
||
}
|
||
|
||
// calculateCacheTTL 计算缓存TTL
|
||
func (r *CachedBaseRepositoryImpl) calculateCacheTTL(options interfaces.ListOptions) time.Duration {
|
||
// 基础TTL
|
||
baseTTL := 15 * time.Minute
|
||
|
||
// 如果有搜索,缩短TTL
|
||
if options.Search != "" {
|
||
return 2 * time.Minute
|
||
}
|
||
|
||
// 如果有复杂筛选,缩短TTL
|
||
if len(options.Filters) > 3 {
|
||
return 5 * time.Minute
|
||
}
|
||
|
||
// 如果是简单查询,延长TTL
|
||
if len(options.Filters) == 0 && options.Search == "" {
|
||
return 30 * time.Minute
|
||
}
|
||
|
||
return baseTTL
|
||
}
|
||
|
||
// shouldUseCache 判断是否应该使用缓存
|
||
func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions) bool {
|
||
// 如果有搜索,不使用缓存(搜索结果变化频繁)
|
||
if options.Search != "" {
|
||
return false
|
||
}
|
||
|
||
// 如果筛选条件过多,不使用缓存
|
||
if len(options.Filters) > 5 {
|
||
return false
|
||
}
|
||
|
||
// 如果分页页数过大,不使用缓存
|
||
if options.Page > 10 {
|
||
return false
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// ================ 缓存预热方法 ================
|
||
|
||
// WarmupCommonQueries 预热常用查询
|
||
func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error {
|
||
r.logger.Info("开始预热缓存",
|
||
zap.String("table", r.tableName),
|
||
zap.Int("queries", len(queries)),
|
||
)
|
||
|
||
for _, query := range queries {
|
||
if err := r.executeWarmupQuery(ctx, query); err != nil {
|
||
r.logger.Warn("缓存预热失败",
|
||
zap.String("query", query.Name),
|
||
zap.Error(err),
|
||
)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// WarmupQuery 预热查询定义
|
||
type WarmupQuery struct {
|
||
Name string `json:"name"`
|
||
SQL string `json:"sql"`
|
||
Args []interface{} `json:"args"`
|
||
TTL time.Duration `json:"ttl"`
|
||
Dest interface{} `json:"dest"`
|
||
}
|
||
|
||
// executeWarmupQuery 执行预热查询
|
||
func (r *CachedBaseRepositoryImpl) executeWarmupQuery(ctx context.Context, query WarmupQuery) error {
|
||
db := r.GetDB(ctx).
|
||
Set("cache:enabled", true).
|
||
Set("cache:ttl", query.TTL)
|
||
|
||
if query.SQL != "" {
|
||
return db.Raw(query.SQL, query.Args...).Scan(query.Dest).Error
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ================ 高级缓存特性 ================
|
||
|
||
// GetOrCreate 获取或创建(带缓存)
|
||
func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interface{}, where string, args []interface{}, createFn func() interface{}) error {
|
||
// 先尝试从缓存获取
|
||
if err := r.GetWithCache(ctx, dest, 15*time.Minute, where, args...); err == nil {
|
||
return nil
|
||
}
|
||
|
||
// 缓存未命中,尝试从数据库获取
|
||
if err := r.GetDB(ctx).Where(where, args...).First(dest).Error; err == nil {
|
||
return nil
|
||
}
|
||
|
||
// 数据库也没有,创建新记录
|
||
if createFn != nil {
|
||
newEntity := createFn()
|
||
if err := r.CreateEntity(ctx, newEntity); err != nil {
|
||
return err
|
||
}
|
||
|
||
// 将新创建的实体复制到dest
|
||
// 这里需要反射或其他方式复制
|
||
return nil
|
||
}
|
||
|
||
return gorm.ErrRecordNotFound
|
||
}
|
||
|
||
// BatchGetWithCache 批量获取(带缓存)
|
||
func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []string, dest interface{}, ttl time.Duration) error {
|
||
if len(ids) == 0 {
|
||
return nil
|
||
}
|
||
|
||
return r.FindWithCache(ctx, dest, ttl, "id IN ?", ids)
|
||
}
|
||
|
||
// RefreshCache 刷新缓存
|
||
func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error {
|
||
r.logger.Info("刷新缓存",
|
||
zap.String("table", r.tableName),
|
||
zap.String("pattern", pattern),
|
||
)
|
||
|
||
// 这里需要调用缓存服务的删除模式方法
|
||
// 具体实现取决于你的CacheService接口
|
||
return nil
|
||
}
|
||
|
||
// ================ 缓存统计方法 ================
|
||
|
||
// GetCacheInfo 获取缓存信息
|
||
func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
|
||
return map[string]interface{}{
|
||
"table_name": r.tableName,
|
||
"cache_enabled": true,
|
||
"default_ttl": "30m",
|
||
"cache_patterns": []string{
|
||
fmt.Sprintf("gorm_cache:%s:*", r.tableName),
|
||
},
|
||
}
|
||
}
|
||
|
||
// LogCacheOperation 记录缓存操作
|
||
func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) {
|
||
r.logger.Debug("缓存操作",
|
||
zap.String("table", r.tableName),
|
||
zap.String("operation", operation),
|
||
zap.String("details", details),
|
||
)
|
||
} |