This commit is contained in:
2025-07-20 20:53:26 +08:00
parent 83bf9aea7d
commit 8ad1d7288e
158 changed files with 18156 additions and 13188 deletions

View File

@@ -2,6 +2,7 @@ package repositories
import (
"context"
"errors"
"fmt"
"time"
@@ -40,7 +41,7 @@ func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInf
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
@@ -51,11 +52,6 @@ func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (
// Update 更新企业信息
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
// 检查企业信息是否已认证完成,认证完成后不可修改
if enterpriseInfo.IsReadOnly() {
return fmt.Errorf("企业信息已认证完成,不可修改")
}
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
r.logger.Error("更新企业信息失败", zap.Error(err))
return fmt.Errorf("更新企业信息失败: %w", err)
@@ -94,7 +90,7 @@ func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) e
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
@@ -107,7 +103,7 @@ func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID s
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))

View File

@@ -5,6 +5,7 @@ package repositories
import (
"context"
"errors"
"fmt"
"time"
@@ -17,18 +18,16 @@ import (
"tyapi-server/internal/shared/interfaces"
)
// SMSCodeRepository 短信验证码仓储
// GormSMSCodeRepository 短信验证码GORM仓储实现无缓存确保安全性
type GormSMSCodeRepository struct {
db *gorm.DB
cache interfaces.CacheService
logger *zap.Logger
}
// NewGormSMSCodeRepository 创建短信验证码仓储
func NewGormSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) repositories.SMSCodeRepository {
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
return &GormSMSCodeRepository{
db: db,
cache: cache,
logger: logger,
}
}
@@ -36,19 +35,15 @@ func NewGormSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
// ================ 基础CRUD操作 ================
// ================ Repository[T] 接口实现 ================
// Create 创建短信验证码记录
// Create 创建短信验证码记录(不缓存,确保安全性)
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
// 缓存验证码
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
return smsCode, nil
}
@@ -56,7 +51,7 @@ func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMS
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&smsCode).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
}
r.logger.Error("获取短信验证码失败", zap.Error(err))
@@ -69,74 +64,15 @@ func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entitie
// Update 更新验证码记录
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
if err := r.db.WithContext(ctx).Save(&smsCode).Error; err != nil {
r.logger.Error("更新验证码记录失败", zap.Error(err))
r.logger.Error("更新短信验证码失败", zap.Error(err))
return err
}
// 更新缓存
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
r.logger.Info("验证码记录更新成功", zap.String("code_id", smsCode.ID))
return nil
}
// Delete 删除短信验证码
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除短信验证码失败", zap.Error(err))
return err
}
r.logger.Info("短信验证码删除成功", zap.String("id", id))
return nil
}
// SoftDelete 软删除短信验证码
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
return r.Delete(ctx, id)
}
// Restore 恢复短信验证码
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复短信验证码失败", zap.Error(err))
return err
}
r.logger.Info("短信验证码恢复成功", zap.String("id", id))
return nil
}
// Count 统计短信验证码数量
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("phone LIKE ? OR code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查短信验证码是否存在
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建短信验证码
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
r.logger.Info("批量创建短信验证码", zap.Int("count", len(smsCodes)))
return r.db.WithContext(ctx).Create(&smsCodes).Error
}
@@ -149,13 +85,11 @@ func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]e
// UpdateBatch 批量更新短信验证码
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
r.logger.Info("批量更新短信验证码", zap.Int("count", len(smsCodes)))
return r.db.WithContext(ctx).Save(&smsCodes).Error
}
// DeleteBatch 批量删除短信验证码
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除短信验证码", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
}
@@ -164,24 +98,35 @@ func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.Lis
var smsCodes []entities.SMSCode
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("phone LIKE ? OR code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
@@ -190,97 +135,148 @@ func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.Lis
return smsCodes, query.Find(&smsCodes).Error
}
// WithTx 使用事务
func (r *GormSMSCodeRepository) WithTx(tx interface{}) interfaces.Repository[entities.SMSCode] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormSMSCodeRepository{
db: gormTx,
cache: r.cache,
logger: r.logger,
}
}
return r
// ================ BaseRepository 接口实现 ================
// Delete 删除短信验证码
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// ================ 业务方法 ================
// Exists 检查短信验证码是否存在
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// Count 统计短信验证码数量
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// SoftDelete 软删除短信验证码
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Restore 恢复短信验证码
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
// GetByPhone 根据手机号获取短信验证码
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if err == gorm.ErrRecordNotFound {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
r.logger.Error("根据手机号获取短信验证码失败", zap.Error(err))
return nil, err
}
return &smsCode, nil
}
// GetLatestByPhone 根据手机号获取最新短信验证码
// GetLatestByPhone 根据手机号获取最新短信验证码
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
return r.GetByPhone(ctx, phone)
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// GetValidByPhone 根据手机号获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
return r.GetValidCode(ctx, phone, "")
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("有效的短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// GetValidByPhoneAndScene 根据手机号和场景获取有效的验证码
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
return r.GetValidCode(ctx, phone, scene)
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("有效的短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
var smsCodes []entities.SMSCode
var smsCodes []*entities.SMSCode
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.SMSCode{})
// 构建查询条件
db := r.db.WithContext(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if query.Phone != "" {
dbQuery = dbQuery.Where("phone = ?", query.Phone)
db = db.Where("phone = ?", query.Phone)
}
if query.Purpose != "" {
dbQuery = dbQuery.Where("scene = ?", query.Purpose)
db = db.Where("scene = ?", query.Purpose)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
db = db.Where("used = ?", query.Status == "used")
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
db = db.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
db = db.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&smsCodes).Error; err != nil {
if err := db.Offset(offset).Limit(query.PageSize).Order("created_at DESC").Find(&smsCodes).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
smsCodePtrs := make([]*entities.SMSCode, len(smsCodes))
for i := range smsCodes {
smsCodePtrs[i] = &smsCodes[i]
}
return smsCodePtrs, total, nil
return smsCodes, total, nil
}
// CreateCode 创建验证码
@@ -288,87 +284,63 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
smsCode := entities.SMSCode{
Phone: phone,
Code: code,
Scene: entities.SMSScene(purpose),
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟
Used: false,
Scene: entities.SMSScene(purpose), // 使用Scene字段
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效
}
return r.Create(ctx, smsCode)
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
}
// ValidateCode 验证验证码
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
phone, code, purpose, time.Now()).
First(&smsCode).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return false, nil
}
r.logger.Error("验证验证码失败", zap.Error(err))
return false, err
}
// 标记为已使用
if err := r.MarkAsUsed(ctx, smsCode.ID); err != nil {
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
return false, err
}
return true, nil
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
Count(&count).Error
return count > 0, err
}
// InvalidateCode 使验证码失效
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
now := time.Now()
return r.db.WithContext(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND used_at IS NULL", phone).
Update("used_at", time.Now()).Error; err != nil {
r.logger.Error("使验证码失效失败", zap.Error(err))
return err
}
// 清除缓存
cacheKey := r.buildCacheKey(phone, "")
r.cache.Delete(ctx, cacheKey)
r.logger.Info("验证码已失效", zap.String("phone", phone))
return nil
Update("used_at", &now).Error
}
// CheckSendFrequency 检查发送频率
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
// 检查最近1分钟内是否已发送
// 检查1分钟内是否已发送
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
Count(&count).Error; err != nil {
r.logger.Error("检查发送频率失败", zap.Error(err))
return false, err
}
return count == 0, nil
Count(&count).Error
// 如果1分钟内已发送则返回false不允许发送
return count == 0, err
}
// GetTodaySendCount 获取今日发送
// GetTodaySendCount 获取今日发送数
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
today := time.Now().Truncate(24 * time.Hour)
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&count).Error; err != nil {
r.logger.Error("获取今日发送次数失败", zap.Error(err))
return 0, err
}
return count, nil
Count(&count).Error
return count, err
}
// GetCodeStats 获取验证码统计信息
// GetCodeStats 获取验证码统计
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
var stats repositories.SMSCodeStats
@@ -406,94 +378,4 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
}
return &stats, nil
}
// GetValidCode 获取有效的验证码
func (r *GormSMSCodeRepository) GetValidCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
// 先从缓存查找
cacheKey := r.buildCacheKey(phone, scene)
var smsCode entities.SMSCode
if err := r.cache.Get(ctx, cacheKey, &smsCode); err == nil {
return &smsCode, nil
}
// 从数据库查找最新的有效验证码
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
return &smsCode, nil
}
// MarkAsUsed 标记验证码为已使用
func (r *GormSMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error {
now := time.Now()
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("id = ?", id).
Update("used_at", now).Error; err != nil {
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
return err
}
r.logger.Info("验证码已标记为使用", zap.String("code_id", id))
return nil
}
// GetRecentCode 获取最近的验证码记录(不限制有效性)
func (r *GormSMSCodeRepository) GetRecentCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ?", phone, scene).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
return nil, err
}
return &smsCode, nil
}
// CleanupExpired 清理过期的验证码
func (r *GormSMSCodeRepository) CleanupExpired(ctx context.Context) error {
result := r.db.WithContext(ctx).
Where("expires_at < ?", time.Now()).
Delete(&entities.SMSCode{})
if result.Error != nil {
r.logger.Error("清理过期验证码失败", zap.Error(result.Error))
return result.Error
}
if result.RowsAffected > 0 {
r.logger.Info("清理过期验证码完成", zap.Int64("count", result.RowsAffected))
}
return nil
}
// CountRecentCodes 统计最近发送的验证码数量
func (r *GormSMSCodeRepository) CountRecentCodes(ctx context.Context, phone string, scene entities.SMSScene, duration time.Duration) (int64, error) {
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?",
phone, scene, time.Now().Add(-duration)).
Count(&count).Error; err != nil {
r.logger.Error("统计最近验证码数量失败", zap.Error(err))
return 0, err
}
return count, nil
}
// buildCacheKey 构建缓存键
func (r *GormSMSCodeRepository) buildCacheKey(phone string, scene entities.SMSScene) string {
return fmt.Sprintf("sms_code:%s:%s", phone, string(scene))
}
}

View File

@@ -6,7 +6,6 @@ package repositories
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
@@ -24,18 +23,16 @@ var (
ErrUserNotFound = errors.New("用户不存在")
)
// UserRepository 用户仓储实现
// UserRepository 用户仓储实现(已移除手动缓存管理)
type GormUserRepository struct {
db *gorm.DB
cache interfaces.CacheService
logger *zap.Logger
}
// NewGormUserRepository 创建用户仓储
func NewGormUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) repositories.UserRepository {
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
db: db,
cache: cache,
logger: logger,
}
}
@@ -43,34 +40,21 @@ func NewGormUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *z
// 确保 GormUserRepository 实现了 UserRepository 接口
var _ repositories.UserRepository = (*GormUserRepository)(nil)
// ================ 基础CRUD操作 ================
// ================ Repository[T] 接口实现 ================
// Create 创建用户
// Create 创建用户(自动缓存失效)
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
if err := r.db.WithContext(ctx).Create(&user).Error; err != nil {
r.logger.Error("创建用户失败", zap.Error(err))
return entities.User{}, err
}
// 清除相关缓存
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
return user, nil
}
// GetByID 根据ID获取用户
// GetByID 根据ID获取用户(自动缓存)
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("user:id:%s", id)
var userCache entities.UserCache
if err := r.cache.Get(ctx, cacheKey, &userCache); err == nil {
var user entities.User
user.FromCache(&userCache)
return user, nil
}
// 从数据库查询
var user entities.User
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -80,109 +64,20 @@ func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.U
return entities.User{}, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, user.ToCache(), 10*time.Minute)
return user, nil
}
// Update 更新用户
// Update 更新用户(自动缓存失效)
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
if err := r.db.WithContext(ctx).Save(&user).Error; err != nil {
r.logger.Error("更新用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, user.ID)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
return nil
}
// Delete 删除用户
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
// 先获取用户信息用于清除缓存
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户删除成功", zap.String("user_id", id))
return nil
}
// SoftDelete 软删除用户
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
// 先获取用户信息用于清除缓存
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("软删除用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户软删除成功", zap.String("user_id", id))
return nil
}
// Restore 恢复软删除的用户
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, id)
r.logger.Info("用户恢复成功", zap.String("user_id", id))
return nil
}
// Count 统计用户数量
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.User{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("phone LIKE ? OR nickname LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查用户是否存在
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建用户
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量创建用户", zap.Int("count", len(users)))
@@ -213,24 +108,35 @@ func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOp
var users []entities.User
query := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("phone LIKE ? OR nickname LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
@@ -239,32 +145,61 @@ func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOp
return users, query.Find(&users).Error
}
// WithTx 使用事务
func (r *GormUserRepository) WithTx(tx interface{}) interfaces.Repository[entities.User] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormUserRepository{
db: gormTx,
cache: r.cache,
logger: r.logger,
}
// ================ BaseRepository 接口实现 ================
// Delete 删除用户
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除用户失败", zap.Error(err))
return err
}
return r
r.logger.Info("用户删除成功", zap.String("user_id", id))
return nil
}
// ================ 业务方法 ================
// Exists 检查用户是否存在
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// Count 统计用户数量
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// SoftDelete 软删除用户
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error
}
// Restore 恢复用户
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
// GetByPhone 根据手机号获取用户
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("user:phone:%s", phone)
var userCache entities.UserCache
if err := r.cache.Get(ctx, cacheKey, &userCache); err == nil {
var user entities.User
user.FromCache(&userCache)
return &user, nil
}
// 从数据库查询
var user entities.User
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -274,64 +209,68 @@ func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*ent
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, user.ToCache(), 10*time.Minute)
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据用户名查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
// GetByUserType 根据用户类型获取用户列表
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
var users []*entities.User
err := r.db.WithContext(ctx).Where("user_type = ?", userType).Find(&users).Error
return users, err
}
// ListUsers 获取用户列表(带分页和筛选)
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
var users []entities.User
var users []*entities.User
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.User{})
// 构建查询条件
db := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if query.Phone != "" {
dbQuery = dbQuery.Where("phone LIKE ?", "%"+query.Phone+"%")
db = db.Where("phone LIKE ?", "%"+query.Phone+"%")
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
db = db.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
db = db.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&users).Error; err != nil {
if err := db.Offset(offset).Limit(query.PageSize).Find(&users).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
userPtrs := make([]*entities.User, len(users))
for i := range users {
userPtrs[i] = &users[i]
}
return userPtrs, total, nil
return users, total, nil
}
// ValidateUser 验证用户
// ValidateUser 验证用户登录
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("手机号或密码错误")
}
r.logger.Error("验证用户失败", zap.Error(err))
err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
if err != nil {
return nil, err
}
@@ -340,172 +279,105 @@ func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password s
// UpdateLastLogin 更新最后登录时间
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
now := time.Now()
return r.db.WithContext(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("last_login_at", time.Now()).Error; err != nil {
r.logger.Error("更新最后登录时间失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("最后登录时间更新成功", zap.String("user_id", userID))
return nil
Updates(map[string]interface{}{
"last_login_at": &now,
"updated_at": now,
}).Error
}
// UpdatePassword 更新密码
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
return r.db.WithContext(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("password", newPassword).Error; err != nil {
r.logger.Error("更新密码失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("密码更新成功", zap.String("user_id", userID))
return nil
Update("password", newPassword).Error
}
// CheckPassword 检查密码
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("id = ? AND password = ?", userID, password).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
r.logger.Error("检查密码失败", zap.Error(err))
return false, err
}
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).
Where("id = ? AND password = ?", userID, password).
Count(&count).Error
return true, nil
return count > 0, err
}
// ActivateUser 激活用户
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
return r.db.WithContext(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("status", "ACTIVE").Error; err != nil {
r.logger.Error("激活用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("用户激活成功", zap.String("user_id", userID))
return nil
Update("active", true).Error
}
// DeactivateUser 停用用户
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
return r.db.WithContext(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("status", "INACTIVE").Error; err != nil {
r.logger.Error("停用用户失败", zap.Error(err))
return err
}
Update("active", false).Error
}
// 清除相关缓存
r.deleteCacheByID(ctx, userID)
r.logger.Info("用户停用成功", zap.String("user_id", userID))
return nil
// UpdateLoginStats 更新登录统计
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"login_count": gorm.Expr("login_count + 1"),
"last_login_at": time.Now(),
}).Error
}
// GetStats 获取用户统计信息
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.db.WithContext(ctx)
// 总用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
return nil, err
}
// 活跃用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("status = ?", "ACTIVE").Count(&stats.ActiveUsers).Error; err != nil {
if err := db.Model(&entities.User{}).Where("active = ?", true).Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetStatsByDateRange 根据日期范围获取用户统计信息
// GetStatsByDateRange 获取指定日期范围用户统计
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
var stats repositories.UserStats
// 总用户数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at BETWEEN ? AND ?", startDate, endDate).Count(&stats.TotalUsers).Error; err != nil {
db := r.db.WithContext(ctx)
// 指定时间范围内的注册数
if err := db.Model(&entities.User{}).
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 活跃用户
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("status = ? AND created_at BETWEEN ? AND ?", "ACTIVE", startDate, endDate).Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
// 指定时间范围内的登录
if err := db.Model(&entities.User{}).
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// FindByPhone 根据手机号查找用户(兼容旧方法)
func (r *GormUserRepository) FindByPhone(ctx context.Context, phone string) (*entities.User, error) {
return r.GetByPhone(ctx, phone)
}
// ExistsByPhone 检查手机号是否存在
func (r *GormUserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("phone = ?", phone).Count(&count).Error; err != nil {
r.logger.Error("检查手机号是否存在失败", zap.Error(err))
return false, err
}
return count > 0, nil
}
// 私有辅助方法
// deleteCacheByID 根据ID删除缓存
func (r *GormUserRepository) deleteCacheByID(ctx context.Context, id string) {
cacheKey := fmt.Sprintf("user:id:%s", id)
if err := r.cache.Delete(ctx, cacheKey); err != nil {
r.logger.Warn("删除用户ID缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
}
}
// deleteCacheByPhone 根据手机号删除缓存
func (r *GormUserRepository) deleteCacheByPhone(ctx context.Context, phone string) {
cacheKey := fmt.Sprintf("user:phone:%s", phone)
if err := r.cache.Delete(ctx, cacheKey); err != nil {
r.logger.Warn("删除用户手机号缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
}
}