121 lines
3.3 KiB
Go
121 lines
3.3 KiB
Go
package repositories
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
"gorm.io/gorm"
|
|
|
|
"tyapi-server/internal/domains/user/entities"
|
|
"tyapi-server/internal/shared/interfaces"
|
|
)
|
|
|
|
// SMSCodeRepository 短信验证码仓储
|
|
type SMSCodeRepository struct {
|
|
db *gorm.DB
|
|
cache interfaces.CacheService
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewSMSCodeRepository 创建短信验证码仓储
|
|
func NewSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *SMSCodeRepository {
|
|
return &SMSCodeRepository{
|
|
db: db,
|
|
cache: cache,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Create 创建短信验证码记录
|
|
func (r *SMSCodeRepository) Create(ctx context.Context, smsCode *entities.SMSCode) error {
|
|
if err := r.db.WithContext(ctx).Create(smsCode).Error; err != nil {
|
|
r.logger.Error("创建短信验证码失败", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
// 缓存验证码
|
|
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
|
|
r.cache.Set(ctx, cacheKey, smsCode, 5*time.Minute)
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetValidCode 获取有效的验证码
|
|
func (r *SMSCodeRepository) GetValidCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
|
// 先从缓存查找
|
|
cacheKey := r.buildCacheKey(phone, scene)
|
|
var smsCode entities.SMSCode
|
|
if err := r.cache.Get(ctx, cacheKey, &smsCode); err == nil {
|
|
return &smsCode, nil
|
|
}
|
|
|
|
// 从数据库查找最新的有效验证码
|
|
if err := r.db.WithContext(ctx).
|
|
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
|
|
phone, scene, time.Now()).
|
|
Order("created_at DESC").
|
|
First(&smsCode).Error; err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 缓存结果
|
|
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
|
|
|
|
return &smsCode, nil
|
|
}
|
|
|
|
// MarkAsUsed 标记验证码为已使用
|
|
func (r *SMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error {
|
|
now := time.Now()
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&entities.SMSCode{}).
|
|
Where("id = ?", id).
|
|
Update("used_at", now).Error; err != nil {
|
|
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
|
|
return err
|
|
}
|
|
|
|
r.logger.Info("验证码已标记为使用", zap.String("code_id", id))
|
|
return nil
|
|
}
|
|
|
|
// CleanupExpired 清理过期的验证码
|
|
func (r *SMSCodeRepository) CleanupExpired(ctx context.Context) error {
|
|
result := r.db.WithContext(ctx).
|
|
Where("expires_at < ?", time.Now()).
|
|
Delete(&entities.SMSCode{})
|
|
|
|
if result.Error != nil {
|
|
r.logger.Error("清理过期验证码失败", zap.Error(result.Error))
|
|
return result.Error
|
|
}
|
|
|
|
if result.RowsAffected > 0 {
|
|
r.logger.Info("清理过期验证码完成", zap.Int64("count", result.RowsAffected))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// CountRecentCodes 统计最近发送的验证码数量
|
|
func (r *SMSCodeRepository) CountRecentCodes(ctx context.Context, phone string, scene entities.SMSScene, duration time.Duration) (int64, error) {
|
|
var count int64
|
|
if err := r.db.WithContext(ctx).
|
|
Model(&entities.SMSCode{}).
|
|
Where("phone = ? AND scene = ? AND created_at > ?",
|
|
phone, scene, time.Now().Add(-duration)).
|
|
Count(&count).Error; err != nil {
|
|
r.logger.Error("统计最近验证码数量失败", zap.Error(err))
|
|
return 0, err
|
|
}
|
|
|
|
return count, nil
|
|
}
|
|
|
|
// buildCacheKey 构建缓存键
|
|
func (r *SMSCodeRepository) buildCacheKey(phone string, scene entities.SMSScene) string {
|
|
return fmt.Sprintf("sms_code:%s:%s", phone, string(scene))
|
|
}
|