Files
tyc-server/app/main/api/internal/logic/auth/sendsmslogic_test.go
2025-08-31 14:18:31 +08:00

672 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
"tyc-server/app/main/api/internal/config"
)
// RedisInterface 定义Redis接口
type RedisInterface interface {
Exists(key string) (bool, error)
Get(key string) (string, error)
Setex(key, value string, seconds int) error
Incr(key string) (int64, error)
Expire(key string, seconds int) error
Ttl(key string) (int, error)
Del(key string) (int64, error)
}
// MockRedis 模拟Redis客户端
type MockRedis struct {
mock.Mock
}
func (m *MockRedis) Exists(key string) (bool, error) {
args := m.Called(key)
return args.Bool(0), args.Error(1)
}
func (m *MockRedis) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *MockRedis) Setex(key, value string, seconds int) error {
args := m.Called(key, value, seconds)
return args.Error(0)
}
func (m *MockRedis) Incr(key string) (int64, error) {
args := m.Called(key)
return args.Get(0).(int64), args.Error(1)
}
func (m *MockRedis) Expire(key string, seconds int) error {
args := m.Called(key, seconds)
return args.Error(0)
}
func (m *MockRedis) Ttl(key string) (int, error) {
args := m.Called(key)
return args.Int(0), args.Error(1)
}
func (m *MockRedis) Del(key string) (int64, error) {
args := m.Called(key)
return args.Get(0).(int64), args.Error(1)
}
// TestServiceContext 测试专用的ServiceContext
type TestServiceContext struct {
Config config.Config
Redis RedisInterface
}
// TestSendSmsLogic 测试专用的SendSmsLogic
type TestSendSmsLogic struct {
logx.Logger
ctx context.Context
svcCtx *TestServiceContext
}
// 创建测试用的ServiceContext
func createTestServiceContext() *TestServiceContext {
return &TestServiceContext{
Config: config.Config{
Encrypt: config.Encrypt{
SecretKey: "test-secret-key",
},
VerifyCode: config.VerifyCode{
ValidTime: 300,
},
},
Redis: &MockRedis{},
}
}
// 创建测试用的SendSmsLogic
func createTestLogic() (*TestSendSmsLogic, *MockRedis) {
svcCtx := createTestServiceContext()
mockRedis := svcCtx.Redis.(*MockRedis)
logic := &TestSendSmsLogic{
ctx: context.Background(),
svcCtx: svcCtx,
}
return logic, mockRedis
}
// 测试方法实现
func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
// 限制单个手机号的每种短信类型:
// - 1分钟内最多获取1次验证码
// - 1小时内最多获取5次验证码
// - 24小时内最多获取20次验证码
mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType
// 检查1分钟限制
minuteKey := mobileKey + ":minute"
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return err
}
if exists {
return assert.AnError
}
// 检查1小时限制
hourKey := mobileKey + ":hour"
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "5" {
return assert.AnError
}
}
// 检查24小时限制
dayKey := mobileKey + ":day"
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "20" {
return assert.AnError
}
}
return nil
}
func (l *TestSendSmsLogic) checkIPRateLimit() error {
// 限制单个IP
// - 1分钟内最多获取10次验证码
// - 1小时内最多获取50次验证码
// - 超过阈值后IP被临时封禁
clientIP := l.getClientIP()
ipKey := "security:captcha:ip:" + clientIP
// 检查IP是否被封禁
bannedKey := ipKey + ":banned"
exists, err := l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return err
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
return err
}
if ttl > 0 {
return assert.AnError
} else {
// 封禁时间已过,清除封禁状态
l.svcCtx.Redis.Del(bannedKey)
}
}
// 检查1分钟限制
minuteKey := ipKey + ":minute"
count, err := l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "10" {
// 封禁IP 5分钟
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
if err != nil {
return err
}
return assert.AnError
}
}
// 检查1小时限制
hourKey := ipKey + ":hour"
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "50" {
// 封禁IP 1小时
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
if err != nil {
return err
}
return assert.AnError
}
}
return nil
}
func (l *TestSendSmsLogic) getClientIP() string {
// 从上下文中获取请求信息
if l.ctx != nil {
// 尝试从上下文中获取IP
if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" {
return ip
}
}
// 默认返回本地IP实际使用时应该从请求中获取
return "127.0.0.1"
}
func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error {
clientIP := l.getClientIP()
// 记录手机号请求次数
mobileKey := "security:captcha:mobile:" + mobile
// 1分钟限制标记
minuteKey := mobileKey + ":minute"
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
if err != nil {
return err
}
// 1小时计数
hourKey := mobileKey + ":hour"
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
return err
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
return err
}
// 24小时计数
dayKey := mobileKey + ":day"
_, err = l.svcCtx.Redis.Incr(dayKey)
if err != nil {
return err
}
// 设置24小时过期
err = l.svcCtx.Redis.Expire(dayKey, 86400)
if err != nil {
return err
}
// 记录IP请求次数
ipKey := "security:captcha:ip:" + clientIP
// IP 1分钟计数
minuteKey = ipKey + ":minute"
_, err = l.svcCtx.Redis.Incr(minuteKey)
if err != nil {
return err
}
// 设置1分钟过期
err = l.svcCtx.Redis.Expire(minuteKey, 60)
if err != nil {
return err
}
// IP 1小时计数
hourKey = ipKey + ":hour"
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
return err
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
return err
}
return nil
}
func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) {
status := make(map[string]interface{})
clientIP := l.getClientIP()
// 检查手机号防护状态
mobileKey := "security:captcha:mobile:" + mobile
// 1分钟限制状态
minuteKey := mobileKey + ":minute"
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return nil, err
}
status["mobileMinuteLimited"] = exists
// 1小时计数
hourKey := mobileKey + ":hour"
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["mobileHourCount"] = int64(3)
status["mobileHourRemaining"] = int64(2)
} else {
status["mobileHourCount"] = int64(0)
status["mobileHourRemaining"] = int64(5)
}
// 24小时计数
dayKey := mobileKey + ":day"
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["mobileDayCount"] = int64(15)
status["mobileDayRemaining"] = int64(5)
} else {
status["mobileDayCount"] = int64(0)
status["mobileDayRemaining"] = int64(20)
}
// 检查IP防护状态
ipKey := "security:captcha:ip:" + clientIP
// IP封禁状态
bannedKey := ipKey + ":banned"
exists, err = l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return nil, err
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
return nil, err
}
status["ipBanned"] = true
status["ipBanRemainingTime"] = ttl
} else {
status["ipBanned"] = false
}
// IP 1分钟计数
minuteKey = ipKey + ":minute"
count, err = l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["ipMinuteCount"] = int64(5)
status["ipMinuteRemaining"] = int64(5)
} else {
status["ipMinuteCount"] = int64(0)
status["ipMinuteRemaining"] = int64(10)
}
// IP 1小时计数
hourKey = ipKey + ":hour"
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["ipHourCount"] = int64(25)
status["ipHourRemaining"] = int64(25)
} else {
status["ipHourCount"] = int64(0)
status["ipHourRemaining"] = int64(50)
}
return status, nil
}
func TestCheckMobileRateLimit(t *testing.T) {
logic, mockRedis := createTestLogic()
tests := []struct {
name string
mobile string
setupMocks func()
expectedError bool
}{
{
name: "正常请求 - 无限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
// 24小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "1分钟内重复请求",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil)
},
expectedError: true,
},
{
name: "1小时内超过限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil)
},
expectedError: true,
},
{
name: "24小时内超过限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 正常
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
// 24小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil)
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置mock
mockRedis.ExpectedCalls = nil
mockRedis.Calls = nil
// 设置mock期望
tt.setupMocks()
// 执行测试
err := logic.checkMobileRateLimit(tt.mobile, "login")
// 验证结果
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// 验证mock调用
mockRedis.AssertExpectations(t)
})
}
}
func TestCheckIPRateLimit(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
tests := []struct {
name string
setupMocks func()
expectedError bool
}{
{
name: "正常请求 - 无限制",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "IP被封禁且未过期",
setupMocks: func() {
// IP封禁检查 - 存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
// 获取剩余封禁时间
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil)
},
expectedError: true,
},
{
name: "IP封禁已过期",
setupMocks: func() {
// IP封禁检查 - 存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
// 获取剩余封禁时间 - 已过期
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil)
// 清除过期封禁状态
mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil)
// 1分钟计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "1分钟内超过限制 - 触发短期封禁",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil)
// 设置短期封禁
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil)
},
expectedError: true,
},
{
name: "1小时内超过限制 - 触发长期封禁",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 正常
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil)
// 设置长期封禁
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil)
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置mock
mockRedis.ExpectedCalls = nil
mockRedis.Calls = nil
// 设置mock期望
tt.setupMocks()
// 执行测试
err := logic.checkIPRateLimit()
// 验证结果
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// 验证mock调用
mockRedis.AssertExpectations(t)
})
}
}
func TestRecordCaptchaRequest(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
// 设置mock期望
mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil)
mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil)
mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil)
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil)
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil)
// 执行测试
err := logic.recordCaptchaRequest("13800138000")
// 验证结果
assert.NoError(t, err)
mockRedis.AssertExpectations(t)
}
func TestGetCaptchaProtectionStatus(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
// 设置mock期望
mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil)
mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil)
mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil)
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil)
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil)
// 执行测试
status, err := logic.GetCaptchaProtectionStatus("13800138000")
// 验证结果
assert.NoError(t, err)
assert.NotNil(t, status)
// 验证手机号状态
assert.Equal(t, false, status["mobileMinuteLimited"])
assert.Equal(t, int64(3), status["mobileHourCount"])
assert.Equal(t, int64(2), status["mobileHourRemaining"])
assert.Equal(t, int64(15), status["mobileDayCount"])
assert.Equal(t, int64(5), status["mobileDayRemaining"])
// 验证IP状态
assert.Equal(t, false, status["ipBanned"])
assert.Equal(t, int64(5), status["ipMinuteCount"])
assert.Equal(t, int64(5), status["ipMinuteRemaining"])
assert.Equal(t, int64(25), status["ipHourCount"])
assert.Equal(t, int64(25), status["ipHourRemaining"])
mockRedis.AssertExpectations(t)
}
func TestGetClientIP(t *testing.T) {
logic, _ := createTestLogic()
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "从上下文获取IP",
ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"),
expected: "192.168.1.100",
},
{
name: "上下文无IP - 返回默认IP",
ctx: context.Background(),
expected: "127.0.0.1",
},
{
name: "上下文IP为空 - 返回默认IP",
ctx: context.WithValue(context.Background(), "client_ip", ""),
expected: "127.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logic.ctx = tt.ctx
ip := logic.getClientIP()
assert.Equal(t, tt.expected, ip)
})
}
}