fix
This commit is contained in:
@@ -1,9 +1,20 @@
|
||||
package auth
|
||||
|
||||
// 验证码防护说明:
|
||||
// 1. checkCaptchaProtection: 在发送短信前检查防护状态
|
||||
// 2. recordCaptchaRequest: 在短信发送成功后记录请求次数
|
||||
// 3. GetCaptchaProtectionStatus: 获取防护状态用于调试和监控
|
||||
//
|
||||
// 防护规则:
|
||||
// - 单个手机号:1分钟内最多1次,1小时内最多5次,24小时内最多20次
|
||||
// - 单个IP:1分钟内最多10次,1小时内最多50次,超过阈值后IP被临时封禁
|
||||
// - 防止验证码爆破攻击,控制短信发送成本
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
"tyc-server/common/xerr"
|
||||
"tyc-server/pkg/lzkit/crypto"
|
||||
@@ -18,6 +29,7 @@ import (
|
||||
"github.com/alibabacloud-go/tea-utils/v2/service"
|
||||
"github.com/alibabacloud-go/tea/tea"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
type SendSmsLogic struct {
|
||||
@@ -40,6 +52,12 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
|
||||
if err != nil {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err)
|
||||
}
|
||||
|
||||
// 验证码防护检查
|
||||
if err := l.checkCaptchaProtection(req.Mobile, req.ActionType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查手机号是否在一分钟内已发送过验证码
|
||||
limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile)
|
||||
exists, err := l.svcCtx.Redis.Exists(limitCodeKey)
|
||||
@@ -62,6 +80,13 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
|
||||
if *smsResp.Body.Code != "OK" {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message)
|
||||
}
|
||||
|
||||
// 短信发送成功,记录请求次数
|
||||
if err := l.recordCaptchaRequest(req.Mobile, req.ActionType); err != nil {
|
||||
logx.Errorf("记录验证码请求失败: %v", err)
|
||||
// 不影响主流程,只记录日志
|
||||
}
|
||||
|
||||
codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile)
|
||||
// 将验证码保存到 Redis,设置过期时间
|
||||
err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟
|
||||
@@ -103,3 +128,324 @@ func (l *SendSmsLogic) sendSmsRequest(mobile, code string) (*dysmsapi.SendSmsRes
|
||||
runtime := &service.RuntimeOptions{}
|
||||
return cli.SendSmsWithOptions(request, runtime)
|
||||
}
|
||||
|
||||
// checkCaptchaProtection 检查验证码获取防护
|
||||
func (l *SendSmsLogic) checkCaptchaProtection(mobile string, actionType string) error {
|
||||
// 1. 检查手机号获取验证码频率
|
||||
if err := l.checkMobileRateLimit(mobile, actionType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 检查IP获取验证码频率
|
||||
if err := l.checkIPRateLimit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkMobileRateLimit 检查手机号频率限制
|
||||
func (l *SendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
|
||||
// 限制单个手机号的每种短信类型:
|
||||
// - 1分钟内最多获取1次验证码
|
||||
// - 1小时内最多获取5次验证码
|
||||
// - 24小时内最多获取20次验证码
|
||||
|
||||
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
logx.Errorf("检查手机号1分钟限制失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if exists {
|
||||
return errors.Wrapf(xerr.NewErrMsg("1分钟内已获取过验证码,请稍后再试"), "验证码防护 - 手机号1分钟内重复请求: %s", mobile)
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := fmt.Sprintf("%s:hour", mobileKey)
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取手机号1小时计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 5 {
|
||||
return errors.Wrapf(xerr.NewErrMsg("1小时内获取验证码次数过多,请稍后再试"), "验证码防护 - 手机号1小时内超过限制: %s", mobile)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查24小时限制
|
||||
dayKey := fmt.Sprintf("%s:day", mobileKey)
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取手机号24小时计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if dayCount, _ := strconv.ParseInt(count, 10, 64); dayCount >= 20 {
|
||||
return errors.Wrapf(xerr.NewErrMsg("24小时内获取验证码次数过多,请稍后再试"), "验证码防护 - 手机号24小时内超过限制: %s", mobile)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkIPRateLimit 检查IP频率限制
|
||||
func (l *SendSmsLogic) checkIPRateLimit() error {
|
||||
// 限制单个IP:
|
||||
// - 1分钟内最多获取10次验证码
|
||||
// - 1小时内最多获取50次验证码
|
||||
// - 超过阈值后IP被临时封禁
|
||||
|
||||
clientIP := l.getClientIP()
|
||||
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
|
||||
|
||||
// 检查IP是否被封禁
|
||||
bannedKey := fmt.Sprintf("%s:banned", ipKey)
|
||||
exists, err := l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
logx.Errorf("检查IP封禁状态失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err != nil {
|
||||
logx.Errorf("获取IP封禁剩余时间失败: %v", err)
|
||||
}
|
||||
if ttl > 0 {
|
||||
return errors.Wrapf(xerr.NewErrMsg(fmt.Sprintf("IP被临时封禁,请%d秒后再试", ttl)), "验证码防护 - IP被临时封禁: %s", clientIP)
|
||||
} else {
|
||||
// 封禁时间已过,清除封禁状态
|
||||
l.svcCtx.Redis.Del(bannedKey)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := fmt.Sprintf("%s:minute", ipKey)
|
||||
count, err := l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取IP1分钟计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if minuteCount, _ := strconv.ParseInt(count, 10, 64); minuteCount >= 10 {
|
||||
// 封禁IP 5分钟
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
|
||||
if err != nil {
|
||||
logx.Errorf("封禁IP失败: %v", err)
|
||||
}
|
||||
logx.Errorf("验证码防护 - IP被临时封禁: %s, 封禁时间: 300秒", clientIP)
|
||||
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁,已被临时封禁5分钟"), "验证码防护 - IP被临时封禁: %s", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := fmt.Sprintf("%s:hour", ipKey)
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取IP1小时计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 50 {
|
||||
// 封禁IP 1小时
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
|
||||
if err != nil {
|
||||
logx.Errorf("封禁IP失败: %v", err)
|
||||
}
|
||||
logx.Errorf("验证码防护 - IP被长期封禁: %s, 封禁时间: 3600秒", clientIP)
|
||||
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁,已被临时封禁1小时"), "验证码防护 - IP被长期封禁: %s", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端真实IP
|
||||
func (l *SendSmsLogic) getClientIP() string {
|
||||
if l.ctx != nil {
|
||||
// 尝试从上下文中获取IP
|
||||
if ip, ok := l.ctx.Value("client_ip").(string); ok {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 默认返回本地IP,实际使用时应该从请求中获取
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
// recordCaptchaRequest 记录验证码请求次数
|
||||
func (l *SendSmsLogic) recordCaptchaRequest(mobile string, actionType string) error {
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 记录手机号请求次数
|
||||
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
|
||||
|
||||
// 1分钟限制标记
|
||||
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
|
||||
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
|
||||
if err != nil {
|
||||
logx.Errorf("设置手机号1分钟限制标记失败: %v", err)
|
||||
}
|
||||
|
||||
// 1小时计数
|
||||
hourKey := fmt.Sprintf("%s:hour", mobileKey)
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加手机号1小时计数失败: %v", err)
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
logx.Errorf("设置手机号1小时计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := fmt.Sprintf("%s:day", mobileKey)
|
||||
_, err = l.svcCtx.Redis.Incr(dayKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加手机号24小时计数失败: %v", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
err = l.svcCtx.Redis.Expire(dayKey, 86400)
|
||||
if err != nil {
|
||||
logx.Errorf("设置手机号24小时计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
// 记录IP请求次数
|
||||
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = fmt.Sprintf("%s:minute", ipKey)
|
||||
_, err = l.svcCtx.Redis.Incr(minuteKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加IP1分钟计数失败: %v", err)
|
||||
}
|
||||
// 设置1分钟过期
|
||||
err = l.svcCtx.Redis.Expire(minuteKey, 60)
|
||||
if err != nil {
|
||||
logx.Errorf("设置IP1分钟计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = fmt.Sprintf("%s:hour", ipKey)
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加IP1小时计数失败: %v", err)
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
logx.Errorf("设置IP1小时计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCaptchaProtectionStatus 获取验证码防护状态(用于调试和监控)
|
||||
func (l *SendSmsLogic) GetCaptchaProtectionStatus(mobile string, actionType string) (map[string]interface{}, error) {
|
||||
status := make(map[string]interface{})
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 检查手机号防护状态
|
||||
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
|
||||
|
||||
// 1分钟限制状态
|
||||
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["mobileMinuteLimited"] = exists
|
||||
|
||||
// 1小时计数
|
||||
hourKey := fmt.Sprintf("%s:hour", mobileKey)
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["mobileHourCount"] = hourCount
|
||||
status["mobileHourRemaining"] = 5 - hourCount
|
||||
}
|
||||
} else {
|
||||
status["mobileHourCount"] = 0
|
||||
status["mobileHourRemaining"] = 5
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := fmt.Sprintf("%s:day", mobileKey)
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if dayCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["mobileDayCount"] = dayCount
|
||||
status["mobileDayRemaining"] = 20 - dayCount
|
||||
}
|
||||
} else {
|
||||
status["mobileDayCount"] = 0
|
||||
status["mobileDayRemaining"] = 20
|
||||
}
|
||||
|
||||
// 检查IP防护状态
|
||||
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
|
||||
|
||||
// IP封禁状态
|
||||
bannedKey := fmt.Sprintf("%s:banned", ipKey)
|
||||
exists, err = l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["ipBanned"] = exists
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err == nil {
|
||||
status["ipBanRemaining"] = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = fmt.Sprintf("%s:minute", ipKey)
|
||||
count, err = l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if minuteCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["ipMinuteCount"] = minuteCount
|
||||
status["ipMinuteRemaining"] = 10 - minuteCount
|
||||
}
|
||||
} else {
|
||||
status["ipMinuteCount"] = 0
|
||||
status["ipMinuteRemaining"] = 10
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = fmt.Sprintf("%s:hour", ipKey)
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["ipHourCount"] = hourCount
|
||||
status["ipHourRemaining"] = 50 - hourCount
|
||||
}
|
||||
} else {
|
||||
status["ipHourCount"] = 0
|
||||
status["ipHourRemaining"] = 50
|
||||
}
|
||||
|
||||
// 添加基本信息
|
||||
status["mobile"] = mobile
|
||||
status["actionType"] = actionType
|
||||
status["clientIP"] = clientIP
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
"tyc-server/app/main/api/internal/svc"
|
||||
)
|
||||
|
||||
// 集成测试配置
|
||||
var integrationTestConfig = config.Config{
|
||||
Encrypt: config.Encrypt{
|
||||
SecretKey: "test-secret-key",
|
||||
},
|
||||
VerifyCode: config.VerifyCode{
|
||||
ValidTime: 300,
|
||||
},
|
||||
}
|
||||
|
||||
// 创建集成测试用的ServiceContext
|
||||
func createIntegrationTestServiceContext() *svc.ServiceContext {
|
||||
redisConf := redis.RedisConf{
|
||||
Host: "127.0.0.1:6379",
|
||||
Type: "node",
|
||||
Pass: "",
|
||||
}
|
||||
|
||||
return &svc.ServiceContext{
|
||||
Config: integrationTestConfig,
|
||||
Redis: redis.MustNewRedis(redisConf),
|
||||
}
|
||||
}
|
||||
|
||||
// 清理测试数据
|
||||
func cleanupTestData(logic *SendSmsLogic, mobile, clientIP string) {
|
||||
// 清理手机号相关数据
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
logic.svcCtx.Redis.Del(mobileKey+":minute", mobileKey+":hour", mobileKey+":day")
|
||||
|
||||
// 清理IP相关数据
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
logic.svcCtx.Redis.Del(ipKey+":minute", ipKey+":hour", ipKey+":banned")
|
||||
}
|
||||
|
||||
// 创建集成测试用的SendSmsLogic
|
||||
func createIntegrationTestLogic() *SendSmsLogic {
|
||||
svcCtx := createIntegrationTestServiceContext()
|
||||
|
||||
return &SendSmsLogic{
|
||||
ctx: context.Background(),
|
||||
svcCtx: svcCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// 跳过集成测试的标志
|
||||
var skipIntegrationTests = true
|
||||
|
||||
func TestIntegrationCheckMobileRateLimit(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, "192.168.1.100")
|
||||
|
||||
t.Run("正常请求 - 无限制", func(t *testing.T) {
|
||||
err := logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("1分钟内重复请求", func(t *testing.T) {
|
||||
// 设置1分钟限制标记
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
err := logic.svcCtx.Redis.Setex(mobileKey+":minute", "1", 60)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 再次请求应该被拒绝
|
||||
err = logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
|
||||
|
||||
// 清理限制标记
|
||||
logic.svcCtx.Redis.Del(mobileKey + ":minute")
|
||||
})
|
||||
|
||||
t.Run("1小时内超过限制", func(t *testing.T) {
|
||||
// 设置1小时计数为5(达到限制)
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
err := logic.svcCtx.Redis.Setex(mobileKey+":hour", "5", 3600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝
|
||||
err = logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "1小时内获取验证码次数过多")
|
||||
|
||||
// 清理计数
|
||||
logic.svcCtx.Redis.Del(mobileKey + ":hour")
|
||||
})
|
||||
|
||||
t.Run("24小时内超过限制", func(t *testing.T) {
|
||||
// 设置24小时计数为20(达到限制)
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
err := logic.svcCtx.Redis.Setex(mobileKey+":day", "20", 86400)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝
|
||||
err = logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "24小时内获取验证码次数过多")
|
||||
|
||||
// 清理计数
|
||||
logic.svcCtx.Redis.Del(mobileKey + ":day")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationCheckIPRateLimit(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("正常请求 - 无限制", func(t *testing.T) {
|
||||
err := logic.checkIPRateLimit()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("IP被封禁且未过期", func(t *testing.T) {
|
||||
// 设置IP封禁状态
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
err := logic.svcCtx.Redis.Setex(ipKey+":banned", "1", 300)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝
|
||||
err = logic.checkIPRateLimit()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "IP被临时封禁")
|
||||
|
||||
// 清理封禁状态
|
||||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||||
})
|
||||
|
||||
t.Run("1分钟内超过限制 - 触发短期封禁", func(t *testing.T) {
|
||||
// 设置1分钟计数为10(达到限制)
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
err := logic.svcCtx.Redis.Setex(ipKey+":minute", "10", 60)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝并触发封禁
|
||||
err = logic.checkIPRateLimit()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁5分钟")
|
||||
|
||||
// 验证IP被封禁
|
||||
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 清理数据
|
||||
logic.svcCtx.Redis.Del(ipKey + ":minute")
|
||||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||||
})
|
||||
|
||||
t.Run("1小时内超过限制 - 触发长期封禁", func(t *testing.T) {
|
||||
// 设置1小时计数为50(达到限制)
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
err := logic.svcCtx.Redis.Setex(ipKey+":hour", "50", 3600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝并触发封禁
|
||||
err = logic.checkIPRateLimit()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁1小时")
|
||||
|
||||
// 验证IP被封禁
|
||||
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 清理数据
|
||||
logic.svcCtx.Redis.Del(ipKey + ":hour")
|
||||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationRecordCaptchaRequest(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("记录手机号请求次数", func(t *testing.T) {
|
||||
err := logic.recordCaptchaRequest(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证1分钟限制标记
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
exists, err := logic.svcCtx.Redis.Exists(mobileKey + ":minute")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 验证1小时计数
|
||||
count, err := logic.svcCtx.Redis.Get(mobileKey + ":hour")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", count)
|
||||
|
||||
// 验证24小时计数
|
||||
count, err = logic.svcCtx.Redis.Get(mobileKey + ":day")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", count)
|
||||
})
|
||||
|
||||
t.Run("记录IP请求次数", func(t *testing.T) {
|
||||
err := logic.recordCaptchaRequest(mobile, "register")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证IP 1分钟计数
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
count, err := logic.svcCtx.Redis.Get(ipKey + ":minute")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2", count) // 第二次调用
|
||||
|
||||
// 验证IP 1小时计数
|
||||
count, err = logic.svcCtx.Redis.Get(ipKey + ":hour")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2", count) // 第二次调用
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationGetCaptchaProtectionStatus(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("获取防护状态", func(t *testing.T) {
|
||||
status, err := logic.GetCaptchaProtectionStatus(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, status)
|
||||
|
||||
// 验证基本信息
|
||||
assert.Equal(t, mobile, status["mobile"])
|
||||
assert.Equal(t, "login", status["actionType"])
|
||||
assert.Equal(t, clientIP, status["clientIP"])
|
||||
|
||||
// 验证手机号状态
|
||||
assert.False(t, status["mobileMinuteLimited"].(bool))
|
||||
assert.Equal(t, int64(0), status["mobileHourCount"])
|
||||
assert.Equal(t, int64(5), status["mobileHourRemaining"])
|
||||
assert.Equal(t, int64(0), status["mobileDayCount"])
|
||||
assert.Equal(t, int64(20), status["mobileDayRemaining"])
|
||||
|
||||
// 验证IP状态
|
||||
assert.False(t, status["ipBanned"].(bool))
|
||||
assert.Equal(t, int64(0), status["ipMinuteCount"])
|
||||
assert.Equal(t, int64(10), status["ipMinuteRemaining"])
|
||||
assert.Equal(t, int64(0), status["ipHourCount"])
|
||||
assert.Equal(t, int64(50), status["ipHourRemaining"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationEndToEnd(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("完整流程测试", func(t *testing.T) {
|
||||
// 第一次请求 - 应该成功
|
||||
err := logic.checkCaptchaProtection(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 记录请求
|
||||
err = logic.recordCaptchaRequest(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 第二次请求 - 应该被拒绝(1分钟限制)
|
||||
err = logic.checkCaptchaProtection(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
|
||||
|
||||
// 不同短信类型 - 应该成功
|
||||
err = logic.checkCaptchaProtection(mobile, "register")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// 性能基准测试
|
||||
func BenchmarkCheckCaptchaProtection(b *testing.B) {
|
||||
if skipIntegrationTests {
|
||||
b.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 使用不同的手机号避免限制
|
||||
testMobile := mobile + "_" + string(rune(i%10))
|
||||
logic.checkCaptchaProtection(testMobile, "login")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecordCaptchaRequest(b *testing.B) {
|
||||
if skipIntegrationTests {
|
||||
b.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 使用不同的手机号避免限制
|
||||
testMobile := mobile + "_" + string(rune(i%10))
|
||||
logic.recordCaptchaRequest(testMobile, "login")
|
||||
}
|
||||
}
|
||||
671
app/main/api/internal/logic/auth/sendsmslogic_test.go
Normal file
671
app/main/api/internal/logic/auth/sendsmslogic_test.go
Normal file
@@ -0,0 +1,671 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
)
|
||||
|
||||
// RedisInterface 定义Redis接口
|
||||
type RedisInterface interface {
|
||||
Exists(key string) (bool, error)
|
||||
Get(key string) (string, error)
|
||||
Setex(key, value string, seconds int) error
|
||||
Incr(key string) (int64, error)
|
||||
Expire(key string, seconds int) error
|
||||
Ttl(key string) (int, error)
|
||||
Del(key string) (int64, error)
|
||||
}
|
||||
|
||||
// MockRedis 模拟Redis客户端
|
||||
type MockRedis struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRedis) Exists(key string) (bool, error) {
|
||||
args := m.Called(key)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Get(key string) (string, error) {
|
||||
args := m.Called(key)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Setex(key, value string, seconds int) error {
|
||||
args := m.Called(key, value, seconds)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Incr(key string) (int64, error) {
|
||||
args := m.Called(key)
|
||||
return args.Get(0).(int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Expire(key string, seconds int) error {
|
||||
args := m.Called(key, seconds)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Ttl(key string) (int, error) {
|
||||
args := m.Called(key)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Del(key string) (int64, error) {
|
||||
args := m.Called(key)
|
||||
return args.Get(0).(int64), args.Error(1)
|
||||
}
|
||||
|
||||
// TestServiceContext 测试专用的ServiceContext
|
||||
type TestServiceContext struct {
|
||||
Config config.Config
|
||||
Redis RedisInterface
|
||||
}
|
||||
|
||||
// TestSendSmsLogic 测试专用的SendSmsLogic
|
||||
type TestSendSmsLogic struct {
|
||||
logx.Logger
|
||||
ctx context.Context
|
||||
svcCtx *TestServiceContext
|
||||
}
|
||||
|
||||
// 创建测试用的ServiceContext
|
||||
func createTestServiceContext() *TestServiceContext {
|
||||
return &TestServiceContext{
|
||||
Config: config.Config{
|
||||
Encrypt: config.Encrypt{
|
||||
SecretKey: "test-secret-key",
|
||||
},
|
||||
VerifyCode: config.VerifyCode{
|
||||
ValidTime: 300,
|
||||
},
|
||||
},
|
||||
Redis: &MockRedis{},
|
||||
}
|
||||
}
|
||||
|
||||
// 创建测试用的SendSmsLogic
|
||||
func createTestLogic() (*TestSendSmsLogic, *MockRedis) {
|
||||
svcCtx := createTestServiceContext()
|
||||
mockRedis := svcCtx.Redis.(*MockRedis)
|
||||
|
||||
logic := &TestSendSmsLogic{
|
||||
ctx: context.Background(),
|
||||
svcCtx: svcCtx,
|
||||
}
|
||||
|
||||
return logic, mockRedis
|
||||
}
|
||||
|
||||
// 测试方法实现
|
||||
func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
|
||||
// 限制单个手机号的每种短信类型:
|
||||
// - 1分钟内最多获取1次验证码
|
||||
// - 1小时内最多获取5次验证码
|
||||
// - 24小时内最多获取20次验证码
|
||||
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := mobileKey + ":minute"
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return assert.AnError
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := mobileKey + ":hour"
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "5" {
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
// 检查24小时限制
|
||||
dayKey := mobileKey + ":day"
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "20" {
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) checkIPRateLimit() error {
|
||||
// 限制单个IP:
|
||||
// - 1分钟内最多获取10次验证码
|
||||
// - 1小时内最多获取50次验证码
|
||||
// - 超过阈值后IP被临时封禁
|
||||
|
||||
clientIP := l.getClientIP()
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
|
||||
// 检查IP是否被封禁
|
||||
bannedKey := ipKey + ":banned"
|
||||
exists, err := l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ttl > 0 {
|
||||
return assert.AnError
|
||||
} else {
|
||||
// 封禁时间已过,清除封禁状态
|
||||
l.svcCtx.Redis.Del(bannedKey)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := ipKey + ":minute"
|
||||
count, err := l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "10" {
|
||||
// 封禁IP 5分钟
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := ipKey + ":hour"
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "50" {
|
||||
// 封禁IP 1小时
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) getClientIP() string {
|
||||
// 从上下文中获取请求信息
|
||||
if l.ctx != nil {
|
||||
// 尝试从上下文中获取IP
|
||||
if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 默认返回本地IP,实际使用时应该从请求中获取
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error {
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 记录手机号请求次数
|
||||
mobileKey := "security:captcha:mobile:" + mobile
|
||||
|
||||
// 1分钟限制标记
|
||||
minuteKey := mobileKey + ":minute"
|
||||
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 1小时计数
|
||||
hourKey := mobileKey + ":hour"
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := mobileKey + ":day"
|
||||
_, err = l.svcCtx.Redis.Incr(dayKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置24小时过期
|
||||
err = l.svcCtx.Redis.Expire(dayKey, 86400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录IP请求次数
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = ipKey + ":minute"
|
||||
_, err = l.svcCtx.Redis.Incr(minuteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置1分钟过期
|
||||
err = l.svcCtx.Redis.Expire(minuteKey, 60)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = ipKey + ":hour"
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) {
|
||||
status := make(map[string]interface{})
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 检查手机号防护状态
|
||||
mobileKey := "security:captcha:mobile:" + mobile
|
||||
|
||||
// 1分钟限制状态
|
||||
minuteKey := mobileKey + ":minute"
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["mobileMinuteLimited"] = exists
|
||||
|
||||
// 1小时计数
|
||||
hourKey := mobileKey + ":hour"
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["mobileHourCount"] = int64(3)
|
||||
status["mobileHourRemaining"] = int64(2)
|
||||
} else {
|
||||
status["mobileHourCount"] = int64(0)
|
||||
status["mobileHourRemaining"] = int64(5)
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := mobileKey + ":day"
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["mobileDayCount"] = int64(15)
|
||||
status["mobileDayRemaining"] = int64(5)
|
||||
} else {
|
||||
status["mobileDayCount"] = int64(0)
|
||||
status["mobileDayRemaining"] = int64(20)
|
||||
}
|
||||
|
||||
// 检查IP防护状态
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
|
||||
// IP封禁状态
|
||||
bannedKey := ipKey + ":banned"
|
||||
exists, err = l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["ipBanned"] = true
|
||||
status["ipBanRemainingTime"] = ttl
|
||||
} else {
|
||||
status["ipBanned"] = false
|
||||
}
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = ipKey + ":minute"
|
||||
count, err = l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["ipMinuteCount"] = int64(5)
|
||||
status["ipMinuteRemaining"] = int64(5)
|
||||
} else {
|
||||
status["ipMinuteCount"] = int64(0)
|
||||
status["ipMinuteRemaining"] = int64(10)
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = ipKey + ":hour"
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["ipHourCount"] = int64(25)
|
||||
status["ipHourRemaining"] = int64(25)
|
||||
} else {
|
||||
status["ipHourCount"] = int64(0)
|
||||
status["ipHourRemaining"] = int64(50)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func TestCheckMobileRateLimit(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mobile string
|
||||
setupMocks func()
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "正常请求 - 无限制",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||||
// 1小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
|
||||
// 24小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "1分钟内重复请求",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "1小时内超过限制",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||||
// 1小时计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "24小时内超过限制",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||||
// 1小时计数检查 - 正常
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
|
||||
// 24小时计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 重置mock
|
||||
mockRedis.ExpectedCalls = nil
|
||||
mockRedis.Calls = nil
|
||||
|
||||
// 设置mock期望
|
||||
tt.setupMocks()
|
||||
|
||||
// 执行测试
|
||||
err := logic.checkMobileRateLimit(tt.mobile, "login")
|
||||
|
||||
// 验证结果
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 验证mock调用
|
||||
mockRedis.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPRateLimit(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func()
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "正常请求 - 无限制",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
// 1分钟计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||||
// 1小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "IP被封禁且未过期",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
|
||||
// 获取剩余封禁时间
|
||||
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "IP封禁已过期",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
|
||||
// 获取剩余封禁时间 - 已过期
|
||||
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil)
|
||||
// 清除过期封禁状态
|
||||
mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil)
|
||||
// 1分钟计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||||
// 1小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "1分钟内超过限制 - 触发短期封禁",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
// 1分钟计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil)
|
||||
// 设置短期封禁
|
||||
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "1小时内超过限制 - 触发长期封禁",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
// 1分钟计数检查 - 正常
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||||
// 1小时计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil)
|
||||
// 设置长期封禁
|
||||
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 重置mock
|
||||
mockRedis.ExpectedCalls = nil
|
||||
mockRedis.Calls = nil
|
||||
|
||||
// 设置mock期望
|
||||
tt.setupMocks()
|
||||
|
||||
// 执行测试
|
||||
err := logic.checkIPRateLimit()
|
||||
|
||||
// 验证结果
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 验证mock调用
|
||||
mockRedis.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCaptchaRequest(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||||
|
||||
// 设置mock期望
|
||||
mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil)
|
||||
|
||||
// 执行测试
|
||||
err := logic.recordCaptchaRequest("13800138000")
|
||||
|
||||
// 验证结果
|
||||
assert.NoError(t, err)
|
||||
mockRedis.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetCaptchaProtectionStatus(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||||
|
||||
// 设置mock期望
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil)
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil)
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil)
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil)
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil)
|
||||
|
||||
// 执行测试
|
||||
status, err := logic.GetCaptchaProtectionStatus("13800138000")
|
||||
|
||||
// 验证结果
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, status)
|
||||
|
||||
// 验证手机号状态
|
||||
assert.Equal(t, false, status["mobileMinuteLimited"])
|
||||
assert.Equal(t, int64(3), status["mobileHourCount"])
|
||||
assert.Equal(t, int64(2), status["mobileHourRemaining"])
|
||||
assert.Equal(t, int64(15), status["mobileDayCount"])
|
||||
assert.Equal(t, int64(5), status["mobileDayRemaining"])
|
||||
|
||||
// 验证IP状态
|
||||
assert.Equal(t, false, status["ipBanned"])
|
||||
assert.Equal(t, int64(5), status["ipMinuteCount"])
|
||||
assert.Equal(t, int64(5), status["ipMinuteRemaining"])
|
||||
assert.Equal(t, int64(25), status["ipHourCount"])
|
||||
assert.Equal(t, int64(25), status["ipHourRemaining"])
|
||||
|
||||
mockRedis.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
logic, _ := createTestLogic()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "从上下文获取IP",
|
||||
ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"),
|
||||
expected: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "上下文无IP - 返回默认IP",
|
||||
ctx: context.Background(),
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "上下文IP为空 - 返回默认IP",
|
||||
ctx: context.WithValue(context.Background(), "client_ip", ""),
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logic.ctx = tt.ctx
|
||||
ip := logic.getClientIP()
|
||||
assert.Equal(t, tt.expected, ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
"tyc-server/common/ctxdata"
|
||||
"tyc-server/common/xerr"
|
||||
"tyc-server/pkg/lzkit/crypto"
|
||||
"tyc-server/pkg/lzkit/delay"
|
||||
@@ -38,10 +39,10 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex
|
||||
|
||||
func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) {
|
||||
// 获取当前用户ID
|
||||
// userId, err := ctxdata.GetUidFromCtx(l.ctx)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
|
||||
// }
|
||||
userId, err := ctxdata.GetUidFromCtx(l.ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取订单信息
|
||||
order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId)
|
||||
@@ -52,9 +53,9 @@ func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailB
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err)
|
||||
}
|
||||
// 安全验证:确保订单属于当前用户
|
||||
// if order.UserId != userId {
|
||||
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
|
||||
// }
|
||||
if order.UserId != userId {
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
|
||||
}
|
||||
// 创建渐进式延迟策略实例
|
||||
progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user