Files
tyc-server/app/main/api/internal/logic/auth/sendsmslogic.go
2025-08-31 14:18:31 +08:00

452 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package auth
// 验证码防护说明:
// 1. checkCaptchaProtection: 在发送短信前检查防护状态
// 2. recordCaptchaRequest: 在短信发送成功后记录请求次数
// 3. GetCaptchaProtectionStatus: 获取防护状态用于调试和监控
//
// 防护规则:
// - 单个手机号1分钟内最多1次1小时内最多5次24小时内最多20次
// - 单个IP1分钟内最多10次1小时内最多50次超过阈值后IP被临时封禁
// - 防止验证码爆破攻击,控制短信发送成本
import (
"context"
"fmt"
"math/rand"
"strconv"
"time"
"tyc-server/common/xerr"
"tyc-server/pkg/lzkit/crypto"
"github.com/pkg/errors"
"tyc-server/app/main/api/internal/svc"
"tyc-server/app/main/api/internal/types"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
dysmsapi "github.com/alibabacloud-go/dysmsapi-20170525/v3/client"
"github.com/alibabacloud-go/tea-utils/v2/service"
"github.com/alibabacloud-go/tea/tea"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
)
type SendSmsLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewSendSmsLogic(ctx context.Context, svcCtx *svc.ServiceContext) *SendSmsLogic {
return &SendSmsLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
secretKey := l.svcCtx.Config.Encrypt.SecretKey
encryptedMobile, err := crypto.EncryptMobile(req.Mobile, secretKey)
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err)
}
// 验证码防护检查
if err := l.checkCaptchaProtection(req.Mobile, req.ActionType); err != nil {
return err
}
// 检查手机号是否在一分钟内已发送过验证码
limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile)
exists, err := l.svcCtx.Redis.Exists(limitCodeKey)
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 读取redis缓存失败: %s", encryptedMobile)
}
if exists {
// 如果 Redis 中已经存在标记,说明在 1 分钟内请求过,返回错误
return errors.Wrapf(xerr.NewErrMsg("一分钟内不能重复发送验证码"), "短信发送, 手机号1分钟内重复请求发送验证码: %s", encryptedMobile)
}
code := fmt.Sprintf("%06d", rand.New(rand.NewSource(time.Now().UnixNano())).Intn(1000000))
// 发送短信
smsResp, err := l.sendSmsRequest(req.Mobile, code)
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 调用阿里客户端失败: %+v", err)
}
if *smsResp.Body.Code != "OK" {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message)
}
// 短信发送成功,记录请求次数
if err := l.recordCaptchaRequest(req.Mobile, req.ActionType); err != nil {
logx.Errorf("记录验证码请求失败: %v", err)
// 不影响主流程,只记录日志
}
codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile)
// 将验证码保存到 Redis设置过期时间
err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 验证码设置过期时间失败: %+v", err)
}
// 在 Redis 中设置 1 分钟的标记,限制重复请求
err = l.svcCtx.Redis.Setex(limitCodeKey, code, 60) // 标记 1 分钟内不能重复请求
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 验证码设置限制重复请求失败: %+v", err)
}
return nil
}
// CreateClient 创建阿里云短信客户端
func (l *SendSmsLogic) CreateClient() (*dysmsapi.Client, error) {
config := &openapi.Config{
AccessKeyId: &l.svcCtx.Config.VerifyCode.AccessKeyID,
AccessKeySecret: &l.svcCtx.Config.VerifyCode.AccessKeySecret,
}
config.Endpoint = tea.String(l.svcCtx.Config.VerifyCode.EndpointURL)
return dysmsapi.NewClient(config)
}
// sendSmsRequest 发送短信请求
func (l *SendSmsLogic) sendSmsRequest(mobile, code string) (*dysmsapi.SendSmsResponse, error) {
// 初始化阿里云短信客户端
cli, err := l.CreateClient()
if err != nil {
return nil, err
}
request := &dysmsapi.SendSmsRequest{
SignName: tea.String(l.svcCtx.Config.VerifyCode.SignName),
TemplateCode: tea.String(l.svcCtx.Config.VerifyCode.TemplateCode),
PhoneNumbers: tea.String(mobile),
TemplateParam: tea.String(fmt.Sprintf("{\"code\":\"%s\"}", code)),
}
runtime := &service.RuntimeOptions{}
return cli.SendSmsWithOptions(request, runtime)
}
// checkCaptchaProtection 检查验证码获取防护
func (l *SendSmsLogic) checkCaptchaProtection(mobile string, actionType string) error {
// 1. 检查手机号获取验证码频率
if err := l.checkMobileRateLimit(mobile, actionType); err != nil {
return err
}
// 2. 检查IP获取验证码频率
if err := l.checkIPRateLimit(); err != nil {
return err
}
return nil
}
// checkMobileRateLimit 检查手机号频率限制
func (l *SendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
// 限制单个手机号的每种短信类型:
// - 1分钟内最多获取1次验证码
// - 1小时内最多获取5次验证码
// - 24小时内最多获取20次验证码
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 检查1分钟限制
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
logx.Errorf("检查手机号1分钟限制失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if exists {
return errors.Wrapf(xerr.NewErrMsg("1分钟内已获取过验证码请稍后再试"), "验证码防护 - 手机号1分钟内重复请求: %s", mobile)
}
// 检查1小时限制
hourKey := fmt.Sprintf("%s:hour", mobileKey)
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取手机号1小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 5 {
return errors.Wrapf(xerr.NewErrMsg("1小时内获取验证码次数过多请稍后再试"), "验证码防护 - 手机号1小时内超过限制: %s", mobile)
}
}
// 检查24小时限制
dayKey := fmt.Sprintf("%s:day", mobileKey)
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取手机号24小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if dayCount, _ := strconv.ParseInt(count, 10, 64); dayCount >= 20 {
return errors.Wrapf(xerr.NewErrMsg("24小时内获取验证码次数过多请稍后再试"), "验证码防护 - 手机号24小时内超过限制: %s", mobile)
}
}
return nil
}
// checkIPRateLimit 检查IP频率限制
func (l *SendSmsLogic) checkIPRateLimit() error {
// 限制单个IP
// - 1分钟内最多获取10次验证码
// - 1小时内最多获取50次验证码
// - 超过阈值后IP被临时封禁
clientIP := l.getClientIP()
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// 检查IP是否被封禁
bannedKey := fmt.Sprintf("%s:banned", ipKey)
exists, err := l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
logx.Errorf("检查IP封禁状态失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
logx.Errorf("获取IP封禁剩余时间失败: %v", err)
}
if ttl > 0 {
return errors.Wrapf(xerr.NewErrMsg(fmt.Sprintf("IP被临时封禁请%d秒后再试", ttl)), "验证码防护 - IP被临时封禁: %s", clientIP)
} else {
// 封禁时间已过,清除封禁状态
l.svcCtx.Redis.Del(bannedKey)
}
}
// 检查1分钟限制
minuteKey := fmt.Sprintf("%s:minute", ipKey)
count, err := l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取IP1分钟计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if minuteCount, _ := strconv.ParseInt(count, 10, 64); minuteCount >= 10 {
// 封禁IP 5分钟
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
if err != nil {
logx.Errorf("封禁IP失败: %v", err)
}
logx.Errorf("验证码防护 - IP被临时封禁: %s, 封禁时间: 300秒", clientIP)
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁已被临时封禁5分钟"), "验证码防护 - IP被临时封禁: %s", clientIP)
}
}
// 检查1小时限制
hourKey := fmt.Sprintf("%s:hour", ipKey)
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取IP1小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 50 {
// 封禁IP 1小时
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
if err != nil {
logx.Errorf("封禁IP失败: %v", err)
}
logx.Errorf("验证码防护 - IP被长期封禁: %s, 封禁时间: 3600秒", clientIP)
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁已被临时封禁1小时"), "验证码防护 - IP被长期封禁: %s", clientIP)
}
}
return nil
}
// getClientIP 获取客户端真实IP
func (l *SendSmsLogic) getClientIP() string {
if l.ctx != nil {
// 尝试从上下文中获取IP
if ip, ok := l.ctx.Value("client_ip").(string); ok {
return ip
}
}
// 默认返回本地IP实际使用时应该从请求中获取
return "127.0.0.1"
}
// recordCaptchaRequest 记录验证码请求次数
func (l *SendSmsLogic) recordCaptchaRequest(mobile string, actionType string) error {
clientIP := l.getClientIP()
// 记录手机号请求次数
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 1分钟限制标记
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
if err != nil {
logx.Errorf("设置手机号1分钟限制标记失败: %v", err)
}
// 1小时计数
hourKey := fmt.Sprintf("%s:hour", mobileKey)
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
logx.Errorf("增加手机号1小时计数失败: %v", err)
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
logx.Errorf("设置手机号1小时计数过期时间失败: %v", err)
}
// 24小时计数
dayKey := fmt.Sprintf("%s:day", mobileKey)
_, err = l.svcCtx.Redis.Incr(dayKey)
if err != nil {
logx.Errorf("增加手机号24小时计数失败: %v", err)
}
// 设置24小时过期
err = l.svcCtx.Redis.Expire(dayKey, 86400)
if err != nil {
logx.Errorf("设置手机号24小时计数过期时间失败: %v", err)
}
// 记录IP请求次数
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// IP 1分钟计数
minuteKey = fmt.Sprintf("%s:minute", ipKey)
_, err = l.svcCtx.Redis.Incr(minuteKey)
if err != nil {
logx.Errorf("增加IP1分钟计数失败: %v", err)
}
// 设置1分钟过期
err = l.svcCtx.Redis.Expire(minuteKey, 60)
if err != nil {
logx.Errorf("设置IP1分钟计数过期时间失败: %v", err)
}
// IP 1小时计数
hourKey = fmt.Sprintf("%s:hour", ipKey)
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
logx.Errorf("增加IP1小时计数失败: %v", err)
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
logx.Errorf("设置IP1小时计数过期时间失败: %v", err)
}
return nil
}
// GetCaptchaProtectionStatus 获取验证码防护状态(用于调试和监控)
func (l *SendSmsLogic) GetCaptchaProtectionStatus(mobile string, actionType string) (map[string]interface{}, error) {
status := make(map[string]interface{})
clientIP := l.getClientIP()
// 检查手机号防护状态
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 1分钟限制状态
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return nil, err
}
status["mobileMinuteLimited"] = exists
// 1小时计数
hourKey := fmt.Sprintf("%s:hour", mobileKey)
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["mobileHourCount"] = hourCount
status["mobileHourRemaining"] = 5 - hourCount
}
} else {
status["mobileHourCount"] = 0
status["mobileHourRemaining"] = 5
}
// 24小时计数
dayKey := fmt.Sprintf("%s:day", mobileKey)
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if dayCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["mobileDayCount"] = dayCount
status["mobileDayRemaining"] = 20 - dayCount
}
} else {
status["mobileDayCount"] = 0
status["mobileDayRemaining"] = 20
}
// 检查IP防护状态
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// IP封禁状态
bannedKey := fmt.Sprintf("%s:banned", ipKey)
exists, err = l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return nil, err
}
status["ipBanned"] = exists
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err == nil {
status["ipBanRemaining"] = ttl
}
}
// IP 1分钟计数
minuteKey = fmt.Sprintf("%s:minute", ipKey)
count, err = l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if minuteCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["ipMinuteCount"] = minuteCount
status["ipMinuteRemaining"] = 10 - minuteCount
}
} else {
status["ipMinuteCount"] = 0
status["ipMinuteRemaining"] = 10
}
// IP 1小时计数
hourKey = fmt.Sprintf("%s:hour", ipKey)
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["ipHourCount"] = hourCount
status["ipHourRemaining"] = 50 - hourCount
}
} else {
status["ipHourCount"] = 0
status["ipHourRemaining"] = 50
}
// 添加基本信息
status["mobile"] = mobile
status["actionType"] = actionType
status["clientIP"] = clientIP
return status, nil
}