temp
This commit is contained in:
		
							
								
								
									
										246
									
								
								internal/shared/database/base_repository.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										246
									
								
								internal/shared/database/base_repository.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,246 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
|  | ||||
| 	"tyapi-server/internal/shared/interfaces" | ||||
|  | ||||
| 	"go.uber.org/zap" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // BaseRepositoryImpl 基础仓储实现 | ||||
| // 提供统一的数据库连接、事务处理和通用辅助方法 | ||||
| type BaseRepositoryImpl struct { | ||||
| 	db     *gorm.DB | ||||
| 	logger *zap.Logger | ||||
| } | ||||
|  | ||||
| // NewBaseRepositoryImpl 创建基础仓储实现 | ||||
| func NewBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger) *BaseRepositoryImpl { | ||||
| 	return &BaseRepositoryImpl{ | ||||
| 		db:     db, | ||||
| 		logger: logger, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ================ 核心工具方法 ================ | ||||
|  | ||||
| // GetDB 获取数据库连接,优先使用事务 | ||||
| // 这是Repository层统一的数据库连接获取方法 | ||||
| func (r *BaseRepositoryImpl) GetDB(ctx context.Context) *gorm.DB { | ||||
| 	if tx, ok := GetTx(ctx); ok { | ||||
| 		return tx.WithContext(ctx) | ||||
| 	} | ||||
| 	return r.db.WithContext(ctx) | ||||
| } | ||||
|  | ||||
| // GetLogger 获取日志记录器 | ||||
| func (r *BaseRepositoryImpl) GetLogger() *zap.Logger { | ||||
| 	return r.logger | ||||
| } | ||||
|  | ||||
| // WithTx 使用事务创建新的Repository实例 | ||||
| func (r *BaseRepositoryImpl) WithTx(tx *gorm.DB) *BaseRepositoryImpl { | ||||
| 	return &BaseRepositoryImpl{ | ||||
| 		db:     tx, | ||||
| 		logger: r.logger, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ExecuteInTransaction 在事务中执行函数 | ||||
| func (r *BaseRepositoryImpl) ExecuteInTransaction(ctx context.Context, fn func(*gorm.DB) error) error { | ||||
| 	db := r.GetDB(ctx) | ||||
|  | ||||
| 	// 如果已经在事务中,直接执行 | ||||
| 	if _, ok := GetTx(ctx); ok { | ||||
| 		return fn(db) | ||||
| 	} | ||||
|  | ||||
| 	// 否则开启新事务 | ||||
| 	return db.Transaction(fn) | ||||
| } | ||||
|  | ||||
| // IsInTransaction 检查当前是否在事务中 | ||||
| func (r *BaseRepositoryImpl) IsInTransaction(ctx context.Context) bool { | ||||
| 	_, ok := GetTx(ctx) | ||||
| 	return ok | ||||
| } | ||||
|  | ||||
| // ================ 通用查询辅助方法 ================ | ||||
|  | ||||
| // FindWhere 根据条件查找实体列表 | ||||
| func (r *BaseRepositoryImpl) FindWhere(ctx context.Context, entities interface{}, condition string, args ...interface{}) error { | ||||
| 	return r.GetDB(ctx).Where(condition, args...).Find(entities).Error | ||||
| } | ||||
|  | ||||
| // FindOne 根据条件查找单个实体 | ||||
| func (r *BaseRepositoryImpl) FindOne(ctx context.Context, entity interface{}, condition string, args ...interface{}) error { | ||||
| 	return r.GetDB(ctx).Where(condition, args...).First(entity).Error | ||||
| } | ||||
|  | ||||
| // CountWhere 根据条件统计数量 | ||||
| func (r *BaseRepositoryImpl) CountWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (int64, error) { | ||||
| 	var count int64 | ||||
| 	err := r.GetDB(ctx).Model(entity).Where(condition, args...).Count(&count).Error | ||||
| 	return count, err | ||||
| } | ||||
|  | ||||
| // ExistsWhere 根据条件检查是否存在 | ||||
| func (r *BaseRepositoryImpl) ExistsWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (bool, error) { | ||||
| 	count, err := r.CountWhere(ctx, entity, condition, args...) | ||||
| 	return count > 0, err | ||||
| } | ||||
|  | ||||
| // ================ CRUD辅助方法 ================ | ||||
|  | ||||
| // CreateEntity 创建实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) CreateEntity(ctx context.Context, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Create(entity).Error | ||||
| } | ||||
|  | ||||
| // GetEntityByID 根据ID获取实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) GetEntityByID(ctx context.Context, id string, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Where("id = ?", id).First(entity).Error | ||||
| } | ||||
|  | ||||
| // UpdateEntity 更新实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) UpdateEntity(ctx context.Context, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Save(entity).Error | ||||
| } | ||||
|  | ||||
| // DeleteEntity 删除实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) DeleteEntity(ctx context.Context, id string, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Delete(entity, "id = ?", id).Error | ||||
| } | ||||
|  | ||||
| // ExistsEntity 检查实体是否存在(辅助方法) | ||||
| func (r *BaseRepositoryImpl) ExistsEntity(ctx context.Context, id string, entity interface{}) (bool, error) { | ||||
| 	var count int64 | ||||
| 	err := r.GetDB(ctx).Model(entity).Where("id = ?", id).Count(&count).Error | ||||
| 	return count > 0, err | ||||
| } | ||||
|  | ||||
| // ================ 批量操作辅助方法 ================ | ||||
|  | ||||
| // CreateBatchEntity 批量创建实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) CreateBatchEntity(ctx context.Context, entities interface{}) error { | ||||
| 	return r.GetDB(ctx).Create(entities).Error | ||||
| } | ||||
|  | ||||
| // GetEntitiesByIDs 根据ID列表获取实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) GetEntitiesByIDs(ctx context.Context, ids []string, entities interface{}) error { | ||||
| 	return r.GetDB(ctx).Where("id IN ?", ids).Find(entities).Error | ||||
| } | ||||
|  | ||||
| // UpdateBatchEntity 批量更新实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) UpdateBatchEntity(ctx context.Context, entities interface{}) error { | ||||
| 	return r.GetDB(ctx).Save(entities).Error | ||||
| } | ||||
|  | ||||
| // DeleteBatchEntity 批量删除实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) DeleteBatchEntity(ctx context.Context, ids []string, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Delete(entity, "id IN ?", ids).Error | ||||
| } | ||||
|  | ||||
| // ================ 软删除辅助方法 ================ | ||||
|  | ||||
| // SoftDeleteEntity 软删除实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) SoftDeleteEntity(ctx context.Context, id string, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Delete(entity, "id = ?", id).Error | ||||
| } | ||||
|  | ||||
| // RestoreEntity 恢复软删除的实体(辅助方法) | ||||
| func (r *BaseRepositoryImpl) RestoreEntity(ctx context.Context, id string, entity interface{}) error { | ||||
| 	return r.GetDB(ctx).Unscoped().Model(entity).Where("id = ?", id).Update("deleted_at", nil).Error | ||||
| } | ||||
|  | ||||
| // ================ 高级查询辅助方法 ================ | ||||
|  | ||||
| // ListWithOptions 获取实体列表(支持ListOptions,辅助方法) | ||||
| func (r *BaseRepositoryImpl) ListWithOptions(ctx context.Context, entity interface{}, entities interface{}, options interfaces.ListOptions) error { | ||||
| 	query := r.GetDB(ctx).Model(entity) | ||||
|  | ||||
| 	// 应用筛选条件 | ||||
| 	if options.Filters != nil { | ||||
| 		for key, value := range options.Filters { | ||||
| 			query = query.Where(key+" = ?", value) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 应用搜索条件(基础实现,具体Repository应该重写) | ||||
| 	if options.Search != "" { | ||||
| 		query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%") | ||||
| 	} | ||||
|  | ||||
| 	// 应用预加载 | ||||
| 	for _, include := range options.Include { | ||||
| 		query = query.Preload(include) | ||||
| 	} | ||||
|  | ||||
| 	// 应用排序 | ||||
| 	if options.Sort != "" { | ||||
| 		order := "ASC" | ||||
| 		if options.Order == "desc" || options.Order == "DESC" { | ||||
| 			order = "DESC" | ||||
| 		} | ||||
| 		query = query.Order(options.Sort + " " + order) | ||||
| 	} else { | ||||
| 		// 默认按创建时间倒序 | ||||
| 		query = query.Order("created_at DESC") | ||||
| 	} | ||||
|  | ||||
| 	// 应用分页 | ||||
| 	if options.Page > 0 && options.PageSize > 0 { | ||||
| 		offset := (options.Page - 1) * options.PageSize | ||||
| 		query = query.Offset(offset).Limit(options.PageSize) | ||||
| 	} | ||||
|  | ||||
| 	return query.Find(entities).Error | ||||
| } | ||||
|  | ||||
| // CountWithOptions 统计实体数量(支持CountOptions,辅助方法) | ||||
| func (r *BaseRepositoryImpl) CountWithOptions(ctx context.Context, entity interface{}, options interfaces.CountOptions) (int64, error) { | ||||
| 	var count int64 | ||||
| 	query := r.GetDB(ctx).Model(entity) | ||||
|  | ||||
| 	// 应用筛选条件 | ||||
| 	if options.Filters != nil { | ||||
| 		for key, value := range options.Filters { | ||||
| 			query = query.Where(key+" = ?", value) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 应用搜索条件(基础实现,具体Repository应该重写) | ||||
| 	if options.Search != "" { | ||||
| 		query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%") | ||||
| 	} | ||||
|  | ||||
| 	err := query.Count(&count).Error | ||||
| 	return count, err | ||||
| } | ||||
|  | ||||
| // ================ 常用查询模式 ================ | ||||
|  | ||||
| // FindByField 根据单个字段查找实体列表 | ||||
| func (r *BaseRepositoryImpl) FindByField(ctx context.Context, entities interface{}, field string, value interface{}) error { | ||||
| 	return r.GetDB(ctx).Where(field+" = ?", value).Find(entities).Error | ||||
| } | ||||
|  | ||||
| // FindOneByField 根据单个字段查找单个实体 | ||||
| func (r *BaseRepositoryImpl) FindOneByField(ctx context.Context, entity interface{}, field string, value interface{}) error { | ||||
| 	return r.GetDB(ctx).Where(field+" = ?", value).First(entity).Error | ||||
| } | ||||
|  | ||||
| // CountByField 根据单个字段统计数量 | ||||
| func (r *BaseRepositoryImpl) CountByField(ctx context.Context, entity interface{}, field string, value interface{}) (int64, error) { | ||||
| 	var count int64 | ||||
| 	err := r.GetDB(ctx).Model(entity).Where(field+" = ?", value).Count(&count).Error | ||||
| 	return count, err | ||||
| } | ||||
|  | ||||
| // ExistsByField 根据单个字段检查是否存在 | ||||
| func (r *BaseRepositoryImpl) ExistsByField(ctx context.Context, entity interface{}, field string, value interface{}) (bool, error) { | ||||
| 	count, err := r.CountByField(ctx, entity, field, value) | ||||
| 	return count > 0, err | ||||
| } | ||||
							
								
								
									
										367
									
								
								internal/shared/database/cached_base_repository.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										367
									
								
								internal/shared/database/cached_base_repository.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,367 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.uber.org/zap" | ||||
| 	"gorm.io/gorm" | ||||
|  | ||||
| 	"tyapi-server/internal/shared/interfaces" | ||||
| ) | ||||
|  | ||||
| // CachedBaseRepositoryImpl 支持缓存的基础仓储实现 | ||||
| // 在BaseRepositoryImpl基础上增加智能缓存管理 | ||||
| type CachedBaseRepositoryImpl struct { | ||||
| 	*BaseRepositoryImpl | ||||
| 	tableName string | ||||
| } | ||||
|  | ||||
| // NewCachedBaseRepositoryImpl 创建支持缓存的基础仓储实现 | ||||
| func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName string) *CachedBaseRepositoryImpl { | ||||
| 	return &CachedBaseRepositoryImpl{ | ||||
| 		BaseRepositoryImpl: NewBaseRepositoryImpl(db, logger), | ||||
| 		tableName:          tableName, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ================ 智能缓存方法 ================ | ||||
|  | ||||
| // GetWithCache 带缓存的单条查询 | ||||
| func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl) | ||||
| 	 | ||||
| 	return db.Where(where, args...).First(dest).Error | ||||
| } | ||||
|  | ||||
| // FindWithCache 带缓存的多条查询 | ||||
| func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl) | ||||
| 	 | ||||
| 	return db.Where(where, args...).Find(dest).Error | ||||
| } | ||||
|  | ||||
| // CountWithCache 带缓存的计数查询 | ||||
| func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl). | ||||
| 		Model(entity) | ||||
| 	 | ||||
| 	return db.Where(where, args...).Count(count).Error | ||||
| } | ||||
|  | ||||
| // ListWithCache 带缓存的列表查询 | ||||
| func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", ttl) | ||||
|  | ||||
| 	// 应用where条件 | ||||
| 	if options.Where != "" { | ||||
| 		db = db.Where(options.Where, options.Args...) | ||||
| 	} | ||||
|  | ||||
| 	// 应用预加载 | ||||
| 	for _, preload := range options.Preloads { | ||||
| 		db = db.Preload(preload) | ||||
| 	} | ||||
|  | ||||
| 	// 应用排序 | ||||
| 	if options.Order != "" { | ||||
| 		db = db.Order(options.Order) | ||||
| 	} | ||||
|  | ||||
| 	// 应用分页 | ||||
| 	if options.Limit > 0 { | ||||
| 		db = db.Limit(options.Limit) | ||||
| 	} | ||||
| 	if options.Offset > 0 { | ||||
| 		db = db.Offset(options.Offset) | ||||
| 	} | ||||
|  | ||||
| 	return db.Find(dest).Error | ||||
| } | ||||
|  | ||||
| // CacheListOptions 缓存列表查询选项 | ||||
| type CacheListOptions struct { | ||||
| 	Where     string        `json:"where"` | ||||
| 	Args      []interface{} `json:"args"` | ||||
| 	Order     string        `json:"order"` | ||||
| 	Limit     int           `json:"limit"` | ||||
| 	Offset    int           `json:"offset"` | ||||
| 	Preloads  []string      `json:"preloads"` | ||||
| } | ||||
|  | ||||
| // ================ 缓存控制方法 ================ | ||||
|  | ||||
| // WithCache 启用缓存 | ||||
| func (r *CachedBaseRepositoryImpl) WithCache(ttl time.Duration) *CachedBaseRepositoryImpl { | ||||
| 	// 创建新实例避免状态污染 | ||||
| 	return &CachedBaseRepositoryImpl{ | ||||
| 		BaseRepositoryImpl: &BaseRepositoryImpl{ | ||||
| 			db:     r.db.Set("cache:enabled", true).Set("cache:ttl", ttl), | ||||
| 			logger: r.logger, | ||||
| 		}, | ||||
| 		tableName: r.tableName, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithoutCache 禁用缓存 | ||||
| func (r *CachedBaseRepositoryImpl) WithoutCache() *CachedBaseRepositoryImpl { | ||||
| 	return &CachedBaseRepositoryImpl{ | ||||
| 		BaseRepositoryImpl: &BaseRepositoryImpl{ | ||||
| 			db:     r.db.Set("cache:disabled", true), | ||||
| 			logger: r.logger, | ||||
| 		}, | ||||
| 		tableName: r.tableName, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // WithShortCache 短期缓存(5分钟) | ||||
| func (r *CachedBaseRepositoryImpl) WithShortCache() *CachedBaseRepositoryImpl { | ||||
| 	return r.WithCache(5 * time.Minute) | ||||
| } | ||||
|  | ||||
| // WithMediumCache 中期缓存(30分钟) | ||||
| func (r *CachedBaseRepositoryImpl) WithMediumCache() *CachedBaseRepositoryImpl { | ||||
| 	return r.WithCache(30 * time.Minute) | ||||
| } | ||||
|  | ||||
| // WithLongCache 长期缓存(2小时) | ||||
| func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl { | ||||
| 	return r.WithCache(2 * time.Hour) | ||||
| } | ||||
|  | ||||
| // ================ 智能查询方法 ================ | ||||
|  | ||||
| // SmartGetByID 智能ID查询(自动缓存) | ||||
| func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error { | ||||
| 	return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id) | ||||
| } | ||||
|  | ||||
| // SmartGetByField 智能字段查询(自动缓存) | ||||
| func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest interface{}, field string, value interface{}, ttl ...time.Duration) error { | ||||
| 	cacheTTL := 15 * time.Minute | ||||
| 	if len(ttl) > 0 { | ||||
| 		cacheTTL = ttl[0] | ||||
| 	} | ||||
| 	 | ||||
| 	return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value) | ||||
| } | ||||
|  | ||||
| // SmartList 智能列表查询(根据查询复杂度自动选择缓存策略) | ||||
| func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface{}, options interfaces.ListOptions) error { | ||||
| 	// 根据查询复杂度决定缓存策略 | ||||
| 	cacheTTL := r.calculateCacheTTL(options) | ||||
| 	useCache := r.shouldUseCache(options) | ||||
|  | ||||
| 	db := r.GetDB(ctx) | ||||
| 	if useCache { | ||||
| 		db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL) | ||||
| 	} else { | ||||
| 		db = db.Set("cache:disabled", true) | ||||
| 	} | ||||
|  | ||||
| 	// 应用筛选条件 | ||||
| 	if options.Filters != nil { | ||||
| 		for key, value := range options.Filters { | ||||
| 			db = db.Where(key+" = ?", value) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 应用搜索条件 | ||||
| 	if options.Search != "" { | ||||
| 		// 这里应该由具体Repository实现搜索逻辑 | ||||
| 		r.logger.Debug("搜索查询默认禁用缓存", zap.String("search", options.Search)) | ||||
| 		db = db.Set("cache:disabled", true) | ||||
| 	} | ||||
|  | ||||
| 	// 应用预加载 | ||||
| 	for _, include := range options.Include { | ||||
| 		db = db.Preload(include) | ||||
| 	} | ||||
|  | ||||
| 	// 应用排序 | ||||
| 	if options.Sort != "" { | ||||
| 		order := "ASC" | ||||
| 		if options.Order == "desc" || options.Order == "DESC" { | ||||
| 			order = "DESC" | ||||
| 		} | ||||
| 		db = db.Order(options.Sort + " " + order) | ||||
| 	} else { | ||||
| 		db = db.Order("created_at DESC") | ||||
| 	} | ||||
|  | ||||
| 	// 应用分页 | ||||
| 	if options.Page > 0 && options.PageSize > 0 { | ||||
| 		offset := (options.Page - 1) * options.PageSize | ||||
| 		db = db.Offset(offset).Limit(options.PageSize) | ||||
| 	} | ||||
|  | ||||
| 	return db.Find(dest).Error | ||||
| } | ||||
|  | ||||
| // calculateCacheTTL 计算缓存TTL | ||||
| func (r *CachedBaseRepositoryImpl) calculateCacheTTL(options interfaces.ListOptions) time.Duration { | ||||
| 	// 基础TTL | ||||
| 	baseTTL := 15 * time.Minute | ||||
|  | ||||
| 	// 如果有搜索,缩短TTL | ||||
| 	if options.Search != "" { | ||||
| 		return 2 * time.Minute | ||||
| 	} | ||||
|  | ||||
| 	// 如果有复杂筛选,缩短TTL | ||||
| 	if len(options.Filters) > 3 { | ||||
| 		return 5 * time.Minute | ||||
| 	} | ||||
|  | ||||
| 	// 如果是简单查询,延长TTL | ||||
| 	if len(options.Filters) == 0 && options.Search == "" { | ||||
| 		return 30 * time.Minute | ||||
| 	} | ||||
|  | ||||
| 	return baseTTL | ||||
| } | ||||
|  | ||||
| // shouldUseCache 判断是否应该使用缓存 | ||||
| func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions) bool { | ||||
| 	// 如果有搜索,不使用缓存(搜索结果变化频繁) | ||||
| 	if options.Search != "" { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 如果筛选条件过多,不使用缓存 | ||||
| 	if len(options.Filters) > 5 { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 如果分页页数过大,不使用缓存 | ||||
| 	if options.Page > 10 { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // ================ 缓存预热方法 ================ | ||||
|  | ||||
| // WarmupCommonQueries 预热常用查询 | ||||
| func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error { | ||||
| 	r.logger.Info("开始预热缓存",  | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.Int("queries", len(queries)), | ||||
| 	) | ||||
|  | ||||
| 	for _, query := range queries { | ||||
| 		if err := r.executeWarmupQuery(ctx, query); err != nil { | ||||
| 			r.logger.Warn("缓存预热失败",  | ||||
| 				zap.String("query", query.Name), | ||||
| 				zap.Error(err), | ||||
| 			) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // WarmupQuery 预热查询定义 | ||||
| type WarmupQuery struct { | ||||
| 	Name     string        `json:"name"` | ||||
| 	SQL      string        `json:"sql"` | ||||
| 	Args     []interface{} `json:"args"` | ||||
| 	TTL      time.Duration `json:"ttl"` | ||||
| 	Dest     interface{}   `json:"dest"` | ||||
| } | ||||
|  | ||||
| // executeWarmupQuery 执行预热查询 | ||||
| func (r *CachedBaseRepositoryImpl) executeWarmupQuery(ctx context.Context, query WarmupQuery) error { | ||||
| 	db := r.GetDB(ctx). | ||||
| 		Set("cache:enabled", true). | ||||
| 		Set("cache:ttl", query.TTL) | ||||
|  | ||||
| 	if query.SQL != "" { | ||||
| 		return db.Raw(query.SQL, query.Args...).Scan(query.Dest).Error | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ================ 高级缓存特性 ================ | ||||
|  | ||||
| // GetOrCreate 获取或创建(带缓存) | ||||
| func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interface{}, where string, args []interface{}, createFn func() interface{}) error { | ||||
| 	// 先尝试从缓存获取 | ||||
| 	if err := r.GetWithCache(ctx, dest, 15*time.Minute, where, args...); err == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// 缓存未命中,尝试从数据库获取 | ||||
| 	if err := r.GetDB(ctx).Where(where, args...).First(dest).Error; err == nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	// 数据库也没有,创建新记录 | ||||
| 	if createFn != nil { | ||||
| 		newEntity := createFn() | ||||
| 		if err := r.CreateEntity(ctx, newEntity); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		 | ||||
| 		// 将新创建的实体复制到dest | ||||
| 		// 这里需要反射或其他方式复制 | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return gorm.ErrRecordNotFound | ||||
| } | ||||
|  | ||||
| // BatchGetWithCache 批量获取(带缓存) | ||||
| func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []string, dest interface{}, ttl time.Duration) error { | ||||
| 	if len(ids) == 0 { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	return r.FindWithCache(ctx, dest, ttl, "id IN ?", ids) | ||||
| } | ||||
|  | ||||
| // RefreshCache 刷新缓存 | ||||
| func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error { | ||||
| 	r.logger.Info("刷新缓存",  | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.String("pattern", pattern), | ||||
| 	) | ||||
|  | ||||
| 	// 这里需要调用缓存服务的删除模式方法 | ||||
| 	// 具体实现取决于你的CacheService接口 | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ================ 缓存统计方法 ================ | ||||
|  | ||||
| // GetCacheInfo 获取缓存信息 | ||||
| func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} { | ||||
| 	return map[string]interface{}{ | ||||
| 		"table_name":       r.tableName, | ||||
| 		"cache_enabled":    true, | ||||
| 		"default_ttl":      "30m", | ||||
| 		"cache_patterns": []string{ | ||||
| 			fmt.Sprintf("gorm_cache:%s:*", r.tableName), | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // LogCacheOperation 记录缓存操作 | ||||
| func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) { | ||||
| 	r.logger.Debug("缓存操作",  | ||||
| 		zap.String("table", r.tableName), | ||||
| 		zap.String("operation", operation), | ||||
| 		zap.String("details", details), | ||||
| 	) | ||||
| }  | ||||
							
								
								
									
										301
									
								
								internal/shared/database/transaction.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										301
									
								
								internal/shared/database/transaction.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,301 @@ | ||||
| package database | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"errors" | ||||
| 	"strings" | ||||
| 	"time" | ||||
|  | ||||
| 	"go.uber.org/zap" | ||||
| 	"gorm.io/gorm" | ||||
| ) | ||||
|  | ||||
| // 自定义错误类型 | ||||
| var ( | ||||
| 	ErrTransactionRollback = errors.New("事务回滚失败") | ||||
| 	ErrTransactionCommit   = errors.New("事务提交失败") | ||||
| ) | ||||
|  | ||||
| // 定义context key | ||||
| type txKey struct{} | ||||
|  | ||||
| // WithTx 将事务对象存储到context中 | ||||
| func WithTx(ctx context.Context, tx *gorm.DB) context.Context { | ||||
| 	return context.WithValue(ctx, txKey{}, tx) | ||||
| } | ||||
|  | ||||
| // GetTx 从context中获取事务对象 | ||||
| func GetTx(ctx context.Context) (*gorm.DB, bool) { | ||||
| 	tx, ok := ctx.Value(txKey{}).(*gorm.DB) | ||||
| 	return tx, ok | ||||
| } | ||||
|  | ||||
| // TransactionManager 事务管理器 | ||||
| type TransactionManager struct { | ||||
| 	db     *gorm.DB | ||||
| 	logger *zap.Logger | ||||
| } | ||||
|  | ||||
| // NewTransactionManager 创建事务管理器 | ||||
| func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager { | ||||
| 	return &TransactionManager{ | ||||
| 		db:     db, | ||||
| 		logger: logger, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ExecuteInTx 在事务中执行函数(推荐使用) | ||||
| // 自动处理事务的开启、提交和回滚 | ||||
| func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error { | ||||
| 	// 检查是否已经在事务中 | ||||
| 	if _, ok := GetTx(ctx); ok { | ||||
| 		// 如果已经在事务中,直接执行函数,避免嵌套事务 | ||||
| 		return fn(ctx) | ||||
| 	} | ||||
|  | ||||
| 	tx := tm.db.Begin() | ||||
| 	if tx.Error != nil { | ||||
| 		return tx.Error | ||||
| 	} | ||||
|  | ||||
| 	// 创建带事务的context | ||||
| 	txCtx := WithTx(ctx, tx) | ||||
|  | ||||
| 	// 执行函数 | ||||
| 	if err := fn(txCtx); err != nil { | ||||
| 		// 回滚事务 | ||||
| 		if rbErr := tx.Rollback().Error; rbErr != nil { | ||||
| 			tm.logger.Error("事务回滚失败", | ||||
| 				zap.Error(err), | ||||
| 				zap.Error(rbErr), | ||||
| 			) | ||||
| 			return errors.Join(err, ErrTransactionRollback, rbErr) | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 提交事务 | ||||
| 	if err := tx.Commit().Error; err != nil { | ||||
| 		tm.logger.Error("事务提交失败", zap.Error(err)) | ||||
| 		return errors.Join(ErrTransactionCommit, err) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // ExecuteInTxWithTimeout 在事务中执行函数(带超时) | ||||
| func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error { | ||||
| 	ctx, cancel := context.WithTimeout(ctx, timeout) | ||||
| 	defer cancel() | ||||
|  | ||||
| 	return tm.ExecuteInTx(ctx, fn) | ||||
| } | ||||
|  | ||||
| // BeginTx 开始事务(手动管理) | ||||
| func (tm *TransactionManager) BeginTx() *gorm.DB { | ||||
| 	return tm.db.Begin() | ||||
| } | ||||
|  | ||||
| // TxWrapper 事务包装器(手动管理) | ||||
| type TxWrapper struct { | ||||
| 	tx *gorm.DB | ||||
| } | ||||
|  | ||||
| // NewTxWrapper 创建事务包装器 | ||||
| func (tm *TransactionManager) NewTxWrapper() *TxWrapper { | ||||
| 	return &TxWrapper{ | ||||
| 		tx: tm.BeginTx(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Commit 提交事务 | ||||
| func (tx *TxWrapper) Commit() error { | ||||
| 	return tx.tx.Commit().Error | ||||
| } | ||||
|  | ||||
| // Rollback 回滚事务 | ||||
| func (tx *TxWrapper) Rollback() error { | ||||
| 	return tx.tx.Rollback().Error | ||||
| } | ||||
|  | ||||
| // GetDB 获取事务数据库实例 | ||||
| func (tx *TxWrapper) GetDB() *gorm.DB { | ||||
| 	return tx.tx | ||||
| } | ||||
|  | ||||
| // WithTx 在事务中执行函数(兼容旧接口) | ||||
| func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error { | ||||
| 	tx := tm.BeginTx() | ||||
| 	defer func() { | ||||
| 		if r := recover(); r != nil { | ||||
| 			tx.Rollback() | ||||
| 			panic(r) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	if err := fn(tx); err != nil { | ||||
| 		tx.Rollback() | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return tx.Commit().Error | ||||
| } | ||||
|  | ||||
| // TransactionOptions 事务选项 | ||||
| type TransactionOptions struct { | ||||
| 	Timeout  time.Duration | ||||
| 	ReadOnly bool // 是否只读事务 | ||||
| } | ||||
|  | ||||
| // ExecuteInTxWithOptions 在事务中执行函数(带选项) | ||||
| func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error { | ||||
| 	// 设置事务选项 | ||||
| 	tx := tm.db.Begin() | ||||
| 	if tx.Error != nil { | ||||
| 		return tx.Error | ||||
| 	} | ||||
|  | ||||
| 	// 设置只读事务 | ||||
| 	if options != nil && options.ReadOnly { | ||||
| 		tx = tx.Session(&gorm.Session{}) | ||||
| 		// 注意:GORM的只读事务需要数据库支持,这里只是标记 | ||||
| 	} | ||||
|  | ||||
| 	// 创建带事务的context | ||||
| 	txCtx := WithTx(ctx, tx) | ||||
|  | ||||
| 	// 设置超时 | ||||
| 	if options != nil && options.Timeout > 0 { | ||||
| 		var cancel context.CancelFunc | ||||
| 		txCtx, cancel = context.WithTimeout(txCtx, options.Timeout) | ||||
| 		defer cancel() | ||||
| 	} | ||||
|  | ||||
| 	// 执行函数 | ||||
| 	if err := fn(txCtx); err != nil { | ||||
| 		// 回滚事务 | ||||
| 		if rbErr := tx.Rollback().Error; rbErr != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 提交事务 | ||||
| 	return tx.Commit().Error | ||||
| } | ||||
|  | ||||
| // TransactionStats 事务统计信息 | ||||
| type TransactionStats struct { | ||||
| 	TotalTransactions      int64 | ||||
| 	SuccessfulTransactions int64 | ||||
| 	FailedTransactions     int64 | ||||
| 	AverageDuration        time.Duration | ||||
| } | ||||
|  | ||||
| // GetStats 获取事务统计信息(预留接口) | ||||
| func (tm *TransactionManager) GetStats() *TransactionStats { | ||||
| 	// TODO: 实现事务统计 | ||||
| 	return &TransactionStats{} | ||||
| } | ||||
|  | ||||
| // RetryableTransactionOptions 可重试事务选项 | ||||
| type RetryableTransactionOptions struct { | ||||
| 	MaxRetries   int           // 最大重试次数 | ||||
| 	RetryDelay   time.Duration // 重试延迟 | ||||
| 	RetryBackoff float64       // 退避倍数 | ||||
| } | ||||
|  | ||||
| // DefaultRetryableOptions 默认重试选项 | ||||
| func DefaultRetryableOptions() *RetryableTransactionOptions { | ||||
| 	return &RetryableTransactionOptions{ | ||||
| 		MaxRetries:   3, | ||||
| 		RetryDelay:   100 * time.Millisecond, | ||||
| 		RetryBackoff: 2.0, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ExecuteInTxWithRetry 在事务中执行函数(支持重试) | ||||
| // 适用于处理死锁等临时性错误 | ||||
| func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error { | ||||
| 	if options == nil { | ||||
| 		options = DefaultRetryableOptions() | ||||
| 	} | ||||
|  | ||||
| 	var lastErr error | ||||
| 	delay := options.RetryDelay | ||||
|  | ||||
| 	for attempt := 0; attempt <= options.MaxRetries; attempt++ { | ||||
| 		// 检查上下文是否已取消 | ||||
| 		if ctx.Err() != nil { | ||||
| 			return ctx.Err() | ||||
| 		} | ||||
|  | ||||
| 		err := tm.ExecuteInTx(ctx, fn) | ||||
| 		if err == nil { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		// 检查是否是可重试的错误(死锁、连接错误等) | ||||
| 		if !isRetryableError(err) { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		lastErr = err | ||||
|  | ||||
| 		// 如果不是最后一次尝试,等待后重试 | ||||
| 		if attempt < options.MaxRetries { | ||||
| 			tm.logger.Warn("事务执行失败,准备重试", | ||||
| 				zap.Int("attempt", attempt+1), | ||||
| 				zap.Int("max_retries", options.MaxRetries), | ||||
| 				zap.Duration("delay", delay), | ||||
| 				zap.Error(err), | ||||
| 			) | ||||
|  | ||||
| 			select { | ||||
| 			case <-time.After(delay): | ||||
| 				delay = time.Duration(float64(delay) * options.RetryBackoff) | ||||
| 			case <-ctx.Done(): | ||||
| 				return ctx.Err() | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	tm.logger.Error("事务执行失败,已超过最大重试次数", | ||||
| 		zap.Int("max_retries", options.MaxRetries), | ||||
| 		zap.Error(lastErr), | ||||
| 	) | ||||
|  | ||||
| 	return lastErr | ||||
| } | ||||
|  | ||||
| // isRetryableError 判断是否是可重试的错误 | ||||
| func isRetryableError(err error) bool { | ||||
| 	if err == nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	errStr := err.Error() | ||||
|  | ||||
| 	// MySQL 死锁错误 | ||||
| 	if contains(errStr, "Deadlock found") { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// MySQL 锁等待超时 | ||||
| 	if contains(errStr, "Lock wait timeout exceeded") { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// 连接错误 | ||||
| 	if contains(errStr, "connection") { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// 可以根据需要添加更多的可重试错误类型 | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| // contains 检查字符串是否包含子字符串(不区分大小写) | ||||
| func contains(s, substr string) bool { | ||||
| 	return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) | ||||
| } | ||||
		Reference in New Issue
	
	Block a user