package services import ( "context" "fmt" "time" "github.com/google/uuid" "go.uber.org/zap" "tyapi-server/internal/config" "tyapi-server/internal/domains/user/entities" "tyapi-server/internal/domains/user/repositories" "tyapi-server/internal/shared/interfaces" "tyapi-server/internal/shared/sms" ) // SMSCodeService 短信验证码服务 type SMSCodeService struct { repo *repositories.SMSCodeRepository smsClient sms.Service cache interfaces.CacheService config config.SMSConfig logger *zap.Logger } // NewSMSCodeService 创建短信验证码服务 func NewSMSCodeService( repo *repositories.SMSCodeRepository, smsClient sms.Service, cache interfaces.CacheService, config config.SMSConfig, logger *zap.Logger, ) *SMSCodeService { return &SMSCodeService{ repo: repo, smsClient: smsClient, cache: cache, config: config, logger: logger, } } // SendCode 发送验证码 func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error { // 检查频率限制 if err := s.checkRateLimit(ctx, phone); err != nil { return err } // 生成验证码 code := s.smsClient.GenerateCode(s.config.CodeLength) // 创建SMS验证码记录 smsCode := &entities.SMSCode{ ID: uuid.New().String(), Phone: phone, Code: code, Scene: scene, IP: clientIP, UserAgent: userAgent, Used: false, ExpiresAt: time.Now().Add(s.config.ExpireTime), } // 保存验证码 if err := s.repo.Create(ctx, smsCode); err != nil { s.logger.Error("保存短信验证码失败", zap.String("phone", phone), zap.String("scene", string(scene)), zap.Error(err)) return fmt.Errorf("保存验证码失败: %w", err) } // 发送短信 if err := s.smsClient.SendVerificationCode(ctx, phone, code); err != nil { // 记录发送失败但不删除验证码记录,让其自然过期 s.logger.Error("发送短信验证码失败", zap.String("phone", phone), zap.String("code", code), zap.Error(err)) return fmt.Errorf("短信发送失败: %w", err) } // 更新发送记录缓存 s.updateSendRecord(ctx, phone) s.logger.Info("短信验证码发送成功", zap.String("phone", phone), zap.String("scene", string(scene))) return nil } // VerifyCode 验证验证码 func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, code string, scene entities.SMSScene) error { // 根据手机号和场景获取有效的验证码记录 smsCode, err := s.repo.GetValidCode(ctx, phone, scene) if err != nil { return fmt.Errorf("验证码无效或已过期") } // 验证验证码是否匹配 if smsCode.Code != code { return fmt.Errorf("验证码无效或已过期") } // 标记验证码为已使用 if err := s.repo.MarkAsUsed(ctx, smsCode.ID); err != nil { s.logger.Error("标记验证码为已使用失败", zap.String("code_id", smsCode.ID), zap.Error(err)) return fmt.Errorf("验证码状态更新失败") } s.logger.Info("短信验证码验证成功", zap.String("phone", phone), zap.String("scene", string(scene))) return nil } // checkRateLimit 检查发送频率限制 func (s *SMSCodeService) checkRateLimit(ctx context.Context, phone string) error { now := time.Now() // 检查最小发送间隔 lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone) var lastSent time.Time if err := s.cache.Get(ctx, lastSentKey, &lastSent); err == nil { if now.Sub(lastSent) < s.config.RateLimit.MinInterval { return fmt.Errorf("请等待 %v 后再试", s.config.RateLimit.MinInterval) } } // 检查每小时发送限制 hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215")) var hourlyCount int if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil { if hourlyCount >= s.config.RateLimit.HourlyLimit { return fmt.Errorf("每小时最多发送 %d 条短信", s.config.RateLimit.HourlyLimit) } } // 检查每日发送限制 dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102")) var dailyCount int if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil { if dailyCount >= s.config.RateLimit.DailyLimit { return fmt.Errorf("每日最多发送 %d 条短信", s.config.RateLimit.DailyLimit) } } return nil } // updateSendRecord 更新发送记录 func (s *SMSCodeService) updateSendRecord(ctx context.Context, phone string) { now := time.Now() // 更新最后发送时间 lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone) s.cache.Set(ctx, lastSentKey, now, s.config.RateLimit.MinInterval) // 更新每小时计数 hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215")) var hourlyCount int if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil { s.cache.Set(ctx, hourlyKey, hourlyCount+1, time.Hour) } else { s.cache.Set(ctx, hourlyKey, 1, time.Hour) } // 更新每日计数 dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102")) var dailyCount int if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil { s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour) } else { s.cache.Set(ctx, dailyKey, 1, 24*time.Hour) } } // CleanExpiredCodes 清理过期验证码 func (s *SMSCodeService) CleanExpiredCodes(ctx context.Context) error { return s.repo.CleanupExpired(ctx) }