package repositories import ( "context" "fmt" "time" "go.uber.org/zap" "gorm.io/gorm" "tyapi-server/internal/domains/user/entities" "tyapi-server/internal/shared/interfaces" ) // SMSCodeRepository 短信验证码仓储 type SMSCodeRepository struct { db *gorm.DB cache interfaces.CacheService logger *zap.Logger } // NewSMSCodeRepository 创建短信验证码仓储 func NewSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *SMSCodeRepository { return &SMSCodeRepository{ db: db, cache: cache, logger: logger, } } // Create 创建短信验证码记录 func (r *SMSCodeRepository) Create(ctx context.Context, smsCode *entities.SMSCode) error { if err := r.db.WithContext(ctx).Create(smsCode).Error; err != nil { r.logger.Error("创建短信验证码失败", zap.Error(err)) return err } // 缓存验证码 cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene) r.cache.Set(ctx, cacheKey, smsCode, 5*time.Minute) return nil } // GetValidCode 获取有效的验证码 func (r *SMSCodeRepository) GetValidCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) { // 先从缓存查找 cacheKey := r.buildCacheKey(phone, scene) var smsCode entities.SMSCode if err := r.cache.Get(ctx, cacheKey, &smsCode); err == nil { return &smsCode, nil } // 从数据库查找最新的有效验证码 if err := r.db.WithContext(ctx). Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()). Order("created_at DESC"). First(&smsCode).Error; err != nil { return nil, err } // 缓存结果 r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute) return &smsCode, nil } // MarkAsUsed 标记验证码为已使用 func (r *SMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error { now := time.Now() if err := r.db.WithContext(ctx). Model(&entities.SMSCode{}). Where("id = ?", id). Update("used_at", now).Error; err != nil { r.logger.Error("标记验证码为已使用失败", zap.Error(err)) return err } r.logger.Info("验证码已标记为使用", zap.String("code_id", id)) return nil } // CleanupExpired 清理过期的验证码 func (r *SMSCodeRepository) CleanupExpired(ctx context.Context) error { result := r.db.WithContext(ctx). Where("expires_at < ?", time.Now()). Delete(&entities.SMSCode{}) if result.Error != nil { r.logger.Error("清理过期验证码失败", zap.Error(result.Error)) return result.Error } if result.RowsAffected > 0 { r.logger.Info("清理过期验证码完成", zap.Int64("count", result.RowsAffected)) } return nil } // CountRecentCodes 统计最近发送的验证码数量 func (r *SMSCodeRepository) CountRecentCodes(ctx context.Context, phone string, scene entities.SMSScene, duration time.Duration) (int64, error) { var count int64 if err := r.db.WithContext(ctx). Model(&entities.SMSCode{}). Where("phone = ? AND scene = ? AND created_at > ?", phone, scene, time.Now().Add(-duration)). Count(&count).Error; err != nil { r.logger.Error("统计最近验证码数量失败", zap.Error(err)) return 0, err } return count, nil } // buildCacheKey 构建缓存键 func (r *SMSCodeRepository) buildCacheKey(phone string, scene entities.SMSScene) string { return fmt.Sprintf("sms_code:%s:%s", phone, string(scene)) }