temp
This commit is contained in:
246
internal/shared/database/base_repository.go
Normal file
246
internal/shared/database/base_repository.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// BaseRepositoryImpl 基础仓储实现
|
||||
// 提供统一的数据库连接、事务处理和通用辅助方法
|
||||
type BaseRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewBaseRepositoryImpl 创建基础仓储实现
|
||||
func NewBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger) *BaseRepositoryImpl {
|
||||
return &BaseRepositoryImpl{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 核心工具方法 ================
|
||||
|
||||
// GetDB 获取数据库连接,优先使用事务
|
||||
// 这是Repository层统一的数据库连接获取方法
|
||||
func (r *BaseRepositoryImpl) GetDB(ctx context.Context) *gorm.DB {
|
||||
if tx, ok := GetTx(ctx); ok {
|
||||
return tx.WithContext(ctx)
|
||||
}
|
||||
return r.db.WithContext(ctx)
|
||||
}
|
||||
|
||||
// GetLogger 获取日志记录器
|
||||
func (r *BaseRepositoryImpl) GetLogger() *zap.Logger {
|
||||
return r.logger
|
||||
}
|
||||
|
||||
// WithTx 使用事务创建新的Repository实例
|
||||
func (r *BaseRepositoryImpl) WithTx(tx *gorm.DB) *BaseRepositoryImpl {
|
||||
return &BaseRepositoryImpl{
|
||||
db: tx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInTransaction 在事务中执行函数
|
||||
func (r *BaseRepositoryImpl) ExecuteInTransaction(ctx context.Context, fn func(*gorm.DB) error) error {
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 如果已经在事务中,直接执行
|
||||
if _, ok := GetTx(ctx); ok {
|
||||
return fn(db)
|
||||
}
|
||||
|
||||
// 否则开启新事务
|
||||
return db.Transaction(fn)
|
||||
}
|
||||
|
||||
// IsInTransaction 检查当前是否在事务中
|
||||
func (r *BaseRepositoryImpl) IsInTransaction(ctx context.Context) bool {
|
||||
_, ok := GetTx(ctx)
|
||||
return ok
|
||||
}
|
||||
|
||||
// ================ 通用查询辅助方法 ================
|
||||
|
||||
// FindWhere 根据条件查找实体列表
|
||||
func (r *BaseRepositoryImpl) FindWhere(ctx context.Context, entities interface{}, condition string, args ...interface{}) error {
|
||||
return r.GetDB(ctx).Where(condition, args...).Find(entities).Error
|
||||
}
|
||||
|
||||
// FindOne 根据条件查找单个实体
|
||||
func (r *BaseRepositoryImpl) FindOne(ctx context.Context, entity interface{}, condition string, args ...interface{}) error {
|
||||
return r.GetDB(ctx).Where(condition, args...).First(entity).Error
|
||||
}
|
||||
|
||||
// CountWhere 根据条件统计数量
|
||||
func (r *BaseRepositoryImpl) CountWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(entity).Where(condition, args...).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ExistsWhere 根据条件检查是否存在
|
||||
func (r *BaseRepositoryImpl) ExistsWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (bool, error) {
|
||||
count, err := r.CountWhere(ctx, entity, condition, args...)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ================ CRUD辅助方法 ================
|
||||
|
||||
// CreateEntity 创建实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) CreateEntity(ctx context.Context, entity interface{}) error {
|
||||
return r.GetDB(ctx).Create(entity).Error
|
||||
}
|
||||
|
||||
// GetEntityByID 根据ID获取实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) GetEntityByID(ctx context.Context, id string, entity interface{}) error {
|
||||
return r.GetDB(ctx).Where("id = ?", id).First(entity).Error
|
||||
}
|
||||
|
||||
// UpdateEntity 更新实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) UpdateEntity(ctx context.Context, entity interface{}) error {
|
||||
return r.GetDB(ctx).Save(entity).Error
|
||||
}
|
||||
|
||||
// DeleteEntity 删除实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) DeleteEntity(ctx context.Context, id string, entity interface{}) error {
|
||||
return r.GetDB(ctx).Delete(entity, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// ExistsEntity 检查实体是否存在(辅助方法)
|
||||
func (r *BaseRepositoryImpl) ExistsEntity(ctx context.Context, id string, entity interface{}) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(entity).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ================ 批量操作辅助方法 ================
|
||||
|
||||
// CreateBatchEntity 批量创建实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) CreateBatchEntity(ctx context.Context, entities interface{}) error {
|
||||
return r.GetDB(ctx).Create(entities).Error
|
||||
}
|
||||
|
||||
// GetEntitiesByIDs 根据ID列表获取实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) GetEntitiesByIDs(ctx context.Context, ids []string, entities interface{}) error {
|
||||
return r.GetDB(ctx).Where("id IN ?", ids).Find(entities).Error
|
||||
}
|
||||
|
||||
// UpdateBatchEntity 批量更新实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) UpdateBatchEntity(ctx context.Context, entities interface{}) error {
|
||||
return r.GetDB(ctx).Save(entities).Error
|
||||
}
|
||||
|
||||
// DeleteBatchEntity 批量删除实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) DeleteBatchEntity(ctx context.Context, ids []string, entity interface{}) error {
|
||||
return r.GetDB(ctx).Delete(entity, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// ================ 软删除辅助方法 ================
|
||||
|
||||
// SoftDeleteEntity 软删除实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) SoftDeleteEntity(ctx context.Context, id string, entity interface{}) error {
|
||||
return r.GetDB(ctx).Delete(entity, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// RestoreEntity 恢复软删除的实体(辅助方法)
|
||||
func (r *BaseRepositoryImpl) RestoreEntity(ctx context.Context, id string, entity interface{}) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(entity).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 高级查询辅助方法 ================
|
||||
|
||||
// ListWithOptions 获取实体列表(支持ListOptions,辅助方法)
|
||||
func (r *BaseRepositoryImpl) ListWithOptions(ctx context.Context, entity interface{}, entities interface{}, options interfaces.ListOptions) error {
|
||||
query := r.GetDB(ctx).Model(entity)
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件(基础实现,具体Repository应该重写)
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, include := range options.Include {
|
||||
query = query.Preload(include)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return query.Find(entities).Error
|
||||
}
|
||||
|
||||
// CountWithOptions 统计实体数量(支持CountOptions,辅助方法)
|
||||
func (r *BaseRepositoryImpl) CountWithOptions(ctx context.Context, entity interface{}, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(entity)
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件(基础实现,具体Repository应该重写)
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ================ 常用查询模式 ================
|
||||
|
||||
// FindByField 根据单个字段查找实体列表
|
||||
func (r *BaseRepositoryImpl) FindByField(ctx context.Context, entities interface{}, field string, value interface{}) error {
|
||||
return r.GetDB(ctx).Where(field+" = ?", value).Find(entities).Error
|
||||
}
|
||||
|
||||
// FindOneByField 根据单个字段查找单个实体
|
||||
func (r *BaseRepositoryImpl) FindOneByField(ctx context.Context, entity interface{}, field string, value interface{}) error {
|
||||
return r.GetDB(ctx).Where(field+" = ?", value).First(entity).Error
|
||||
}
|
||||
|
||||
// CountByField 根据单个字段统计数量
|
||||
func (r *BaseRepositoryImpl) CountByField(ctx context.Context, entity interface{}, field string, value interface{}) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(entity).Where(field+" = ?", value).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ExistsByField 根据单个字段检查是否存在
|
||||
func (r *BaseRepositoryImpl) ExistsByField(ctx context.Context, entity interface{}, field string, value interface{}) (bool, error) {
|
||||
count, err := r.CountByField(ctx, entity, field, value)
|
||||
return count > 0, err
|
||||
}
|
||||
367
internal/shared/database/cached_base_repository.go
Normal file
367
internal/shared/database/cached_base_repository.go
Normal file
@@ -0,0 +1,367 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// CachedBaseRepositoryImpl 支持缓存的基础仓储实现
|
||||
// 在BaseRepositoryImpl基础上增加智能缓存管理
|
||||
type CachedBaseRepositoryImpl struct {
|
||||
*BaseRepositoryImpl
|
||||
tableName string
|
||||
}
|
||||
|
||||
// NewCachedBaseRepositoryImpl 创建支持缓存的基础仓储实现
|
||||
func NewCachedBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger, tableName string) *CachedBaseRepositoryImpl {
|
||||
return &CachedBaseRepositoryImpl{
|
||||
BaseRepositoryImpl: NewBaseRepositoryImpl(db, logger),
|
||||
tableName: tableName,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 智能缓存方法 ================
|
||||
|
||||
// GetWithCache 带缓存的单条查询
|
||||
func (r *CachedBaseRepositoryImpl) GetWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl)
|
||||
|
||||
return db.Where(where, args...).First(dest).Error
|
||||
}
|
||||
|
||||
// FindWithCache 带缓存的多条查询
|
||||
func (r *CachedBaseRepositoryImpl) FindWithCache(ctx context.Context, dest interface{}, ttl time.Duration, where string, args ...interface{}) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl)
|
||||
|
||||
return db.Where(where, args...).Find(dest).Error
|
||||
}
|
||||
|
||||
// CountWithCache 带缓存的计数查询
|
||||
func (r *CachedBaseRepositoryImpl) CountWithCache(ctx context.Context, count *int64, ttl time.Duration, entity interface{}, where string, args ...interface{}) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl).
|
||||
Model(entity)
|
||||
|
||||
return db.Where(where, args...).Count(count).Error
|
||||
}
|
||||
|
||||
// ListWithCache 带缓存的列表查询
|
||||
func (r *CachedBaseRepositoryImpl) ListWithCache(ctx context.Context, dest interface{}, ttl time.Duration, options CacheListOptions) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", ttl)
|
||||
|
||||
// 应用where条件
|
||||
if options.Where != "" {
|
||||
db = db.Where(options.Where, options.Args...)
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, preload := range options.Preloads {
|
||||
db = db.Preload(preload)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Order != "" {
|
||||
db = db.Order(options.Order)
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Limit > 0 {
|
||||
db = db.Limit(options.Limit)
|
||||
}
|
||||
if options.Offset > 0 {
|
||||
db = db.Offset(options.Offset)
|
||||
}
|
||||
|
||||
return db.Find(dest).Error
|
||||
}
|
||||
|
||||
// CacheListOptions 缓存列表查询选项
|
||||
type CacheListOptions struct {
|
||||
Where string `json:"where"`
|
||||
Args []interface{} `json:"args"`
|
||||
Order string `json:"order"`
|
||||
Limit int `json:"limit"`
|
||||
Offset int `json:"offset"`
|
||||
Preloads []string `json:"preloads"`
|
||||
}
|
||||
|
||||
// ================ 缓存控制方法 ================
|
||||
|
||||
// WithCache 启用缓存
|
||||
func (r *CachedBaseRepositoryImpl) WithCache(ttl time.Duration) *CachedBaseRepositoryImpl {
|
||||
// 创建新实例避免状态污染
|
||||
return &CachedBaseRepositoryImpl{
|
||||
BaseRepositoryImpl: &BaseRepositoryImpl{
|
||||
db: r.db.Set("cache:enabled", true).Set("cache:ttl", ttl),
|
||||
logger: r.logger,
|
||||
},
|
||||
tableName: r.tableName,
|
||||
}
|
||||
}
|
||||
|
||||
// WithoutCache 禁用缓存
|
||||
func (r *CachedBaseRepositoryImpl) WithoutCache() *CachedBaseRepositoryImpl {
|
||||
return &CachedBaseRepositoryImpl{
|
||||
BaseRepositoryImpl: &BaseRepositoryImpl{
|
||||
db: r.db.Set("cache:disabled", true),
|
||||
logger: r.logger,
|
||||
},
|
||||
tableName: r.tableName,
|
||||
}
|
||||
}
|
||||
|
||||
// WithShortCache 短期缓存(5分钟)
|
||||
func (r *CachedBaseRepositoryImpl) WithShortCache() *CachedBaseRepositoryImpl {
|
||||
return r.WithCache(5 * time.Minute)
|
||||
}
|
||||
|
||||
// WithMediumCache 中期缓存(30分钟)
|
||||
func (r *CachedBaseRepositoryImpl) WithMediumCache() *CachedBaseRepositoryImpl {
|
||||
return r.WithCache(30 * time.Minute)
|
||||
}
|
||||
|
||||
// WithLongCache 长期缓存(2小时)
|
||||
func (r *CachedBaseRepositoryImpl) WithLongCache() *CachedBaseRepositoryImpl {
|
||||
return r.WithCache(2 * time.Hour)
|
||||
}
|
||||
|
||||
// ================ 智能查询方法 ================
|
||||
|
||||
// SmartGetByID 智能ID查询(自动缓存)
|
||||
func (r *CachedBaseRepositoryImpl) SmartGetByID(ctx context.Context, id string, dest interface{}) error {
|
||||
return r.GetWithCache(ctx, dest, 30*time.Minute, "id = ?", id)
|
||||
}
|
||||
|
||||
// SmartGetByField 智能字段查询(自动缓存)
|
||||
func (r *CachedBaseRepositoryImpl) SmartGetByField(ctx context.Context, dest interface{}, field string, value interface{}, ttl ...time.Duration) error {
|
||||
cacheTTL := 15 * time.Minute
|
||||
if len(ttl) > 0 {
|
||||
cacheTTL = ttl[0]
|
||||
}
|
||||
|
||||
return r.GetWithCache(ctx, dest, cacheTTL, field+" = ?", value)
|
||||
}
|
||||
|
||||
// SmartList 智能列表查询(根据查询复杂度自动选择缓存策略)
|
||||
func (r *CachedBaseRepositoryImpl) SmartList(ctx context.Context, dest interface{}, options interfaces.ListOptions) error {
|
||||
// 根据查询复杂度决定缓存策略
|
||||
cacheTTL := r.calculateCacheTTL(options)
|
||||
useCache := r.shouldUseCache(options)
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
if useCache {
|
||||
db = db.Set("cache:enabled", true).Set("cache:ttl", cacheTTL)
|
||||
} else {
|
||||
db = db.Set("cache:disabled", true)
|
||||
}
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
db = db.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
// 这里应该由具体Repository实现搜索逻辑
|
||||
r.logger.Debug("搜索查询默认禁用缓存", zap.String("search", options.Search))
|
||||
db = db.Set("cache:disabled", true)
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, include := range options.Include {
|
||||
db = db.Preload(include)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
db = db.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
db = db.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
db = db.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return db.Find(dest).Error
|
||||
}
|
||||
|
||||
// calculateCacheTTL 计算缓存TTL
|
||||
func (r *CachedBaseRepositoryImpl) calculateCacheTTL(options interfaces.ListOptions) time.Duration {
|
||||
// 基础TTL
|
||||
baseTTL := 15 * time.Minute
|
||||
|
||||
// 如果有搜索,缩短TTL
|
||||
if options.Search != "" {
|
||||
return 2 * time.Minute
|
||||
}
|
||||
|
||||
// 如果有复杂筛选,缩短TTL
|
||||
if len(options.Filters) > 3 {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
|
||||
// 如果是简单查询,延长TTL
|
||||
if len(options.Filters) == 0 && options.Search == "" {
|
||||
return 30 * time.Minute
|
||||
}
|
||||
|
||||
return baseTTL
|
||||
}
|
||||
|
||||
// shouldUseCache 判断是否应该使用缓存
|
||||
func (r *CachedBaseRepositoryImpl) shouldUseCache(options interfaces.ListOptions) bool {
|
||||
// 如果有搜索,不使用缓存(搜索结果变化频繁)
|
||||
if options.Search != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果筛选条件过多,不使用缓存
|
||||
if len(options.Filters) > 5 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 如果分页页数过大,不使用缓存
|
||||
if options.Page > 10 {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// ================ 缓存预热方法 ================
|
||||
|
||||
// WarmupCommonQueries 预热常用查询
|
||||
func (r *CachedBaseRepositoryImpl) WarmupCommonQueries(ctx context.Context, queries []WarmupQuery) error {
|
||||
r.logger.Info("开始预热缓存",
|
||||
zap.String("table", r.tableName),
|
||||
zap.Int("queries", len(queries)),
|
||||
)
|
||||
|
||||
for _, query := range queries {
|
||||
if err := r.executeWarmupQuery(ctx, query); err != nil {
|
||||
r.logger.Warn("缓存预热失败",
|
||||
zap.String("query", query.Name),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WarmupQuery 预热查询定义
|
||||
type WarmupQuery struct {
|
||||
Name string `json:"name"`
|
||||
SQL string `json:"sql"`
|
||||
Args []interface{} `json:"args"`
|
||||
TTL time.Duration `json:"ttl"`
|
||||
Dest interface{} `json:"dest"`
|
||||
}
|
||||
|
||||
// executeWarmupQuery 执行预热查询
|
||||
func (r *CachedBaseRepositoryImpl) executeWarmupQuery(ctx context.Context, query WarmupQuery) error {
|
||||
db := r.GetDB(ctx).
|
||||
Set("cache:enabled", true).
|
||||
Set("cache:ttl", query.TTL)
|
||||
|
||||
if query.SQL != "" {
|
||||
return db.Raw(query.SQL, query.Args...).Scan(query.Dest).Error
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 高级缓存特性 ================
|
||||
|
||||
// GetOrCreate 获取或创建(带缓存)
|
||||
func (r *CachedBaseRepositoryImpl) GetOrCreate(ctx context.Context, dest interface{}, where string, args []interface{}, createFn func() interface{}) error {
|
||||
// 先尝试从缓存获取
|
||||
if err := r.GetWithCache(ctx, dest, 15*time.Minute, where, args...); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 缓存未命中,尝试从数据库获取
|
||||
if err := r.GetDB(ctx).Where(where, args...).First(dest).Error; err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 数据库也没有,创建新记录
|
||||
if createFn != nil {
|
||||
newEntity := createFn()
|
||||
if err := r.CreateEntity(ctx, newEntity); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 将新创建的实体复制到dest
|
||||
// 这里需要反射或其他方式复制
|
||||
return nil
|
||||
}
|
||||
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
// BatchGetWithCache 批量获取(带缓存)
|
||||
func (r *CachedBaseRepositoryImpl) BatchGetWithCache(ctx context.Context, ids []string, dest interface{}, ttl time.Duration) error {
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.FindWithCache(ctx, dest, ttl, "id IN ?", ids)
|
||||
}
|
||||
|
||||
// RefreshCache 刷新缓存
|
||||
func (r *CachedBaseRepositoryImpl) RefreshCache(ctx context.Context, pattern string) error {
|
||||
r.logger.Info("刷新缓存",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("pattern", pattern),
|
||||
)
|
||||
|
||||
// 这里需要调用缓存服务的删除模式方法
|
||||
// 具体实现取决于你的CacheService接口
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 缓存统计方法 ================
|
||||
|
||||
// GetCacheInfo 获取缓存信息
|
||||
func (r *CachedBaseRepositoryImpl) GetCacheInfo() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"table_name": r.tableName,
|
||||
"cache_enabled": true,
|
||||
"default_ttl": "30m",
|
||||
"cache_patterns": []string{
|
||||
fmt.Sprintf("gorm_cache:%s:*", r.tableName),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// LogCacheOperation 记录缓存操作
|
||||
func (r *CachedBaseRepositoryImpl) LogCacheOperation(operation, details string) {
|
||||
r.logger.Debug("缓存操作",
|
||||
zap.String("table", r.tableName),
|
||||
zap.String("operation", operation),
|
||||
zap.String("details", details),
|
||||
)
|
||||
}
|
||||
301
internal/shared/database/transaction.go
Normal file
301
internal/shared/database/transaction.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 自定义错误类型
|
||||
var (
|
||||
ErrTransactionRollback = errors.New("事务回滚失败")
|
||||
ErrTransactionCommit = errors.New("事务提交失败")
|
||||
)
|
||||
|
||||
// 定义context key
|
||||
type txKey struct{}
|
||||
|
||||
// WithTx 将事务对象存储到context中
|
||||
func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
|
||||
return context.WithValue(ctx, txKey{}, tx)
|
||||
}
|
||||
|
||||
// GetTx 从context中获取事务对象
|
||||
func GetTx(ctx context.Context) (*gorm.DB, bool) {
|
||||
tx, ok := ctx.Value(txKey{}).(*gorm.DB)
|
||||
return tx, ok
|
||||
}
|
||||
|
||||
// TransactionManager 事务管理器
|
||||
type TransactionManager struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTransactionManager 创建事务管理器
|
||||
func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
|
||||
return &TransactionManager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInTx 在事务中执行函数(推荐使用)
|
||||
// 自动处理事务的开启、提交和回滚
|
||||
func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
|
||||
// 检查是否已经在事务中
|
||||
if _, ok := GetTx(ctx); ok {
|
||||
// 如果已经在事务中,直接执行函数,避免嵌套事务
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
tx := tm.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// 创建带事务的context
|
||||
txCtx := WithTx(ctx, tx)
|
||||
|
||||
// 执行函数
|
||||
if err := fn(txCtx); err != nil {
|
||||
// 回滚事务
|
||||
if rbErr := tx.Rollback().Error; rbErr != nil {
|
||||
tm.logger.Error("事务回滚失败",
|
||||
zap.Error(err),
|
||||
zap.Error(rbErr),
|
||||
)
|
||||
return errors.Join(err, ErrTransactionRollback, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
tm.logger.Error("事务提交失败", zap.Error(err))
|
||||
return errors.Join(ErrTransactionCommit, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteInTxWithTimeout 在事务中执行函数(带超时)
|
||||
func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
return tm.ExecuteInTx(ctx, fn)
|
||||
}
|
||||
|
||||
// BeginTx 开始事务(手动管理)
|
||||
func (tm *TransactionManager) BeginTx() *gorm.DB {
|
||||
return tm.db.Begin()
|
||||
}
|
||||
|
||||
// TxWrapper 事务包装器(手动管理)
|
||||
type TxWrapper struct {
|
||||
tx *gorm.DB
|
||||
}
|
||||
|
||||
// NewTxWrapper 创建事务包装器
|
||||
func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
|
||||
return &TxWrapper{
|
||||
tx: tm.BeginTx(),
|
||||
}
|
||||
}
|
||||
|
||||
// Commit 提交事务
|
||||
func (tx *TxWrapper) Commit() error {
|
||||
return tx.tx.Commit().Error
|
||||
}
|
||||
|
||||
// Rollback 回滚事务
|
||||
func (tx *TxWrapper) Rollback() error {
|
||||
return tx.tx.Rollback().Error
|
||||
}
|
||||
|
||||
// GetDB 获取事务数据库实例
|
||||
func (tx *TxWrapper) GetDB() *gorm.DB {
|
||||
return tx.tx
|
||||
}
|
||||
|
||||
// WithTx 在事务中执行函数(兼容旧接口)
|
||||
func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
|
||||
tx := tm.BeginTx()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// TransactionOptions 事务选项
|
||||
type TransactionOptions struct {
|
||||
Timeout time.Duration
|
||||
ReadOnly bool // 是否只读事务
|
||||
}
|
||||
|
||||
// ExecuteInTxWithOptions 在事务中执行函数(带选项)
|
||||
func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
|
||||
// 设置事务选项
|
||||
tx := tm.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// 设置只读事务
|
||||
if options != nil && options.ReadOnly {
|
||||
tx = tx.Session(&gorm.Session{})
|
||||
// 注意:GORM的只读事务需要数据库支持,这里只是标记
|
||||
}
|
||||
|
||||
// 创建带事务的context
|
||||
txCtx := WithTx(ctx, tx)
|
||||
|
||||
// 设置超时
|
||||
if options != nil && options.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// 执行函数
|
||||
if err := fn(txCtx); err != nil {
|
||||
// 回滚事务
|
||||
if rbErr := tx.Rollback().Error; rbErr != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// TransactionStats 事务统计信息
|
||||
type TransactionStats struct {
|
||||
TotalTransactions int64
|
||||
SuccessfulTransactions int64
|
||||
FailedTransactions int64
|
||||
AverageDuration time.Duration
|
||||
}
|
||||
|
||||
// GetStats 获取事务统计信息(预留接口)
|
||||
func (tm *TransactionManager) GetStats() *TransactionStats {
|
||||
// TODO: 实现事务统计
|
||||
return &TransactionStats{}
|
||||
}
|
||||
|
||||
// RetryableTransactionOptions 可重试事务选项
|
||||
type RetryableTransactionOptions struct {
|
||||
MaxRetries int // 最大重试次数
|
||||
RetryDelay time.Duration // 重试延迟
|
||||
RetryBackoff float64 // 退避倍数
|
||||
}
|
||||
|
||||
// DefaultRetryableOptions 默认重试选项
|
||||
func DefaultRetryableOptions() *RetryableTransactionOptions {
|
||||
return &RetryableTransactionOptions{
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
RetryBackoff: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInTxWithRetry 在事务中执行函数(支持重试)
|
||||
// 适用于处理死锁等临时性错误
|
||||
func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
|
||||
if options == nil {
|
||||
options = DefaultRetryableOptions()
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
delay := options.RetryDelay
|
||||
|
||||
for attempt := 0; attempt <= options.MaxRetries; attempt++ {
|
||||
// 检查上下文是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
err := tm.ExecuteInTx(ctx, fn)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否是可重试的错误(死锁、连接错误等)
|
||||
if !isRetryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// 如果不是最后一次尝试,等待后重试
|
||||
if attempt < options.MaxRetries {
|
||||
tm.logger.Warn("事务执行失败,准备重试",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("max_retries", options.MaxRetries),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(float64(delay) * options.RetryBackoff)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tm.logger.Error("事务执行失败,已超过最大重试次数",
|
||||
zap.Int("max_retries", options.MaxRetries),
|
||||
zap.Error(lastErr),
|
||||
)
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// isRetryableError 判断是否是可重试的错误
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// MySQL 死锁错误
|
||||
if contains(errStr, "Deadlock found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// MySQL 锁等待超时
|
||||
if contains(errStr, "Lock wait timeout exceeded") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 连接错误
|
||||
if contains(errStr, "connection") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 可以根据需要添加更多的可重试错误类型
|
||||
return false
|
||||
}
|
||||
|
||||
// contains 检查字符串是否包含子字符串(不区分大小写)
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
||||
}
|
||||
Reference in New Issue
Block a user