This commit is contained in:
2025-07-28 01:46:39 +08:00
parent b03129667a
commit 357639462a
219 changed files with 21634 additions and 8138 deletions

View File

@@ -0,0 +1,101 @@
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
package repositories
import (
"context"
"errors"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ContractInfosTable = "contract_infos"
)
type GormContractInfoRepository struct {
*database.CachedBaseRepositoryImpl
}
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
return &GormContractInfoRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
}
}
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
return r.CreateEntity(ctx, contract)
}
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
var contract entities.ContractInfo
err := r.SmartGetByID(ctx, contractID, &contract)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &contract, nil
}
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
}
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
return count > 0, err
}
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
return count > 0, err
}

View File

@@ -15,20 +15,23 @@ import (
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
)
const (
SMSCodesTable = "sms_codes"
)
// GormSMSCodeRepository 短信验证码GORM仓储实现无缓存确保安全性
type GormSMSCodeRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
// NewGormSMSCodeRepository 创建短信验证码仓储
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
return &GormSMSCodeRepository{
db: db,
logger: logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
}
}
@@ -39,64 +42,54 @@ var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
// Create 创建短信验证码记录(不缓存,确保安全性)
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
err := r.GetDB(ctx).Create(&smsCode).Error
return smsCode, err
}
// GetByID 根据ID获取短信验证码
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&smsCode).Error; err != nil {
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
}
r.logger.Error("获取短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
}
// Update 更新验证码记录
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
if err := r.db.WithContext(ctx).Save(&smsCode).Error; err != nil {
r.logger.Error("更新短信验证码失败", zap.Error(err))
return err
}
return nil
return r.GetDB(ctx).Save(&smsCode).Error
}
// CreateBatch 批量创建短信验证码
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
return r.db.WithContext(ctx).Create(&smsCodes).Error
return r.GetDB(ctx).Create(&smsCodes).Error
}
// GetByIDs 根据ID列表获取短信验证码
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
return smsCodes, err
}
// UpdateBatch 批量更新短信验证码
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
return r.db.WithContext(ctx).Save(&smsCodes).Error
return r.GetDB(ctx).Save(&smsCodes).Error
}
// DeleteBatch 批量删除短信验证码
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
}
// List 获取短信验证码列表
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
query := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
@@ -139,20 +132,20 @@ func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.Lis
// Delete 删除短信验证码
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Exists 检查短信验证码是否存在
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// Count 统计短信验证码数量
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
query := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
@@ -172,12 +165,12 @@ func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.Co
// SoftDelete 软删除短信验证码
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Restore 恢复短信验证码
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
@@ -185,7 +178,7 @@ func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
// GetByPhone 根据手机号获取短信验证码
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
@@ -198,7 +191,7 @@ func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*
// GetLatestByPhone 根据手机号获取最新短信验证码
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
@@ -211,7 +204,7 @@ func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone stri
// GetValidByPhone 根据手机号获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
@@ -227,7 +220,7 @@ func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone strin
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
@@ -246,7 +239,7 @@ func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries
var total int64
// 构建查询条件
db := r.db.WithContext(ctx).Model(&entities.SMSCode{})
db := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if query.Phone != "" {
@@ -288,8 +281,8 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
}
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
@@ -299,7 +292,7 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
// ValidateCode 验证验证码
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
Count(&count).Error
@@ -309,7 +302,7 @@ func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string,
// InvalidateCode 使验证码失效
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&entities.SMSCode{}).
return r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND used_at IS NULL", phone).
Update("used_at", &now).Error
}
@@ -320,7 +313,7 @@ func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone st
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
Count(&count).Error
@@ -333,7 +326,7 @@ func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone str
today := time.Now().Truncate(24 * time.Hour)
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&count).Error
@@ -348,7 +341,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
startDate := time.Now().AddDate(0, 0, -days)
// 总发送数
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, startDate).
Count(&stats.TotalSent).Error; err != nil {
@@ -356,7 +349,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
}
// 总验证数
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
Count(&stats.TotalValidated).Error; err != nil {
@@ -370,7 +363,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
// 今日发送数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&stats.TodaySent).Error; err != nil {

View File

@@ -14,242 +14,199 @@ import (
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
)
const (
UsersTable = "users"
UserCacheTTL = 30 * 60 // 30分钟
)
// 定义错误常量
var (
// ErrUserNotFound 用户不存在错误
ErrUserNotFound = errors.New("用户不存在")
)
// UserRepository 用户仓储实现(已移除手动缓存管理)
type GormUserRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
// NewGormUserRepository 创建用户仓储
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
db: db,
logger: logger,
}
}
// 确保 GormUserRepository 实现了 UserRepository 接口
var _ repositories.UserRepository = (*GormUserRepository)(nil)
// ================ Repository[T] 接口实现 ================
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
}
}
// Create 创建用户(自动缓存失效)
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
if err := r.db.WithContext(ctx).Create(&user).Error; err != nil {
r.logger.Error("创建用户失败", zap.Error(err))
err := r.CreateEntity(ctx, &user)
return user, err
}
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
var user entities.User
err := r.SmartGetByID(ctx, id, &user)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, errors.New("用户不存在")
}
return entities.User{}, err
}
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
return user, nil
}
// GetByID 根据ID获取用户自动缓存
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, ErrUserNotFound
}
r.logger.Error("根据ID查询用户失败", zap.Error(err))
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
return entities.User{}, err
}
return user, nil
}
// Update 更新用户(自动缓存失效)
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
if err := r.db.WithContext(ctx).Save(&user).Error; err != nil {
r.logger.Error("更新用户失败", zap.Error(err))
return err
}
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
return nil
}
// CreateBatch 批量创建用户
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量创建用户", zap.Int("count", len(users)))
return r.db.WithContext(ctx).Create(&users).Error
}
// GetByIDs 根据ID列表获取用户
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
var users []entities.User
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
return users, err
}
// UpdateBatch 批量更新用户
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量更新用户", zap.Int("count", len(users)))
return r.db.WithContext(ctx).Save(&users).Error
}
// DeleteBatch 批量删除用户
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除用户", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
}
// List 获取用户列表
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
var users []entities.User
query := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return users, query.Find(&users).Error
}
// ================ BaseRepository 接口实现 ================
// Delete 删除用户
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除用户失败", zap.Error(err))
return err
}
r.logger.Info("用户删除成功", zap.String("user_id", id))
return nil
}
// Exists 检查用户是否存在
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
query := r.GetDB(ctx).Model(&entities.User{}).
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
// Count 统计用户数量
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
// 如果指定了排除的用户ID则排除该用户的记录
if excludeUserID != "" {
query = query.Where("users.id != ?", excludeUserID)
}
err := query.Count(&count).Error
if err != nil {
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
return false, err
}
return count > 0, nil
}
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
return r.UpdateEntity(ctx, &user)
}
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
return r.GetDB(ctx).Create(&users).Error
}
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
var users []entities.User
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&users).Error
return users, err
}
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
return r.GetDB(ctx).Save(&users).Error
}
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
}
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
var users []entities.User
err := r.SmartList(ctx, &users, options)
return users, err
}
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.User{})
}
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.User{})
}
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
return count, err
}
// SoftDelete 软删除用户
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
}
// Restore 恢复用户
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
// GetByPhone 根据手机号获取用户
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据手机号查询用户失败", zap.Error(err))
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据用户名查询用户失败", zap.Error(err))
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
// GetByUserType 根据用户类型获取用户列表
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
var users []*entities.User
err := r.db.WithContext(ctx).Where("user_type = ?", userType).Find(&users).Error
err := r.GetDB(ctx).Where("user_type = ?", userType).Find(&users).Error
return users, err
}
// ListUsers 获取用户列表(带分页和筛选)
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
var users []*entities.User
var total int64
// 构建查询条件
db := r.db.WithContext(ctx).Model(&entities.User{})
// 构建查询条件,预加载企业信息
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
// 应用筛选条件
if query.Phone != "" {
db = db.Where("phone LIKE ?", "%"+query.Phone+"%")
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
}
if query.UserType != "" {
db = db.Where("users.user_type = ?", query.UserType)
}
if query.IsActive != nil {
db = db.Where("users.active = ?", *query.IsActive)
}
if query.IsCertified != nil {
db = db.Where("users.is_certified = ?", *query.IsCertified)
}
if query.CompanyName != "" {
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
}
if query.StartDate != "" {
db = db.Where("created_at >= ?", query.StartDate)
db = db.Where("users.created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
db = db.Where("created_at <= ?", query.EndDate)
db = db.Where("users.created_at <= ?", query.EndDate)
}
// 统计总数
@@ -266,10 +223,9 @@ func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListU
return users, total, nil
}
// ValidateUser 验证用户登录
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
var user entities.User
err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
if err != nil {
return nil, err
}
@@ -277,10 +233,9 @@ func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password s
return &user, nil
}
// UpdateLastLogin 更新最后登录时间
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"last_login_at": &now,
@@ -288,40 +243,35 @@ func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string)
}).Error
}
// UpdatePassword 更新密码
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("password", newPassword).Error
}
// CheckPassword 检查密码
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).
err := r.GetDB(ctx).Model(&entities.User{}).
Where("id = ? AND password = ?", userID, password).
Count(&count).Error
return count > 0, err
}
// ActivateUser 激活用户
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("active", true).Error
}
// DeactivateUser 停用用户
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("active", false).Error
}
// UpdateLoginStats 更新登录统计
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"login_count": gorm.Expr("login_count + 1"),
@@ -329,11 +279,10 @@ func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string
}).Error
}
// GetStats 获取用户统计信息
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.db.WithContext(ctx)
db := r.GetDB(ctx)
// 总用户数
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
@@ -345,6 +294,11 @@ func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserSt
return nil, err
}
// 已认证用户数
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
@@ -359,11 +313,10 @@ func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserSt
return &stats, nil
}
// GetStatsByDateRange 获取指定日期范围的用户统计
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.db.WithContext(ctx)
db := r.GetDB(ctx)
// 指定时间范围内的注册数
if err := db.Model(&entities.User{}).