188 lines
5.2 KiB
Go
188 lines
5.2 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"tyapi-server/internal/config"
|
|
"tyapi-server/internal/domains/user/entities"
|
|
"tyapi-server/internal/domains/user/repositories"
|
|
"tyapi-server/internal/shared/interfaces"
|
|
"tyapi-server/internal/shared/sms"
|
|
)
|
|
|
|
// SMSCodeService 短信验证码服务
|
|
type SMSCodeService struct {
|
|
repo *repositories.SMSCodeRepository
|
|
smsClient sms.Service
|
|
cache interfaces.CacheService
|
|
config config.SMSConfig
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewSMSCodeService 创建短信验证码服务
|
|
func NewSMSCodeService(
|
|
repo *repositories.SMSCodeRepository,
|
|
smsClient sms.Service,
|
|
cache interfaces.CacheService,
|
|
config config.SMSConfig,
|
|
logger *zap.Logger,
|
|
) *SMSCodeService {
|
|
return &SMSCodeService{
|
|
repo: repo,
|
|
smsClient: smsClient,
|
|
cache: cache,
|
|
config: config,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// SendCode 发送验证码
|
|
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error {
|
|
// 检查频率限制
|
|
if err := s.checkRateLimit(ctx, phone); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 生成验证码
|
|
code := s.smsClient.GenerateCode(s.config.CodeLength)
|
|
|
|
// 创建SMS验证码记录
|
|
smsCode := &entities.SMSCode{
|
|
ID: uuid.New().String(),
|
|
Phone: phone,
|
|
Code: code,
|
|
Scene: scene,
|
|
IP: clientIP,
|
|
UserAgent: userAgent,
|
|
Used: false,
|
|
ExpiresAt: time.Now().Add(s.config.ExpireTime),
|
|
}
|
|
|
|
// 保存验证码
|
|
if err := s.repo.Create(ctx, smsCode); err != nil {
|
|
s.logger.Error("保存短信验证码失败",
|
|
zap.String("phone", phone),
|
|
zap.String("scene", string(scene)),
|
|
zap.Error(err))
|
|
return fmt.Errorf("保存验证码失败: %w", err)
|
|
}
|
|
|
|
// 发送短信
|
|
if err := s.smsClient.SendVerificationCode(ctx, phone, code); err != nil {
|
|
// 记录发送失败但不删除验证码记录,让其自然过期
|
|
s.logger.Error("发送短信验证码失败",
|
|
zap.String("phone", phone),
|
|
zap.String("code", code),
|
|
zap.Error(err))
|
|
return fmt.Errorf("短信发送失败: %w", err)
|
|
}
|
|
|
|
// 更新发送记录缓存
|
|
s.updateSendRecord(ctx, phone)
|
|
|
|
s.logger.Info("短信验证码发送成功",
|
|
zap.String("phone", phone),
|
|
zap.String("scene", string(scene)))
|
|
|
|
return nil
|
|
}
|
|
|
|
// VerifyCode 验证验证码
|
|
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, code string, scene entities.SMSScene) error {
|
|
// 根据手机号和场景获取有效的验证码记录
|
|
smsCode, err := s.repo.GetValidCode(ctx, phone, scene)
|
|
if err != nil {
|
|
return fmt.Errorf("验证码无效或已过期")
|
|
}
|
|
|
|
// 验证验证码是否匹配
|
|
if smsCode.Code != code {
|
|
return fmt.Errorf("验证码无效或已过期")
|
|
}
|
|
|
|
// 标记验证码为已使用
|
|
if err := s.repo.MarkAsUsed(ctx, smsCode.ID); err != nil {
|
|
s.logger.Error("标记验证码为已使用失败",
|
|
zap.String("code_id", smsCode.ID),
|
|
zap.Error(err))
|
|
return fmt.Errorf("验证码状态更新失败")
|
|
}
|
|
|
|
s.logger.Info("短信验证码验证成功",
|
|
zap.String("phone", phone),
|
|
zap.String("scene", string(scene)))
|
|
|
|
return nil
|
|
}
|
|
|
|
// checkRateLimit 检查发送频率限制
|
|
func (s *SMSCodeService) checkRateLimit(ctx context.Context, phone string) error {
|
|
now := time.Now()
|
|
|
|
// 检查最小发送间隔
|
|
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
|
|
var lastSent time.Time
|
|
if err := s.cache.Get(ctx, lastSentKey, &lastSent); err == nil {
|
|
if now.Sub(lastSent) < s.config.RateLimit.MinInterval {
|
|
return fmt.Errorf("请等待 %v 后再试", s.config.RateLimit.MinInterval)
|
|
}
|
|
}
|
|
|
|
// 检查每小时发送限制
|
|
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
|
|
var hourlyCount int
|
|
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
|
|
if hourlyCount >= s.config.RateLimit.HourlyLimit {
|
|
return fmt.Errorf("每小时最多发送 %d 条短信", s.config.RateLimit.HourlyLimit)
|
|
}
|
|
}
|
|
|
|
// 检查每日发送限制
|
|
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
|
|
var dailyCount int
|
|
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
|
|
if dailyCount >= s.config.RateLimit.DailyLimit {
|
|
return fmt.Errorf("每日最多发送 %d 条短信", s.config.RateLimit.DailyLimit)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// updateSendRecord 更新发送记录
|
|
func (s *SMSCodeService) updateSendRecord(ctx context.Context, phone string) {
|
|
now := time.Now()
|
|
|
|
// 更新最后发送时间
|
|
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
|
|
s.cache.Set(ctx, lastSentKey, now, s.config.RateLimit.MinInterval)
|
|
|
|
// 更新每小时计数
|
|
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
|
|
var hourlyCount int
|
|
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
|
|
s.cache.Set(ctx, hourlyKey, hourlyCount+1, time.Hour)
|
|
} else {
|
|
s.cache.Set(ctx, hourlyKey, 1, time.Hour)
|
|
}
|
|
|
|
// 更新每日计数
|
|
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
|
|
var dailyCount int
|
|
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
|
|
s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour)
|
|
} else {
|
|
s.cache.Set(ctx, dailyKey, 1, 24*time.Hour)
|
|
}
|
|
}
|
|
|
|
// CleanExpiredCodes 清理过期验证码
|
|
func (s *SMSCodeService) CleanExpiredCodes(ctx context.Context) error {
|
|
return s.repo.CleanupExpired(ctx)
|
|
}
|