This commit is contained in:
2025-07-20 20:53:26 +08:00
parent 83bf9aea7d
commit 8ad1d7288e
158 changed files with 18156 additions and 13188 deletions

View File

@@ -0,0 +1,246 @@
package database
import (
"context"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// BaseRepositoryImpl 基础仓储实现
// 提供统一的数据库连接、事务处理和通用辅助方法
type BaseRepositoryImpl struct {
db *gorm.DB
logger *zap.Logger
}
// NewBaseRepositoryImpl 创建基础仓储实现
func NewBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger) *BaseRepositoryImpl {
return &BaseRepositoryImpl{
db: db,
logger: logger,
}
}
// ================ 核心工具方法 ================
// GetDB 获取数据库连接,优先使用事务
// 这是Repository层统一的数据库连接获取方法
func (r *BaseRepositoryImpl) GetDB(ctx context.Context) *gorm.DB {
if tx, ok := GetTx(ctx); ok {
return tx.WithContext(ctx)
}
return r.db.WithContext(ctx)
}
// GetLogger 获取日志记录器
func (r *BaseRepositoryImpl) GetLogger() *zap.Logger {
return r.logger
}
// WithTx 使用事务创建新的Repository实例
func (r *BaseRepositoryImpl) WithTx(tx *gorm.DB) *BaseRepositoryImpl {
return &BaseRepositoryImpl{
db: tx,
logger: r.logger,
}
}
// ExecuteInTransaction 在事务中执行函数
func (r *BaseRepositoryImpl) ExecuteInTransaction(ctx context.Context, fn func(*gorm.DB) error) error {
db := r.GetDB(ctx)
// 如果已经在事务中,直接执行
if _, ok := GetTx(ctx); ok {
return fn(db)
}
// 否则开启新事务
return db.Transaction(fn)
}
// IsInTransaction 检查当前是否在事务中
func (r *BaseRepositoryImpl) IsInTransaction(ctx context.Context) bool {
_, ok := GetTx(ctx)
return ok
}
// ================ 通用查询辅助方法 ================
// FindWhere 根据条件查找实体列表
func (r *BaseRepositoryImpl) FindWhere(ctx context.Context, entities interface{}, condition string, args ...interface{}) error {
return r.GetDB(ctx).Where(condition, args...).Find(entities).Error
}
// FindOne 根据条件查找单个实体
func (r *BaseRepositoryImpl) FindOne(ctx context.Context, entity interface{}, condition string, args ...interface{}) error {
return r.GetDB(ctx).Where(condition, args...).First(entity).Error
}
// CountWhere 根据条件统计数量
func (r *BaseRepositoryImpl) CountWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(entity).Where(condition, args...).Count(&count).Error
return count, err
}
// ExistsWhere 根据条件检查是否存在
func (r *BaseRepositoryImpl) ExistsWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (bool, error) {
count, err := r.CountWhere(ctx, entity, condition, args...)
return count > 0, err
}
// ================ CRUD辅助方法 ================
// CreateEntity 创建实体(辅助方法)
func (r *BaseRepositoryImpl) CreateEntity(ctx context.Context, entity interface{}) error {
return r.GetDB(ctx).Create(entity).Error
}
// GetEntityByID 根据ID获取实体辅助方法
func (r *BaseRepositoryImpl) GetEntityByID(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Where("id = ?", id).First(entity).Error
}
// UpdateEntity 更新实体(辅助方法)
func (r *BaseRepositoryImpl) UpdateEntity(ctx context.Context, entity interface{}) error {
return r.GetDB(ctx).Save(entity).Error
}
// DeleteEntity 删除实体(辅助方法)
func (r *BaseRepositoryImpl) DeleteEntity(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Delete(entity, "id = ?", id).Error
}
// ExistsEntity 检查实体是否存在(辅助方法)
func (r *BaseRepositoryImpl) ExistsEntity(ctx context.Context, id string, entity interface{}) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(entity).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// ================ 批量操作辅助方法 ================
// CreateBatchEntity 批量创建实体(辅助方法)
func (r *BaseRepositoryImpl) CreateBatchEntity(ctx context.Context, entities interface{}) error {
return r.GetDB(ctx).Create(entities).Error
}
// GetEntitiesByIDs 根据ID列表获取实体辅助方法
func (r *BaseRepositoryImpl) GetEntitiesByIDs(ctx context.Context, ids []string, entities interface{}) error {
return r.GetDB(ctx).Where("id IN ?", ids).Find(entities).Error
}
// UpdateBatchEntity 批量更新实体(辅助方法)
func (r *BaseRepositoryImpl) UpdateBatchEntity(ctx context.Context, entities interface{}) error {
return r.GetDB(ctx).Save(entities).Error
}
// DeleteBatchEntity 批量删除实体(辅助方法)
func (r *BaseRepositoryImpl) DeleteBatchEntity(ctx context.Context, ids []string, entity interface{}) error {
return r.GetDB(ctx).Delete(entity, "id IN ?", ids).Error
}
// ================ 软删除辅助方法 ================
// SoftDeleteEntity 软删除实体(辅助方法)
func (r *BaseRepositoryImpl) SoftDeleteEntity(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Delete(entity, "id = ?", id).Error
}
// RestoreEntity 恢复软删除的实体(辅助方法)
func (r *BaseRepositoryImpl) RestoreEntity(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Unscoped().Model(entity).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 高级查询辅助方法 ================
// ListWithOptions 获取实体列表支持ListOptions辅助方法
func (r *BaseRepositoryImpl) ListWithOptions(ctx context.Context, entity interface{}, entities interface{}, options interfaces.ListOptions) error {
query := r.GetDB(ctx).Model(entity)
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件基础实现具体Repository应该重写
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
// 默认按创建时间倒序
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return query.Find(entities).Error
}
// CountWithOptions 统计实体数量支持CountOptions辅助方法
func (r *BaseRepositoryImpl) CountWithOptions(ctx context.Context, entity interface{}, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(entity)
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件基础实现具体Repository应该重写
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// ================ 常用查询模式 ================
// FindByField 根据单个字段查找实体列表
func (r *BaseRepositoryImpl) FindByField(ctx context.Context, entities interface{}, field string, value interface{}) error {
return r.GetDB(ctx).Where(field+" = ?", value).Find(entities).Error
}
// FindOneByField 根据单个字段查找单个实体
func (r *BaseRepositoryImpl) FindOneByField(ctx context.Context, entity interface{}, field string, value interface{}) error {
return r.GetDB(ctx).Where(field+" = ?", value).First(entity).Error
}
// CountByField 根据单个字段统计数量
func (r *BaseRepositoryImpl) CountByField(ctx context.Context, entity interface{}, field string, value interface{}) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(entity).Where(field+" = ?", value).Count(&count).Error
return count, err
}
// ExistsByField 根据单个字段检查是否存在
func (r *BaseRepositoryImpl) ExistsByField(ctx context.Context, entity interface{}, field string, value interface{}) (bool, error) {
count, err := r.CountByField(ctx, entity, field, value)
return count > 0, err
}