247 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			247 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package database | |||
|  | 
 | |||
|  | import ( | |||
|  | 	"context" | |||
|  | 
 | |||
|  | 	"tyapi-server/internal/shared/interfaces" | |||
|  | 
 | |||
|  | 	"go.uber.org/zap" | |||
|  | 	"gorm.io/gorm" | |||
|  | ) | |||
|  | 
 | |||
|  | // BaseRepositoryImpl 基础仓储实现 | |||
|  | // 提供统一的数据库连接、事务处理和通用辅助方法 | |||
|  | type BaseRepositoryImpl struct { | |||
|  | 	db     *gorm.DB | |||
|  | 	logger *zap.Logger | |||
|  | } | |||
|  | 
 | |||
|  | // NewBaseRepositoryImpl 创建基础仓储实现 | |||
|  | func NewBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger) *BaseRepositoryImpl { | |||
|  | 	return &BaseRepositoryImpl{ | |||
|  | 		db:     db, | |||
|  | 		logger: logger, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // ================ 核心工具方法 ================ | |||
|  | 
 | |||
|  | // GetDB 获取数据库连接,优先使用事务 | |||
|  | // 这是Repository层统一的数据库连接获取方法 | |||
|  | func (r *BaseRepositoryImpl) GetDB(ctx context.Context) *gorm.DB { | |||
|  | 	if tx, ok := GetTx(ctx); ok { | |||
|  | 		return tx.WithContext(ctx) | |||
|  | 	} | |||
|  | 	return r.db.WithContext(ctx) | |||
|  | } | |||
|  | 
 | |||
|  | // GetLogger 获取日志记录器 | |||
|  | func (r *BaseRepositoryImpl) GetLogger() *zap.Logger { | |||
|  | 	return r.logger | |||
|  | } | |||
|  | 
 | |||
|  | // WithTx 使用事务创建新的Repository实例 | |||
|  | func (r *BaseRepositoryImpl) WithTx(tx *gorm.DB) *BaseRepositoryImpl { | |||
|  | 	return &BaseRepositoryImpl{ | |||
|  | 		db:     tx, | |||
|  | 		logger: r.logger, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // ExecuteInTransaction 在事务中执行函数 | |||
|  | func (r *BaseRepositoryImpl) ExecuteInTransaction(ctx context.Context, fn func(*gorm.DB) error) error { | |||
|  | 	db := r.GetDB(ctx) | |||
|  | 
 | |||
|  | 	// 如果已经在事务中,直接执行 | |||
|  | 	if _, ok := GetTx(ctx); ok { | |||
|  | 		return fn(db) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 否则开启新事务 | |||
|  | 	return db.Transaction(fn) | |||
|  | } | |||
|  | 
 | |||
|  | // IsInTransaction 检查当前是否在事务中 | |||
|  | func (r *BaseRepositoryImpl) IsInTransaction(ctx context.Context) bool { | |||
|  | 	_, ok := GetTx(ctx) | |||
|  | 	return ok | |||
|  | } | |||
|  | 
 | |||
|  | // ================ 通用查询辅助方法 ================ | |||
|  | 
 | |||
|  | // FindWhere 根据条件查找实体列表 | |||
|  | func (r *BaseRepositoryImpl) FindWhere(ctx context.Context, entities interface{}, condition string, args ...interface{}) error { | |||
|  | 	return r.GetDB(ctx).Where(condition, args...).Find(entities).Error | |||
|  | } | |||
|  | 
 | |||
|  | // FindOne 根据条件查找单个实体 | |||
|  | func (r *BaseRepositoryImpl) FindOne(ctx context.Context, entity interface{}, condition string, args ...interface{}) error { | |||
|  | 	return r.GetDB(ctx).Where(condition, args...).First(entity).Error | |||
|  | } | |||
|  | 
 | |||
|  | // CountWhere 根据条件统计数量 | |||
|  | func (r *BaseRepositoryImpl) CountWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (int64, error) { | |||
|  | 	var count int64 | |||
|  | 	err := r.GetDB(ctx).Model(entity).Where(condition, args...).Count(&count).Error | |||
|  | 	return count, err | |||
|  | } | |||
|  | 
 | |||
|  | // ExistsWhere 根据条件检查是否存在 | |||
|  | func (r *BaseRepositoryImpl) ExistsWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (bool, error) { | |||
|  | 	count, err := r.CountWhere(ctx, entity, condition, args...) | |||
|  | 	return count > 0, err | |||
|  | } | |||
|  | 
 | |||
|  | // ================ CRUD辅助方法 ================ | |||
|  | 
 | |||
|  | // CreateEntity 创建实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) CreateEntity(ctx context.Context, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Create(entity).Error | |||
|  | } | |||
|  | 
 | |||
|  | // GetEntityByID 根据ID获取实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) GetEntityByID(ctx context.Context, id string, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Where("id = ?", id).First(entity).Error | |||
|  | } | |||
|  | 
 | |||
|  | // UpdateEntity 更新实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) UpdateEntity(ctx context.Context, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Save(entity).Error | |||
|  | } | |||
|  | 
 | |||
|  | // DeleteEntity 删除实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) DeleteEntity(ctx context.Context, id string, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Delete(entity, "id = ?", id).Error | |||
|  | } | |||
|  | 
 | |||
|  | // ExistsEntity 检查实体是否存在(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) ExistsEntity(ctx context.Context, id string, entity interface{}) (bool, error) { | |||
|  | 	var count int64 | |||
|  | 	err := r.GetDB(ctx).Model(entity).Where("id = ?", id).Count(&count).Error | |||
|  | 	return count > 0, err | |||
|  | } | |||
|  | 
 | |||
|  | // ================ 批量操作辅助方法 ================ | |||
|  | 
 | |||
|  | // CreateBatchEntity 批量创建实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) CreateBatchEntity(ctx context.Context, entities interface{}) error { | |||
|  | 	return r.GetDB(ctx).Create(entities).Error | |||
|  | } | |||
|  | 
 | |||
|  | // GetEntitiesByIDs 根据ID列表获取实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) GetEntitiesByIDs(ctx context.Context, ids []string, entities interface{}) error { | |||
|  | 	return r.GetDB(ctx).Where("id IN ?", ids).Find(entities).Error | |||
|  | } | |||
|  | 
 | |||
|  | // UpdateBatchEntity 批量更新实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) UpdateBatchEntity(ctx context.Context, entities interface{}) error { | |||
|  | 	return r.GetDB(ctx).Save(entities).Error | |||
|  | } | |||
|  | 
 | |||
|  | // DeleteBatchEntity 批量删除实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) DeleteBatchEntity(ctx context.Context, ids []string, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Delete(entity, "id IN ?", ids).Error | |||
|  | } | |||
|  | 
 | |||
|  | // ================ 软删除辅助方法 ================ | |||
|  | 
 | |||
|  | // SoftDeleteEntity 软删除实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) SoftDeleteEntity(ctx context.Context, id string, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Delete(entity, "id = ?", id).Error | |||
|  | } | |||
|  | 
 | |||
|  | // RestoreEntity 恢复软删除的实体(辅助方法) | |||
|  | func (r *BaseRepositoryImpl) RestoreEntity(ctx context.Context, id string, entity interface{}) error { | |||
|  | 	return r.GetDB(ctx).Unscoped().Model(entity).Where("id = ?", id).Update("deleted_at", nil).Error | |||
|  | } | |||
|  | 
 | |||
|  | // ================ 高级查询辅助方法 ================ | |||
|  | 
 | |||
|  | // ListWithOptions 获取实体列表(支持ListOptions,辅助方法) | |||
|  | func (r *BaseRepositoryImpl) ListWithOptions(ctx context.Context, entity interface{}, entities interface{}, options interfaces.ListOptions) error { | |||
|  | 	query := r.GetDB(ctx).Model(entity) | |||
|  | 
 | |||
|  | 	// 应用筛选条件 | |||
|  | 	if options.Filters != nil { | |||
|  | 		for key, value := range options.Filters { | |||
|  | 			query = query.Where(key+" = ?", value) | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 应用搜索条件(基础实现,具体Repository应该重写) | |||
|  | 	if options.Search != "" { | |||
|  | 		query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 应用预加载 | |||
|  | 	for _, include := range options.Include { | |||
|  | 		query = query.Preload(include) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 应用排序 | |||
|  | 	if options.Sort != "" { | |||
|  | 		order := "ASC" | |||
|  | 		if options.Order == "desc" || options.Order == "DESC" { | |||
|  | 			order = "DESC" | |||
|  | 		} | |||
|  | 		query = query.Order(options.Sort + " " + order) | |||
|  | 	} else { | |||
|  | 		// 默认按创建时间倒序 | |||
|  | 		query = query.Order("created_at DESC") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 应用分页 | |||
|  | 	if options.Page > 0 && options.PageSize > 0 { | |||
|  | 		offset := (options.Page - 1) * options.PageSize | |||
|  | 		query = query.Offset(offset).Limit(options.PageSize) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return query.Find(entities).Error | |||
|  | } | |||
|  | 
 | |||
|  | // CountWithOptions 统计实体数量(支持CountOptions,辅助方法) | |||
|  | func (r *BaseRepositoryImpl) CountWithOptions(ctx context.Context, entity interface{}, options interfaces.CountOptions) (int64, error) { | |||
|  | 	var count int64 | |||
|  | 	query := r.GetDB(ctx).Model(entity) | |||
|  | 
 | |||
|  | 	// 应用筛选条件 | |||
|  | 	if options.Filters != nil { | |||
|  | 		for key, value := range options.Filters { | |||
|  | 			query = query.Where(key+" = ?", value) | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 应用搜索条件(基础实现,具体Repository应该重写) | |||
|  | 	if options.Search != "" { | |||
|  | 		query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	err := query.Count(&count).Error | |||
|  | 	return count, err | |||
|  | } | |||
|  | 
 | |||
|  | // ================ 常用查询模式 ================ | |||
|  | 
 | |||
|  | // FindByField 根据单个字段查找实体列表 | |||
|  | func (r *BaseRepositoryImpl) FindByField(ctx context.Context, entities interface{}, field string, value interface{}) error { | |||
|  | 	return r.GetDB(ctx).Where(field+" = ?", value).Find(entities).Error | |||
|  | } | |||
|  | 
 | |||
|  | // FindOneByField 根据单个字段查找单个实体 | |||
|  | func (r *BaseRepositoryImpl) FindOneByField(ctx context.Context, entity interface{}, field string, value interface{}) error { | |||
|  | 	return r.GetDB(ctx).Where(field+" = ?", value).First(entity).Error | |||
|  | } | |||
|  | 
 | |||
|  | // CountByField 根据单个字段统计数量 | |||
|  | func (r *BaseRepositoryImpl) CountByField(ctx context.Context, entity interface{}, field string, value interface{}) (int64, error) { | |||
|  | 	var count int64 | |||
|  | 	err := r.GetDB(ctx).Model(entity).Where(field+" = ?", value).Count(&count).Error | |||
|  | 	return count, err | |||
|  | } | |||
|  | 
 | |||
|  | // ExistsByField 根据单个字段检查是否存在 | |||
|  | func (r *BaseRepositoryImpl) ExistsByField(ctx context.Context, entity interface{}, field string, value interface{}) (bool, error) { | |||
|  | 	count, err := r.CountByField(ctx, entity, field, value) | |||
|  | 	return count > 0, err | |||
|  | } |