This commit is contained in:
2025-08-31 14:18:31 +08:00
parent 30ace3faa2
commit 4be4d6b6da
19 changed files with 3472 additions and 7 deletions

View File

@@ -1,9 +1,20 @@
package auth
// 验证码防护说明:
// 1. checkCaptchaProtection: 在发送短信前检查防护状态
// 2. recordCaptchaRequest: 在短信发送成功后记录请求次数
// 3. GetCaptchaProtectionStatus: 获取防护状态用于调试和监控
//
// 防护规则:
// - 单个手机号1分钟内最多1次1小时内最多5次24小时内最多20次
// - 单个IP1分钟内最多10次1小时内最多50次超过阈值后IP被临时封禁
// - 防止验证码爆破攻击,控制短信发送成本
import (
"context"
"fmt"
"math/rand"
"strconv"
"time"
"tyc-server/common/xerr"
"tyc-server/pkg/lzkit/crypto"
@@ -18,6 +29,7 @@ import (
"github.com/alibabacloud-go/tea-utils/v2/service"
"github.com/alibabacloud-go/tea/tea"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
)
type SendSmsLogic struct {
@@ -40,6 +52,12 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err)
}
// 验证码防护检查
if err := l.checkCaptchaProtection(req.Mobile, req.ActionType); err != nil {
return err
}
// 检查手机号是否在一分钟内已发送过验证码
limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile)
exists, err := l.svcCtx.Redis.Exists(limitCodeKey)
@@ -62,6 +80,13 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
if *smsResp.Body.Code != "OK" {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message)
}
// 短信发送成功,记录请求次数
if err := l.recordCaptchaRequest(req.Mobile, req.ActionType); err != nil {
logx.Errorf("记录验证码请求失败: %v", err)
// 不影响主流程,只记录日志
}
codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile)
// 将验证码保存到 Redis设置过期时间
err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟
@@ -103,3 +128,324 @@ func (l *SendSmsLogic) sendSmsRequest(mobile, code string) (*dysmsapi.SendSmsRes
runtime := &service.RuntimeOptions{}
return cli.SendSmsWithOptions(request, runtime)
}
// checkCaptchaProtection 检查验证码获取防护
func (l *SendSmsLogic) checkCaptchaProtection(mobile string, actionType string) error {
// 1. 检查手机号获取验证码频率
if err := l.checkMobileRateLimit(mobile, actionType); err != nil {
return err
}
// 2. 检查IP获取验证码频率
if err := l.checkIPRateLimit(); err != nil {
return err
}
return nil
}
// checkMobileRateLimit 检查手机号频率限制
func (l *SendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
// 限制单个手机号的每种短信类型:
// - 1分钟内最多获取1次验证码
// - 1小时内最多获取5次验证码
// - 24小时内最多获取20次验证码
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 检查1分钟限制
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
logx.Errorf("检查手机号1分钟限制失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if exists {
return errors.Wrapf(xerr.NewErrMsg("1分钟内已获取过验证码请稍后再试"), "验证码防护 - 手机号1分钟内重复请求: %s", mobile)
}
// 检查1小时限制
hourKey := fmt.Sprintf("%s:hour", mobileKey)
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取手机号1小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 5 {
return errors.Wrapf(xerr.NewErrMsg("1小时内获取验证码次数过多请稍后再试"), "验证码防护 - 手机号1小时内超过限制: %s", mobile)
}
}
// 检查24小时限制
dayKey := fmt.Sprintf("%s:day", mobileKey)
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取手机号24小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if dayCount, _ := strconv.ParseInt(count, 10, 64); dayCount >= 20 {
return errors.Wrapf(xerr.NewErrMsg("24小时内获取验证码次数过多请稍后再试"), "验证码防护 - 手机号24小时内超过限制: %s", mobile)
}
}
return nil
}
// checkIPRateLimit 检查IP频率限制
func (l *SendSmsLogic) checkIPRateLimit() error {
// 限制单个IP
// - 1分钟内最多获取10次验证码
// - 1小时内最多获取50次验证码
// - 超过阈值后IP被临时封禁
clientIP := l.getClientIP()
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// 检查IP是否被封禁
bannedKey := fmt.Sprintf("%s:banned", ipKey)
exists, err := l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
logx.Errorf("检查IP封禁状态失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
logx.Errorf("获取IP封禁剩余时间失败: %v", err)
}
if ttl > 0 {
return errors.Wrapf(xerr.NewErrMsg(fmt.Sprintf("IP被临时封禁请%d秒后再试", ttl)), "验证码防护 - IP被临时封禁: %s", clientIP)
} else {
// 封禁时间已过,清除封禁状态
l.svcCtx.Redis.Del(bannedKey)
}
}
// 检查1分钟限制
minuteKey := fmt.Sprintf("%s:minute", ipKey)
count, err := l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取IP1分钟计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if minuteCount, _ := strconv.ParseInt(count, 10, 64); minuteCount >= 10 {
// 封禁IP 5分钟
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
if err != nil {
logx.Errorf("封禁IP失败: %v", err)
}
logx.Errorf("验证码防护 - IP被临时封禁: %s, 封禁时间: 300秒", clientIP)
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁已被临时封禁5分钟"), "验证码防护 - IP被临时封禁: %s", clientIP)
}
}
// 检查1小时限制
hourKey := fmt.Sprintf("%s:hour", ipKey)
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取IP1小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 50 {
// 封禁IP 1小时
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
if err != nil {
logx.Errorf("封禁IP失败: %v", err)
}
logx.Errorf("验证码防护 - IP被长期封禁: %s, 封禁时间: 3600秒", clientIP)
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁已被临时封禁1小时"), "验证码防护 - IP被长期封禁: %s", clientIP)
}
}
return nil
}
// getClientIP 获取客户端真实IP
func (l *SendSmsLogic) getClientIP() string {
if l.ctx != nil {
// 尝试从上下文中获取IP
if ip, ok := l.ctx.Value("client_ip").(string); ok {
return ip
}
}
// 默认返回本地IP实际使用时应该从请求中获取
return "127.0.0.1"
}
// recordCaptchaRequest 记录验证码请求次数
func (l *SendSmsLogic) recordCaptchaRequest(mobile string, actionType string) error {
clientIP := l.getClientIP()
// 记录手机号请求次数
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 1分钟限制标记
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
if err != nil {
logx.Errorf("设置手机号1分钟限制标记失败: %v", err)
}
// 1小时计数
hourKey := fmt.Sprintf("%s:hour", mobileKey)
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
logx.Errorf("增加手机号1小时计数失败: %v", err)
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
logx.Errorf("设置手机号1小时计数过期时间失败: %v", err)
}
// 24小时计数
dayKey := fmt.Sprintf("%s:day", mobileKey)
_, err = l.svcCtx.Redis.Incr(dayKey)
if err != nil {
logx.Errorf("增加手机号24小时计数失败: %v", err)
}
// 设置24小时过期
err = l.svcCtx.Redis.Expire(dayKey, 86400)
if err != nil {
logx.Errorf("设置手机号24小时计数过期时间失败: %v", err)
}
// 记录IP请求次数
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// IP 1分钟计数
minuteKey = fmt.Sprintf("%s:minute", ipKey)
_, err = l.svcCtx.Redis.Incr(minuteKey)
if err != nil {
logx.Errorf("增加IP1分钟计数失败: %v", err)
}
// 设置1分钟过期
err = l.svcCtx.Redis.Expire(minuteKey, 60)
if err != nil {
logx.Errorf("设置IP1分钟计数过期时间失败: %v", err)
}
// IP 1小时计数
hourKey = fmt.Sprintf("%s:hour", ipKey)
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
logx.Errorf("增加IP1小时计数失败: %v", err)
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
logx.Errorf("设置IP1小时计数过期时间失败: %v", err)
}
return nil
}
// GetCaptchaProtectionStatus 获取验证码防护状态(用于调试和监控)
func (l *SendSmsLogic) GetCaptchaProtectionStatus(mobile string, actionType string) (map[string]interface{}, error) {
status := make(map[string]interface{})
clientIP := l.getClientIP()
// 检查手机号防护状态
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 1分钟限制状态
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return nil, err
}
status["mobileMinuteLimited"] = exists
// 1小时计数
hourKey := fmt.Sprintf("%s:hour", mobileKey)
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["mobileHourCount"] = hourCount
status["mobileHourRemaining"] = 5 - hourCount
}
} else {
status["mobileHourCount"] = 0
status["mobileHourRemaining"] = 5
}
// 24小时计数
dayKey := fmt.Sprintf("%s:day", mobileKey)
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if dayCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["mobileDayCount"] = dayCount
status["mobileDayRemaining"] = 20 - dayCount
}
} else {
status["mobileDayCount"] = 0
status["mobileDayRemaining"] = 20
}
// 检查IP防护状态
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// IP封禁状态
bannedKey := fmt.Sprintf("%s:banned", ipKey)
exists, err = l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return nil, err
}
status["ipBanned"] = exists
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err == nil {
status["ipBanRemaining"] = ttl
}
}
// IP 1分钟计数
minuteKey = fmt.Sprintf("%s:minute", ipKey)
count, err = l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if minuteCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["ipMinuteCount"] = minuteCount
status["ipMinuteRemaining"] = 10 - minuteCount
}
} else {
status["ipMinuteCount"] = 0
status["ipMinuteRemaining"] = 10
}
// IP 1小时计数
hourKey = fmt.Sprintf("%s:hour", ipKey)
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["ipHourCount"] = hourCount
status["ipHourRemaining"] = 50 - hourCount
}
} else {
status["ipHourCount"] = 0
status["ipHourRemaining"] = 50
}
// 添加基本信息
status["mobile"] = mobile
status["actionType"] = actionType
status["clientIP"] = clientIP
return status, nil
}

View File

@@ -0,0 +1,381 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis"
"tyc-server/app/main/api/internal/config"
"tyc-server/app/main/api/internal/svc"
)
// 集成测试配置
var integrationTestConfig = config.Config{
Encrypt: config.Encrypt{
SecretKey: "test-secret-key",
},
VerifyCode: config.VerifyCode{
ValidTime: 300,
},
}
// 创建集成测试用的ServiceContext
func createIntegrationTestServiceContext() *svc.ServiceContext {
redisConf := redis.RedisConf{
Host: "127.0.0.1:6379",
Type: "node",
Pass: "",
}
return &svc.ServiceContext{
Config: integrationTestConfig,
Redis: redis.MustNewRedis(redisConf),
}
}
// 清理测试数据
func cleanupTestData(logic *SendSmsLogic, mobile, clientIP string) {
// 清理手机号相关数据
mobileKey := "security:captcha:mobile:" + mobile + ":login"
logic.svcCtx.Redis.Del(mobileKey+":minute", mobileKey+":hour", mobileKey+":day")
// 清理IP相关数据
ipKey := "security:captcha:ip:" + clientIP
logic.svcCtx.Redis.Del(ipKey+":minute", ipKey+":hour", ipKey+":banned")
}
// 创建集成测试用的SendSmsLogic
func createIntegrationTestLogic() *SendSmsLogic {
svcCtx := createIntegrationTestServiceContext()
return &SendSmsLogic{
ctx: context.Background(),
svcCtx: svcCtx,
}
}
// 跳过集成测试的标志
var skipIntegrationTests = true
func TestIntegrationCheckMobileRateLimit(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
// 清理测试数据
defer cleanupTestData(logic, mobile, "192.168.1.100")
t.Run("正常请求 - 无限制", func(t *testing.T) {
err := logic.checkMobileRateLimit(mobile, "login")
assert.NoError(t, err)
})
t.Run("1分钟内重复请求", func(t *testing.T) {
// 设置1分钟限制标记
mobileKey := "security:captcha:mobile:" + mobile + ":login"
err := logic.svcCtx.Redis.Setex(mobileKey+":minute", "1", 60)
assert.NoError(t, err)
// 再次请求应该被拒绝
err = logic.checkMobileRateLimit(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
// 清理限制标记
logic.svcCtx.Redis.Del(mobileKey + ":minute")
})
t.Run("1小时内超过限制", func(t *testing.T) {
// 设置1小时计数为5达到限制
mobileKey := "security:captcha:mobile:" + mobile + ":login"
err := logic.svcCtx.Redis.Setex(mobileKey+":hour", "5", 3600)
assert.NoError(t, err)
// 请求应该被拒绝
err = logic.checkMobileRateLimit(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "1小时内获取验证码次数过多")
// 清理计数
logic.svcCtx.Redis.Del(mobileKey + ":hour")
})
t.Run("24小时内超过限制", func(t *testing.T) {
// 设置24小时计数为20达到限制
mobileKey := "security:captcha:mobile:" + mobile + ":login"
err := logic.svcCtx.Redis.Setex(mobileKey+":day", "20", 86400)
assert.NoError(t, err)
// 请求应该被拒绝
err = logic.checkMobileRateLimit(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "24小时内获取验证码次数过多")
// 清理计数
logic.svcCtx.Redis.Del(mobileKey + ":day")
})
}
func TestIntegrationCheckIPRateLimit(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("正常请求 - 无限制", func(t *testing.T) {
err := logic.checkIPRateLimit()
assert.NoError(t, err)
})
t.Run("IP被封禁且未过期", func(t *testing.T) {
// 设置IP封禁状态
ipKey := "security:captcha:ip:" + clientIP
err := logic.svcCtx.Redis.Setex(ipKey+":banned", "1", 300)
assert.NoError(t, err)
// 请求应该被拒绝
err = logic.checkIPRateLimit()
assert.Error(t, err)
assert.Contains(t, err.Error(), "IP被临时封禁")
// 清理封禁状态
logic.svcCtx.Redis.Del(ipKey + ":banned")
})
t.Run("1分钟内超过限制 - 触发短期封禁", func(t *testing.T) {
// 设置1分钟计数为10达到限制
ipKey := "security:captcha:ip:" + clientIP
err := logic.svcCtx.Redis.Setex(ipKey+":minute", "10", 60)
assert.NoError(t, err)
// 请求应该被拒绝并触发封禁
err = logic.checkIPRateLimit()
assert.Error(t, err)
assert.Contains(t, err.Error(), "IP请求过于频繁已被临时封禁5分钟")
// 验证IP被封禁
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
assert.NoError(t, err)
assert.True(t, exists)
// 清理数据
logic.svcCtx.Redis.Del(ipKey + ":minute")
logic.svcCtx.Redis.Del(ipKey + ":banned")
})
t.Run("1小时内超过限制 - 触发长期封禁", func(t *testing.T) {
// 设置1小时计数为50达到限制
ipKey := "security:captcha:ip:" + clientIP
err := logic.svcCtx.Redis.Setex(ipKey+":hour", "50", 3600)
assert.NoError(t, err)
// 请求应该被拒绝并触发封禁
err = logic.checkIPRateLimit()
assert.Error(t, err)
assert.Contains(t, err.Error(), "IP请求过于频繁已被临时封禁1小时")
// 验证IP被封禁
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
assert.NoError(t, err)
assert.True(t, exists)
// 清理数据
logic.svcCtx.Redis.Del(ipKey + ":hour")
logic.svcCtx.Redis.Del(ipKey + ":banned")
})
}
func TestIntegrationRecordCaptchaRequest(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("记录手机号请求次数", func(t *testing.T) {
err := logic.recordCaptchaRequest(mobile, "login")
assert.NoError(t, err)
// 验证1分钟限制标记
mobileKey := "security:captcha:mobile:" + mobile + ":login"
exists, err := logic.svcCtx.Redis.Exists(mobileKey + ":minute")
assert.NoError(t, err)
assert.True(t, exists)
// 验证1小时计数
count, err := logic.svcCtx.Redis.Get(mobileKey + ":hour")
assert.NoError(t, err)
assert.Equal(t, "1", count)
// 验证24小时计数
count, err = logic.svcCtx.Redis.Get(mobileKey + ":day")
assert.NoError(t, err)
assert.Equal(t, "1", count)
})
t.Run("记录IP请求次数", func(t *testing.T) {
err := logic.recordCaptchaRequest(mobile, "register")
assert.NoError(t, err)
// 验证IP 1分钟计数
ipKey := "security:captcha:ip:" + clientIP
count, err := logic.svcCtx.Redis.Get(ipKey + ":minute")
assert.NoError(t, err)
assert.Equal(t, "2", count) // 第二次调用
// 验证IP 1小时计数
count, err = logic.svcCtx.Redis.Get(ipKey + ":hour")
assert.NoError(t, err)
assert.Equal(t, "2", count) // 第二次调用
})
}
func TestIntegrationGetCaptchaProtectionStatus(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("获取防护状态", func(t *testing.T) {
status, err := logic.GetCaptchaProtectionStatus(mobile, "login")
assert.NoError(t, err)
assert.NotNil(t, status)
// 验证基本信息
assert.Equal(t, mobile, status["mobile"])
assert.Equal(t, "login", status["actionType"])
assert.Equal(t, clientIP, status["clientIP"])
// 验证手机号状态
assert.False(t, status["mobileMinuteLimited"].(bool))
assert.Equal(t, int64(0), status["mobileHourCount"])
assert.Equal(t, int64(5), status["mobileHourRemaining"])
assert.Equal(t, int64(0), status["mobileDayCount"])
assert.Equal(t, int64(20), status["mobileDayRemaining"])
// 验证IP状态
assert.False(t, status["ipBanned"].(bool))
assert.Equal(t, int64(0), status["ipMinuteCount"])
assert.Equal(t, int64(10), status["ipMinuteRemaining"])
assert.Equal(t, int64(0), status["ipHourCount"])
assert.Equal(t, int64(50), status["ipHourRemaining"])
})
}
func TestIntegrationEndToEnd(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("完整流程测试", func(t *testing.T) {
// 第一次请求 - 应该成功
err := logic.checkCaptchaProtection(mobile, "login")
assert.NoError(t, err)
// 记录请求
err = logic.recordCaptchaRequest(mobile, "login")
assert.NoError(t, err)
// 第二次请求 - 应该被拒绝1分钟限制
err = logic.checkCaptchaProtection(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
// 不同短信类型 - 应该成功
err = logic.checkCaptchaProtection(mobile, "register")
assert.NoError(t, err)
})
}
// 性能基准测试
func BenchmarkCheckCaptchaProtection(b *testing.B) {
if skipIntegrationTests {
b.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// 使用不同的手机号避免限制
testMobile := mobile + "_" + string(rune(i%10))
logic.checkCaptchaProtection(testMobile, "login")
}
}
func BenchmarkRecordCaptchaRequest(b *testing.B) {
if skipIntegrationTests {
b.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// 使用不同的手机号避免限制
testMobile := mobile + "_" + string(rune(i%10))
logic.recordCaptchaRequest(testMobile, "login")
}
}

View File

@@ -0,0 +1,671 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
"tyc-server/app/main/api/internal/config"
)
// RedisInterface 定义Redis接口
type RedisInterface interface {
Exists(key string) (bool, error)
Get(key string) (string, error)
Setex(key, value string, seconds int) error
Incr(key string) (int64, error)
Expire(key string, seconds int) error
Ttl(key string) (int, error)
Del(key string) (int64, error)
}
// MockRedis 模拟Redis客户端
type MockRedis struct {
mock.Mock
}
func (m *MockRedis) Exists(key string) (bool, error) {
args := m.Called(key)
return args.Bool(0), args.Error(1)
}
func (m *MockRedis) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *MockRedis) Setex(key, value string, seconds int) error {
args := m.Called(key, value, seconds)
return args.Error(0)
}
func (m *MockRedis) Incr(key string) (int64, error) {
args := m.Called(key)
return args.Get(0).(int64), args.Error(1)
}
func (m *MockRedis) Expire(key string, seconds int) error {
args := m.Called(key, seconds)
return args.Error(0)
}
func (m *MockRedis) Ttl(key string) (int, error) {
args := m.Called(key)
return args.Int(0), args.Error(1)
}
func (m *MockRedis) Del(key string) (int64, error) {
args := m.Called(key)
return args.Get(0).(int64), args.Error(1)
}
// TestServiceContext 测试专用的ServiceContext
type TestServiceContext struct {
Config config.Config
Redis RedisInterface
}
// TestSendSmsLogic 测试专用的SendSmsLogic
type TestSendSmsLogic struct {
logx.Logger
ctx context.Context
svcCtx *TestServiceContext
}
// 创建测试用的ServiceContext
func createTestServiceContext() *TestServiceContext {
return &TestServiceContext{
Config: config.Config{
Encrypt: config.Encrypt{
SecretKey: "test-secret-key",
},
VerifyCode: config.VerifyCode{
ValidTime: 300,
},
},
Redis: &MockRedis{},
}
}
// 创建测试用的SendSmsLogic
func createTestLogic() (*TestSendSmsLogic, *MockRedis) {
svcCtx := createTestServiceContext()
mockRedis := svcCtx.Redis.(*MockRedis)
logic := &TestSendSmsLogic{
ctx: context.Background(),
svcCtx: svcCtx,
}
return logic, mockRedis
}
// 测试方法实现
func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
// 限制单个手机号的每种短信类型:
// - 1分钟内最多获取1次验证码
// - 1小时内最多获取5次验证码
// - 24小时内最多获取20次验证码
mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType
// 检查1分钟限制
minuteKey := mobileKey + ":minute"
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return err
}
if exists {
return assert.AnError
}
// 检查1小时限制
hourKey := mobileKey + ":hour"
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "5" {
return assert.AnError
}
}
// 检查24小时限制
dayKey := mobileKey + ":day"
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "20" {
return assert.AnError
}
}
return nil
}
func (l *TestSendSmsLogic) checkIPRateLimit() error {
// 限制单个IP
// - 1分钟内最多获取10次验证码
// - 1小时内最多获取50次验证码
// - 超过阈值后IP被临时封禁
clientIP := l.getClientIP()
ipKey := "security:captcha:ip:" + clientIP
// 检查IP是否被封禁
bannedKey := ipKey + ":banned"
exists, err := l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return err
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
return err
}
if ttl > 0 {
return assert.AnError
} else {
// 封禁时间已过,清除封禁状态
l.svcCtx.Redis.Del(bannedKey)
}
}
// 检查1分钟限制
minuteKey := ipKey + ":minute"
count, err := l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "10" {
// 封禁IP 5分钟
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
if err != nil {
return err
}
return assert.AnError
}
}
// 检查1小时限制
hourKey := ipKey + ":hour"
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "50" {
// 封禁IP 1小时
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
if err != nil {
return err
}
return assert.AnError
}
}
return nil
}
func (l *TestSendSmsLogic) getClientIP() string {
// 从上下文中获取请求信息
if l.ctx != nil {
// 尝试从上下文中获取IP
if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" {
return ip
}
}
// 默认返回本地IP实际使用时应该从请求中获取
return "127.0.0.1"
}
func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error {
clientIP := l.getClientIP()
// 记录手机号请求次数
mobileKey := "security:captcha:mobile:" + mobile
// 1分钟限制标记
minuteKey := mobileKey + ":minute"
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
if err != nil {
return err
}
// 1小时计数
hourKey := mobileKey + ":hour"
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
return err
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
return err
}
// 24小时计数
dayKey := mobileKey + ":day"
_, err = l.svcCtx.Redis.Incr(dayKey)
if err != nil {
return err
}
// 设置24小时过期
err = l.svcCtx.Redis.Expire(dayKey, 86400)
if err != nil {
return err
}
// 记录IP请求次数
ipKey := "security:captcha:ip:" + clientIP
// IP 1分钟计数
minuteKey = ipKey + ":minute"
_, err = l.svcCtx.Redis.Incr(minuteKey)
if err != nil {
return err
}
// 设置1分钟过期
err = l.svcCtx.Redis.Expire(minuteKey, 60)
if err != nil {
return err
}
// IP 1小时计数
hourKey = ipKey + ":hour"
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
return err
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
return err
}
return nil
}
func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) {
status := make(map[string]interface{})
clientIP := l.getClientIP()
// 检查手机号防护状态
mobileKey := "security:captcha:mobile:" + mobile
// 1分钟限制状态
minuteKey := mobileKey + ":minute"
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return nil, err
}
status["mobileMinuteLimited"] = exists
// 1小时计数
hourKey := mobileKey + ":hour"
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["mobileHourCount"] = int64(3)
status["mobileHourRemaining"] = int64(2)
} else {
status["mobileHourCount"] = int64(0)
status["mobileHourRemaining"] = int64(5)
}
// 24小时计数
dayKey := mobileKey + ":day"
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["mobileDayCount"] = int64(15)
status["mobileDayRemaining"] = int64(5)
} else {
status["mobileDayCount"] = int64(0)
status["mobileDayRemaining"] = int64(20)
}
// 检查IP防护状态
ipKey := "security:captcha:ip:" + clientIP
// IP封禁状态
bannedKey := ipKey + ":banned"
exists, err = l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return nil, err
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
return nil, err
}
status["ipBanned"] = true
status["ipBanRemainingTime"] = ttl
} else {
status["ipBanned"] = false
}
// IP 1分钟计数
minuteKey = ipKey + ":minute"
count, err = l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["ipMinuteCount"] = int64(5)
status["ipMinuteRemaining"] = int64(5)
} else {
status["ipMinuteCount"] = int64(0)
status["ipMinuteRemaining"] = int64(10)
}
// IP 1小时计数
hourKey = ipKey + ":hour"
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["ipHourCount"] = int64(25)
status["ipHourRemaining"] = int64(25)
} else {
status["ipHourCount"] = int64(0)
status["ipHourRemaining"] = int64(50)
}
return status, nil
}
func TestCheckMobileRateLimit(t *testing.T) {
logic, mockRedis := createTestLogic()
tests := []struct {
name string
mobile string
setupMocks func()
expectedError bool
}{
{
name: "正常请求 - 无限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
// 24小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "1分钟内重复请求",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil)
},
expectedError: true,
},
{
name: "1小时内超过限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil)
},
expectedError: true,
},
{
name: "24小时内超过限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 正常
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
// 24小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil)
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置mock
mockRedis.ExpectedCalls = nil
mockRedis.Calls = nil
// 设置mock期望
tt.setupMocks()
// 执行测试
err := logic.checkMobileRateLimit(tt.mobile, "login")
// 验证结果
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// 验证mock调用
mockRedis.AssertExpectations(t)
})
}
}
func TestCheckIPRateLimit(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
tests := []struct {
name string
setupMocks func()
expectedError bool
}{
{
name: "正常请求 - 无限制",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "IP被封禁且未过期",
setupMocks: func() {
// IP封禁检查 - 存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
// 获取剩余封禁时间
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil)
},
expectedError: true,
},
{
name: "IP封禁已过期",
setupMocks: func() {
// IP封禁检查 - 存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
// 获取剩余封禁时间 - 已过期
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil)
// 清除过期封禁状态
mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil)
// 1分钟计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "1分钟内超过限制 - 触发短期封禁",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil)
// 设置短期封禁
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil)
},
expectedError: true,
},
{
name: "1小时内超过限制 - 触发长期封禁",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 正常
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil)
// 设置长期封禁
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil)
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置mock
mockRedis.ExpectedCalls = nil
mockRedis.Calls = nil
// 设置mock期望
tt.setupMocks()
// 执行测试
err := logic.checkIPRateLimit()
// 验证结果
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// 验证mock调用
mockRedis.AssertExpectations(t)
})
}
}
func TestRecordCaptchaRequest(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
// 设置mock期望
mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil)
mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil)
mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil)
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil)
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil)
// 执行测试
err := logic.recordCaptchaRequest("13800138000")
// 验证结果
assert.NoError(t, err)
mockRedis.AssertExpectations(t)
}
func TestGetCaptchaProtectionStatus(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
// 设置mock期望
mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil)
mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil)
mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil)
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil)
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil)
// 执行测试
status, err := logic.GetCaptchaProtectionStatus("13800138000")
// 验证结果
assert.NoError(t, err)
assert.NotNil(t, status)
// 验证手机号状态
assert.Equal(t, false, status["mobileMinuteLimited"])
assert.Equal(t, int64(3), status["mobileHourCount"])
assert.Equal(t, int64(2), status["mobileHourRemaining"])
assert.Equal(t, int64(15), status["mobileDayCount"])
assert.Equal(t, int64(5), status["mobileDayRemaining"])
// 验证IP状态
assert.Equal(t, false, status["ipBanned"])
assert.Equal(t, int64(5), status["ipMinuteCount"])
assert.Equal(t, int64(5), status["ipMinuteRemaining"])
assert.Equal(t, int64(25), status["ipHourCount"])
assert.Equal(t, int64(25), status["ipHourRemaining"])
mockRedis.AssertExpectations(t)
}
func TestGetClientIP(t *testing.T) {
logic, _ := createTestLogic()
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "从上下文获取IP",
ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"),
expected: "192.168.1.100",
},
{
name: "上下文无IP - 返回默认IP",
ctx: context.Background(),
expected: "127.0.0.1",
},
{
name: "上下文IP为空 - 返回默认IP",
ctx: context.WithValue(context.Background(), "client_ip", ""),
expected: "127.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logic.ctx = tt.ctx
ip := logic.getClientIP()
assert.Equal(t, tt.expected, ip)
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"time"
"tyc-server/common/ctxdata"
"tyc-server/common/xerr"
"tyc-server/pkg/lzkit/crypto"
"tyc-server/pkg/lzkit/delay"
@@ -38,10 +39,10 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex
func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) {
// 获取当前用户ID
// userId, err := ctxdata.GetUidFromCtx(l.ctx)
// if err != nil {
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
// }
userId, err := ctxdata.GetUidFromCtx(l.ctx)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
}
// 获取订单信息
order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId)
@@ -52,9 +53,9 @@ func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailB
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err)
}
// 安全验证:确保订单属于当前用户
// if order.UserId != userId {
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
// }
if order.UserId != userId {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
}
// 创建渐进式延迟策略实例
progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5)
if err != nil {