237 lines
7.1 KiB
Go
237 lines
7.1 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"tyapi-server/internal/config"
|
|
"tyapi-server/internal/domains/user/entities"
|
|
"tyapi-server/internal/domains/user/repositories"
|
|
"tyapi-server/internal/shared/interfaces"
|
|
"tyapi-server/internal/shared/sms"
|
|
)
|
|
|
|
// SMSCodeService 短信验证码服务
|
|
type SMSCodeService struct {
|
|
repo *repositories.SMSCodeRepository
|
|
smsClient sms.Service
|
|
cache interfaces.CacheService
|
|
config config.SMSConfig
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewSMSCodeService 创建短信验证码服务
|
|
func NewSMSCodeService(
|
|
repo *repositories.SMSCodeRepository,
|
|
smsClient sms.Service,
|
|
cache interfaces.CacheService,
|
|
config config.SMSConfig,
|
|
logger *zap.Logger,
|
|
) *SMSCodeService {
|
|
return &SMSCodeService{
|
|
repo: repo,
|
|
smsClient: smsClient,
|
|
cache: cache,
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// SendCode 发送验证码
|
|
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error {
|
|
// 1. 检查频率限制
|
|
if err := s.checkRateLimit(ctx, phone); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 2. 生成验证码
|
|
code := s.smsClient.GenerateCode(s.config.CodeLength)
|
|
|
|
// 3. 使用工厂方法创建SMS验证码记录
|
|
smsCode, err := entities.NewSMSCode(phone, code, scene, s.config.ExpireTime, clientIP, userAgent)
|
|
if err != nil {
|
|
return fmt.Errorf("创建验证码记录失败: %w", err)
|
|
}
|
|
|
|
// 4. 设置ID
|
|
smsCode.ID = uuid.New().String()
|
|
|
|
// 5. 保存验证码
|
|
if err := s.repo.Create(ctx, smsCode); err != nil {
|
|
s.logger.Error("保存短信验证码失败",
|
|
zap.String("phone", smsCode.GetMaskedPhone()),
|
|
zap.String("scene", smsCode.GetSceneName()),
|
|
zap.Error(err))
|
|
return fmt.Errorf("保存验证码失败: %w", err)
|
|
}
|
|
|
|
// 6. 发送短信
|
|
if err := s.smsClient.SendVerificationCode(ctx, phone, code); err != nil {
|
|
// 记录发送失败但不删除验证码记录,让其自然过期
|
|
s.logger.Error("发送短信验证码失败",
|
|
zap.String("phone", smsCode.GetMaskedPhone()),
|
|
zap.String("code", smsCode.GetMaskedCode()),
|
|
zap.Error(err))
|
|
return fmt.Errorf("短信发送失败: %w", err)
|
|
}
|
|
|
|
// 7. 更新发送记录缓存
|
|
s.updateSendRecord(ctx, phone)
|
|
|
|
s.logger.Info("短信验证码发送成功",
|
|
zap.String("phone", smsCode.GetMaskedPhone()),
|
|
zap.String("scene", smsCode.GetSceneName()),
|
|
zap.String("remaining_time", smsCode.GetRemainingTime().String()))
|
|
|
|
return nil
|
|
}
|
|
|
|
// VerifyCode 验证验证码
|
|
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, code string, scene entities.SMSScene) error {
|
|
// 1. 根据手机号和场景获取有效的验证码记录
|
|
smsCode, err := s.repo.GetValidCode(ctx, phone, scene)
|
|
if err != nil {
|
|
return fmt.Errorf("验证码无效或已过期")
|
|
}
|
|
|
|
// 2. 使用实体的验证方法
|
|
if err := smsCode.VerifyCode(code); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 3. 保存更新后的验证码状态
|
|
if err := s.repo.Update(ctx, smsCode); err != nil {
|
|
s.logger.Error("更新验证码状态失败",
|
|
zap.String("code_id", smsCode.ID),
|
|
zap.Error(err))
|
|
return fmt.Errorf("验证码状态更新失败")
|
|
}
|
|
|
|
s.logger.Info("短信验证码验证成功",
|
|
zap.String("phone", smsCode.GetMaskedPhone()),
|
|
zap.String("scene", smsCode.GetSceneName()))
|
|
|
|
return nil
|
|
}
|
|
|
|
// CanResendCode 检查是否可以重新发送验证码
|
|
func (s *SMSCodeService) CanResendCode(ctx context.Context, phone string, scene entities.SMSScene) (bool, error) {
|
|
// 1. 获取最近的验证码记录
|
|
recentCode, err := s.repo.GetRecentCode(ctx, phone, scene)
|
|
if err != nil {
|
|
// 如果没有记录,可以发送
|
|
return true, nil
|
|
}
|
|
|
|
// 2. 使用实体的方法检查是否可以重新发送
|
|
canResend := recentCode.CanResend(s.config.RateLimit.MinInterval)
|
|
|
|
// 3. 记录检查结果
|
|
if !canResend {
|
|
remainingTime := s.config.RateLimit.MinInterval - time.Since(recentCode.CreatedAt)
|
|
s.logger.Info("验证码发送频率限制",
|
|
zap.String("phone", recentCode.GetMaskedPhone()),
|
|
zap.String("scene", recentCode.GetSceneName()),
|
|
zap.Duration("remaining_wait_time", remainingTime))
|
|
}
|
|
|
|
return canResend, nil
|
|
}
|
|
|
|
// GetCodeStatus 获取验证码状态信息
|
|
func (s *SMSCodeService) GetCodeStatus(ctx context.Context, phone string, scene entities.SMSScene) (map[string]interface{}, error) {
|
|
// 1. 获取最近的验证码记录
|
|
recentCode, err := s.repo.GetRecentCode(ctx, phone, scene)
|
|
if err != nil {
|
|
return map[string]interface{}{
|
|
"has_code": false,
|
|
"message": "没有找到验证码记录",
|
|
}, nil
|
|
}
|
|
|
|
// 2. 构建状态信息
|
|
status := map[string]interface{}{
|
|
"has_code": true,
|
|
"is_valid": recentCode.IsValid(),
|
|
"is_expired": recentCode.IsExpired(),
|
|
"is_used": recentCode.Used,
|
|
"remaining_time": recentCode.GetRemainingTime().String(),
|
|
"scene": recentCode.GetSceneName(),
|
|
"can_resend": recentCode.CanResend(s.config.RateLimit.MinInterval),
|
|
"created_at": recentCode.CreatedAt,
|
|
"security_info": recentCode.GetSecurityInfo(),
|
|
}
|
|
|
|
return status, nil
|
|
}
|
|
|
|
// checkRateLimit 检查发送频率限制
|
|
func (s *SMSCodeService) checkRateLimit(ctx context.Context, phone string) error {
|
|
now := time.Now()
|
|
|
|
// 检查最小发送间隔
|
|
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
|
|
var lastSent time.Time
|
|
if err := s.cache.Get(ctx, lastSentKey, &lastSent); err == nil {
|
|
if now.Sub(lastSent) < s.config.RateLimit.MinInterval {
|
|
return fmt.Errorf("请等待 %v 后再试", s.config.RateLimit.MinInterval)
|
|
}
|
|
}
|
|
|
|
// 检查每小时发送限制
|
|
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
|
|
var hourlyCount int
|
|
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
|
|
if hourlyCount >= s.config.RateLimit.HourlyLimit {
|
|
return fmt.Errorf("每小时最多发送 %d 条短信", s.config.RateLimit.HourlyLimit)
|
|
}
|
|
}
|
|
|
|
// 检查每日发送限制
|
|
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
|
|
var dailyCount int
|
|
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
|
|
if dailyCount >= s.config.RateLimit.DailyLimit {
|
|
return fmt.Errorf("每日最多发送 %d 条短信", s.config.RateLimit.DailyLimit)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// updateSendRecord 更新发送记录
|
|
func (s *SMSCodeService) updateSendRecord(ctx context.Context, phone string) {
|
|
now := time.Now()
|
|
|
|
// 更新最后发送时间
|
|
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
|
|
s.cache.Set(ctx, lastSentKey, now, s.config.RateLimit.MinInterval)
|
|
|
|
// 更新每小时计数
|
|
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
|
|
var hourlyCount int
|
|
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
|
|
s.cache.Set(ctx, hourlyKey, hourlyCount+1, time.Hour)
|
|
} else {
|
|
s.cache.Set(ctx, hourlyKey, 1, time.Hour)
|
|
}
|
|
|
|
// 更新每日计数
|
|
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
|
|
var dailyCount int
|
|
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
|
|
s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour)
|
|
} else {
|
|
s.cache.Set(ctx, dailyKey, 1, 24*time.Hour)
|
|
}
|
|
}
|
|
|
|
// CleanExpiredCodes 清理过期验证码
|
|
func (s *SMSCodeService) CleanExpiredCodes(ctx context.Context) error {
|
|
return s.repo.CleanupExpired(ctx)
|
|
}
|