Files
tyapi-server/internal/shared/database/base_repository.go
2025-07-20 20:53:26 +08:00

247 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package database
import (
"context"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// BaseRepositoryImpl 基础仓储实现
// 提供统一的数据库连接、事务处理和通用辅助方法
type BaseRepositoryImpl struct {
db *gorm.DB
logger *zap.Logger
}
// NewBaseRepositoryImpl 创建基础仓储实现
func NewBaseRepositoryImpl(db *gorm.DB, logger *zap.Logger) *BaseRepositoryImpl {
return &BaseRepositoryImpl{
db: db,
logger: logger,
}
}
// ================ 核心工具方法 ================
// GetDB 获取数据库连接,优先使用事务
// 这是Repository层统一的数据库连接获取方法
func (r *BaseRepositoryImpl) GetDB(ctx context.Context) *gorm.DB {
if tx, ok := GetTx(ctx); ok {
return tx.WithContext(ctx)
}
return r.db.WithContext(ctx)
}
// GetLogger 获取日志记录器
func (r *BaseRepositoryImpl) GetLogger() *zap.Logger {
return r.logger
}
// WithTx 使用事务创建新的Repository实例
func (r *BaseRepositoryImpl) WithTx(tx *gorm.DB) *BaseRepositoryImpl {
return &BaseRepositoryImpl{
db: tx,
logger: r.logger,
}
}
// ExecuteInTransaction 在事务中执行函数
func (r *BaseRepositoryImpl) ExecuteInTransaction(ctx context.Context, fn func(*gorm.DB) error) error {
db := r.GetDB(ctx)
// 如果已经在事务中,直接执行
if _, ok := GetTx(ctx); ok {
return fn(db)
}
// 否则开启新事务
return db.Transaction(fn)
}
// IsInTransaction 检查当前是否在事务中
func (r *BaseRepositoryImpl) IsInTransaction(ctx context.Context) bool {
_, ok := GetTx(ctx)
return ok
}
// ================ 通用查询辅助方法 ================
// FindWhere 根据条件查找实体列表
func (r *BaseRepositoryImpl) FindWhere(ctx context.Context, entities interface{}, condition string, args ...interface{}) error {
return r.GetDB(ctx).Where(condition, args...).Find(entities).Error
}
// FindOne 根据条件查找单个实体
func (r *BaseRepositoryImpl) FindOne(ctx context.Context, entity interface{}, condition string, args ...interface{}) error {
return r.GetDB(ctx).Where(condition, args...).First(entity).Error
}
// CountWhere 根据条件统计数量
func (r *BaseRepositoryImpl) CountWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(entity).Where(condition, args...).Count(&count).Error
return count, err
}
// ExistsWhere 根据条件检查是否存在
func (r *BaseRepositoryImpl) ExistsWhere(ctx context.Context, entity interface{}, condition string, args ...interface{}) (bool, error) {
count, err := r.CountWhere(ctx, entity, condition, args...)
return count > 0, err
}
// ================ CRUD辅助方法 ================
// CreateEntity 创建实体(辅助方法)
func (r *BaseRepositoryImpl) CreateEntity(ctx context.Context, entity interface{}) error {
return r.GetDB(ctx).Create(entity).Error
}
// GetEntityByID 根据ID获取实体辅助方法
func (r *BaseRepositoryImpl) GetEntityByID(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Where("id = ?", id).First(entity).Error
}
// UpdateEntity 更新实体(辅助方法)
func (r *BaseRepositoryImpl) UpdateEntity(ctx context.Context, entity interface{}) error {
return r.GetDB(ctx).Save(entity).Error
}
// DeleteEntity 删除实体(辅助方法)
func (r *BaseRepositoryImpl) DeleteEntity(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Delete(entity, "id = ?", id).Error
}
// ExistsEntity 检查实体是否存在(辅助方法)
func (r *BaseRepositoryImpl) ExistsEntity(ctx context.Context, id string, entity interface{}) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(entity).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// ================ 批量操作辅助方法 ================
// CreateBatchEntity 批量创建实体(辅助方法)
func (r *BaseRepositoryImpl) CreateBatchEntity(ctx context.Context, entities interface{}) error {
return r.GetDB(ctx).Create(entities).Error
}
// GetEntitiesByIDs 根据ID列表获取实体辅助方法
func (r *BaseRepositoryImpl) GetEntitiesByIDs(ctx context.Context, ids []string, entities interface{}) error {
return r.GetDB(ctx).Where("id IN ?", ids).Find(entities).Error
}
// UpdateBatchEntity 批量更新实体(辅助方法)
func (r *BaseRepositoryImpl) UpdateBatchEntity(ctx context.Context, entities interface{}) error {
return r.GetDB(ctx).Save(entities).Error
}
// DeleteBatchEntity 批量删除实体(辅助方法)
func (r *BaseRepositoryImpl) DeleteBatchEntity(ctx context.Context, ids []string, entity interface{}) error {
return r.GetDB(ctx).Delete(entity, "id IN ?", ids).Error
}
// ================ 软删除辅助方法 ================
// SoftDeleteEntity 软删除实体(辅助方法)
func (r *BaseRepositoryImpl) SoftDeleteEntity(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Delete(entity, "id = ?", id).Error
}
// RestoreEntity 恢复软删除的实体(辅助方法)
func (r *BaseRepositoryImpl) RestoreEntity(ctx context.Context, id string, entity interface{}) error {
return r.GetDB(ctx).Unscoped().Model(entity).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 高级查询辅助方法 ================
// ListWithOptions 获取实体列表支持ListOptions辅助方法
func (r *BaseRepositoryImpl) ListWithOptions(ctx context.Context, entity interface{}, entities interface{}, options interfaces.ListOptions) error {
query := r.GetDB(ctx).Model(entity)
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件基础实现具体Repository应该重写
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
// 默认按创建时间倒序
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return query.Find(entities).Error
}
// CountWithOptions 统计实体数量支持CountOptions辅助方法
func (r *BaseRepositoryImpl) CountWithOptions(ctx context.Context, entity interface{}, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(entity)
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件基础实现具体Repository应该重写
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// ================ 常用查询模式 ================
// FindByField 根据单个字段查找实体列表
func (r *BaseRepositoryImpl) FindByField(ctx context.Context, entities interface{}, field string, value interface{}) error {
return r.GetDB(ctx).Where(field+" = ?", value).Find(entities).Error
}
// FindOneByField 根据单个字段查找单个实体
func (r *BaseRepositoryImpl) FindOneByField(ctx context.Context, entity interface{}, field string, value interface{}) error {
return r.GetDB(ctx).Where(field+" = ?", value).First(entity).Error
}
// CountByField 根据单个字段统计数量
func (r *BaseRepositoryImpl) CountByField(ctx context.Context, entity interface{}, field string, value interface{}) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(entity).Where(field+" = ?", value).Count(&count).Error
return count, err
}
// ExistsByField 根据单个字段检查是否存在
func (r *BaseRepositoryImpl) ExistsByField(ctx context.Context, entity interface{}, field string, value interface{}) (bool, error) {
count, err := r.CountByField(ctx, entity, field, value)
return count > 0, err
}