382 lines
10 KiB
Go
382 lines
10 KiB
Go
package auth
|
||
|
||
import (
|
||
"context"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
|
||
"tyc-server/app/main/api/internal/config"
|
||
"tyc-server/app/main/api/internal/svc"
|
||
)
|
||
|
||
// 集成测试配置
|
||
var integrationTestConfig = config.Config{
|
||
Encrypt: config.Encrypt{
|
||
SecretKey: "test-secret-key",
|
||
},
|
||
VerifyCode: config.VerifyCode{
|
||
ValidTime: 300,
|
||
},
|
||
}
|
||
|
||
// 创建集成测试用的ServiceContext
|
||
func createIntegrationTestServiceContext() *svc.ServiceContext {
|
||
redisConf := redis.RedisConf{
|
||
Host: "127.0.0.1:6379",
|
||
Type: "node",
|
||
Pass: "",
|
||
}
|
||
|
||
return &svc.ServiceContext{
|
||
Config: integrationTestConfig,
|
||
Redis: redis.MustNewRedis(redisConf),
|
||
}
|
||
}
|
||
|
||
// 清理测试数据
|
||
func cleanupTestData(logic *SendSmsLogic, mobile, clientIP string) {
|
||
// 清理手机号相关数据
|
||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||
logic.svcCtx.Redis.Del(mobileKey+":minute", mobileKey+":hour", mobileKey+":day")
|
||
|
||
// 清理IP相关数据
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
logic.svcCtx.Redis.Del(ipKey+":minute", ipKey+":hour", ipKey+":banned")
|
||
}
|
||
|
||
// 创建集成测试用的SendSmsLogic
|
||
func createIntegrationTestLogic() *SendSmsLogic {
|
||
svcCtx := createIntegrationTestServiceContext()
|
||
|
||
return &SendSmsLogic{
|
||
ctx: context.Background(),
|
||
svcCtx: svcCtx,
|
||
}
|
||
}
|
||
|
||
// 跳过集成测试的标志
|
||
var skipIntegrationTests = true
|
||
|
||
func TestIntegrationCheckMobileRateLimit(t *testing.T) {
|
||
if skipIntegrationTests {
|
||
t.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
|
||
mobile := "13800138000"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, "192.168.1.100")
|
||
|
||
t.Run("正常请求 - 无限制", func(t *testing.T) {
|
||
err := logic.checkMobileRateLimit(mobile, "login")
|
||
assert.NoError(t, err)
|
||
})
|
||
|
||
t.Run("1分钟内重复请求", func(t *testing.T) {
|
||
// 设置1分钟限制标记
|
||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||
err := logic.svcCtx.Redis.Setex(mobileKey+":minute", "1", 60)
|
||
assert.NoError(t, err)
|
||
|
||
// 再次请求应该被拒绝
|
||
err = logic.checkMobileRateLimit(mobile, "login")
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
|
||
|
||
// 清理限制标记
|
||
logic.svcCtx.Redis.Del(mobileKey + ":minute")
|
||
})
|
||
|
||
t.Run("1小时内超过限制", func(t *testing.T) {
|
||
// 设置1小时计数为5(达到限制)
|
||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||
err := logic.svcCtx.Redis.Setex(mobileKey+":hour", "5", 3600)
|
||
assert.NoError(t, err)
|
||
|
||
// 请求应该被拒绝
|
||
err = logic.checkMobileRateLimit(mobile, "login")
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "1小时内获取验证码次数过多")
|
||
|
||
// 清理计数
|
||
logic.svcCtx.Redis.Del(mobileKey + ":hour")
|
||
})
|
||
|
||
t.Run("24小时内超过限制", func(t *testing.T) {
|
||
// 设置24小时计数为20(达到限制)
|
||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||
err := logic.svcCtx.Redis.Setex(mobileKey+":day", "20", 86400)
|
||
assert.NoError(t, err)
|
||
|
||
// 请求应该被拒绝
|
||
err = logic.checkMobileRateLimit(mobile, "login")
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "24小时内获取验证码次数过多")
|
||
|
||
// 清理计数
|
||
logic.svcCtx.Redis.Del(mobileKey + ":day")
|
||
})
|
||
}
|
||
|
||
func TestIntegrationCheckIPRateLimit(t *testing.T) {
|
||
if skipIntegrationTests {
|
||
t.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
|
||
mobile := "13800138000"
|
||
clientIP := "192.168.1.100"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, clientIP)
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||
|
||
t.Run("正常请求 - 无限制", func(t *testing.T) {
|
||
err := logic.checkIPRateLimit()
|
||
assert.NoError(t, err)
|
||
})
|
||
|
||
t.Run("IP被封禁且未过期", func(t *testing.T) {
|
||
// 设置IP封禁状态
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
err := logic.svcCtx.Redis.Setex(ipKey+":banned", "1", 300)
|
||
assert.NoError(t, err)
|
||
|
||
// 请求应该被拒绝
|
||
err = logic.checkIPRateLimit()
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "IP被临时封禁")
|
||
|
||
// 清理封禁状态
|
||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||
})
|
||
|
||
t.Run("1分钟内超过限制 - 触发短期封禁", func(t *testing.T) {
|
||
// 设置1分钟计数为10(达到限制)
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
err := logic.svcCtx.Redis.Setex(ipKey+":minute", "10", 60)
|
||
assert.NoError(t, err)
|
||
|
||
// 请求应该被拒绝并触发封禁
|
||
err = logic.checkIPRateLimit()
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁5分钟")
|
||
|
||
// 验证IP被封禁
|
||
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
|
||
assert.NoError(t, err)
|
||
assert.True(t, exists)
|
||
|
||
// 清理数据
|
||
logic.svcCtx.Redis.Del(ipKey + ":minute")
|
||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||
})
|
||
|
||
t.Run("1小时内超过限制 - 触发长期封禁", func(t *testing.T) {
|
||
// 设置1小时计数为50(达到限制)
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
err := logic.svcCtx.Redis.Setex(ipKey+":hour", "50", 3600)
|
||
assert.NoError(t, err)
|
||
|
||
// 请求应该被拒绝并触发封禁
|
||
err = logic.checkIPRateLimit()
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁1小时")
|
||
|
||
// 验证IP被封禁
|
||
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
|
||
assert.NoError(t, err)
|
||
assert.True(t, exists)
|
||
|
||
// 清理数据
|
||
logic.svcCtx.Redis.Del(ipKey + ":hour")
|
||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||
})
|
||
}
|
||
|
||
func TestIntegrationRecordCaptchaRequest(t *testing.T) {
|
||
if skipIntegrationTests {
|
||
t.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
|
||
mobile := "13800138000"
|
||
clientIP := "192.168.1.100"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, clientIP)
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||
|
||
t.Run("记录手机号请求次数", func(t *testing.T) {
|
||
err := logic.recordCaptchaRequest(mobile, "login")
|
||
assert.NoError(t, err)
|
||
|
||
// 验证1分钟限制标记
|
||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||
exists, err := logic.svcCtx.Redis.Exists(mobileKey + ":minute")
|
||
assert.NoError(t, err)
|
||
assert.True(t, exists)
|
||
|
||
// 验证1小时计数
|
||
count, err := logic.svcCtx.Redis.Get(mobileKey + ":hour")
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, "1", count)
|
||
|
||
// 验证24小时计数
|
||
count, err = logic.svcCtx.Redis.Get(mobileKey + ":day")
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, "1", count)
|
||
})
|
||
|
||
t.Run("记录IP请求次数", func(t *testing.T) {
|
||
err := logic.recordCaptchaRequest(mobile, "register")
|
||
assert.NoError(t, err)
|
||
|
||
// 验证IP 1分钟计数
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
count, err := logic.svcCtx.Redis.Get(ipKey + ":minute")
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, "2", count) // 第二次调用
|
||
|
||
// 验证IP 1小时计数
|
||
count, err = logic.svcCtx.Redis.Get(ipKey + ":hour")
|
||
assert.NoError(t, err)
|
||
assert.Equal(t, "2", count) // 第二次调用
|
||
})
|
||
}
|
||
|
||
func TestIntegrationGetCaptchaProtectionStatus(t *testing.T) {
|
||
if skipIntegrationTests {
|
||
t.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
|
||
mobile := "13800138000"
|
||
clientIP := "192.168.1.100"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, clientIP)
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||
|
||
t.Run("获取防护状态", func(t *testing.T) {
|
||
status, err := logic.GetCaptchaProtectionStatus(mobile, "login")
|
||
assert.NoError(t, err)
|
||
assert.NotNil(t, status)
|
||
|
||
// 验证基本信息
|
||
assert.Equal(t, mobile, status["mobile"])
|
||
assert.Equal(t, "login", status["actionType"])
|
||
assert.Equal(t, clientIP, status["clientIP"])
|
||
|
||
// 验证手机号状态
|
||
assert.False(t, status["mobileMinuteLimited"].(bool))
|
||
assert.Equal(t, int64(0), status["mobileHourCount"])
|
||
assert.Equal(t, int64(5), status["mobileHourRemaining"])
|
||
assert.Equal(t, int64(0), status["mobileDayCount"])
|
||
assert.Equal(t, int64(20), status["mobileDayRemaining"])
|
||
|
||
// 验证IP状态
|
||
assert.False(t, status["ipBanned"].(bool))
|
||
assert.Equal(t, int64(0), status["ipMinuteCount"])
|
||
assert.Equal(t, int64(10), status["ipMinuteRemaining"])
|
||
assert.Equal(t, int64(0), status["ipHourCount"])
|
||
assert.Equal(t, int64(50), status["ipHourRemaining"])
|
||
})
|
||
}
|
||
|
||
func TestIntegrationEndToEnd(t *testing.T) {
|
||
if skipIntegrationTests {
|
||
t.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
|
||
mobile := "13800138000"
|
||
clientIP := "192.168.1.100"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, clientIP)
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||
|
||
t.Run("完整流程测试", func(t *testing.T) {
|
||
// 第一次请求 - 应该成功
|
||
err := logic.checkCaptchaProtection(mobile, "login")
|
||
assert.NoError(t, err)
|
||
|
||
// 记录请求
|
||
err = logic.recordCaptchaRequest(mobile, "login")
|
||
assert.NoError(t, err)
|
||
|
||
// 第二次请求 - 应该被拒绝(1分钟限制)
|
||
err = logic.checkCaptchaProtection(mobile, "login")
|
||
assert.Error(t, err)
|
||
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
|
||
|
||
// 不同短信类型 - 应该成功
|
||
err = logic.checkCaptchaProtection(mobile, "register")
|
||
assert.NoError(t, err)
|
||
})
|
||
}
|
||
|
||
// 性能基准测试
|
||
func BenchmarkCheckCaptchaProtection(b *testing.B) {
|
||
if skipIntegrationTests {
|
||
b.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
mobile := "13800138000"
|
||
clientIP := "192.168.1.100"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, clientIP)
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
// 使用不同的手机号避免限制
|
||
testMobile := mobile + "_" + string(rune(i%10))
|
||
logic.checkCaptchaProtection(testMobile, "login")
|
||
}
|
||
}
|
||
|
||
func BenchmarkRecordCaptchaRequest(b *testing.B) {
|
||
if skipIntegrationTests {
|
||
b.Skip("跳过集成测试,需要Redis服务")
|
||
}
|
||
|
||
logic := createIntegrationTestLogic()
|
||
mobile := "13800138000"
|
||
clientIP := "192.168.1.100"
|
||
|
||
// 清理测试数据
|
||
defer cleanupTestData(logic, mobile, clientIP)
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||
|
||
b.ResetTimer()
|
||
for i := 0; i < b.N; i++ {
|
||
// 使用不同的手机号避免限制
|
||
testMobile := mobile + "_" + string(rune(i%10))
|
||
logic.recordCaptchaRequest(testMobile, "login")
|
||
}
|
||
}
|