Files
tyapi-server/internal/domains/user/services/sms_code_service.go

188 lines
5.2 KiB
Go

package services
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/sms"
)
// SMSCodeService 短信验证码服务
type SMSCodeService struct {
repo *repositories.SMSCodeRepository
smsClient sms.Service
cache interfaces.CacheService
config config.SMSConfig
logger *zap.Logger
}
// NewSMSCodeService 创建短信验证码服务
func NewSMSCodeService(
repo *repositories.SMSCodeRepository,
smsClient sms.Service,
cache interfaces.CacheService,
config config.SMSConfig,
logger *zap.Logger,
) *SMSCodeService {
return &SMSCodeService{
repo: repo,
smsClient: smsClient,
cache: cache,
config: config,
logger: logger,
}
}
// SendCode 发送验证码
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error {
// 检查频率限制
if err := s.checkRateLimit(ctx, phone); err != nil {
return err
}
// 生成验证码
code := s.smsClient.GenerateCode(s.config.CodeLength)
// 创建SMS验证码记录
smsCode := &entities.SMSCode{
ID: uuid.New().String(),
Phone: phone,
Code: code,
Scene: scene,
IP: clientIP,
UserAgent: userAgent,
Used: false,
ExpiresAt: time.Now().Add(s.config.ExpireTime),
}
// 保存验证码
if err := s.repo.Create(ctx, smsCode); err != nil {
s.logger.Error("保存短信验证码失败",
zap.String("phone", phone),
zap.String("scene", string(scene)),
zap.Error(err))
return fmt.Errorf("保存验证码失败: %w", err)
}
// 发送短信
if err := s.smsClient.SendVerificationCode(ctx, phone, code); err != nil {
// 记录发送失败但不删除验证码记录,让其自然过期
s.logger.Error("发送短信验证码失败",
zap.String("phone", phone),
zap.String("code", code),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
// 更新发送记录缓存
s.updateSendRecord(ctx, phone)
s.logger.Info("短信验证码发送成功",
zap.String("phone", phone),
zap.String("scene", string(scene)))
return nil
}
// VerifyCode 验证验证码
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, code string, scene entities.SMSScene) error {
// 根据手机号和场景获取有效的验证码记录
smsCode, err := s.repo.GetValidCode(ctx, phone, scene)
if err != nil {
return fmt.Errorf("验证码无效或已过期")
}
// 验证验证码是否匹配
if smsCode.Code != code {
return fmt.Errorf("验证码无效或已过期")
}
// 标记验证码为已使用
if err := s.repo.MarkAsUsed(ctx, smsCode.ID); err != nil {
s.logger.Error("标记验证码为已使用失败",
zap.String("code_id", smsCode.ID),
zap.Error(err))
return fmt.Errorf("验证码状态更新失败")
}
s.logger.Info("短信验证码验证成功",
zap.String("phone", phone),
zap.String("scene", string(scene)))
return nil
}
// checkRateLimit 检查发送频率限制
func (s *SMSCodeService) checkRateLimit(ctx context.Context, phone string) error {
now := time.Now()
// 检查最小发送间隔
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
var lastSent time.Time
if err := s.cache.Get(ctx, lastSentKey, &lastSent); err == nil {
if now.Sub(lastSent) < s.config.RateLimit.MinInterval {
return fmt.Errorf("请等待 %v 后再试", s.config.RateLimit.MinInterval)
}
}
// 检查每小时发送限制
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
var hourlyCount int
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
if hourlyCount >= s.config.RateLimit.HourlyLimit {
return fmt.Errorf("每小时最多发送 %d 条短信", s.config.RateLimit.HourlyLimit)
}
}
// 检查每日发送限制
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
var dailyCount int
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
if dailyCount >= s.config.RateLimit.DailyLimit {
return fmt.Errorf("每日最多发送 %d 条短信", s.config.RateLimit.DailyLimit)
}
}
return nil
}
// updateSendRecord 更新发送记录
func (s *SMSCodeService) updateSendRecord(ctx context.Context, phone string) {
now := time.Now()
// 更新最后发送时间
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
s.cache.Set(ctx, lastSentKey, now, s.config.RateLimit.MinInterval)
// 更新每小时计数
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
var hourlyCount int
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
s.cache.Set(ctx, hourlyKey, hourlyCount+1, time.Hour)
} else {
s.cache.Set(ctx, hourlyKey, 1, time.Hour)
}
// 更新每日计数
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
var dailyCount int
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour)
} else {
s.cache.Set(ctx, dailyKey, 1, 24*time.Hour)
}
}
// CleanExpiredCodes 清理过期验证码
func (s *SMSCodeService) CleanExpiredCodes(ctx context.Context) error {
return s.repo.CleanupExpired(ctx)
}