672 lines
17 KiB
Go
672 lines
17 KiB
Go
package auth
|
||
|
||
import (
|
||
"context"
|
||
"testing"
|
||
|
||
"github.com/stretchr/testify/assert"
|
||
"github.com/stretchr/testify/mock"
|
||
"github.com/zeromicro/go-zero/core/logx"
|
||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||
|
||
"tyc-server/app/main/api/internal/config"
|
||
)
|
||
|
||
// RedisInterface 定义Redis接口
|
||
type RedisInterface interface {
|
||
Exists(key string) (bool, error)
|
||
Get(key string) (string, error)
|
||
Setex(key, value string, seconds int) error
|
||
Incr(key string) (int64, error)
|
||
Expire(key string, seconds int) error
|
||
Ttl(key string) (int, error)
|
||
Del(key string) (int64, error)
|
||
}
|
||
|
||
// MockRedis 模拟Redis客户端
|
||
type MockRedis struct {
|
||
mock.Mock
|
||
}
|
||
|
||
func (m *MockRedis) Exists(key string) (bool, error) {
|
||
args := m.Called(key)
|
||
return args.Bool(0), args.Error(1)
|
||
}
|
||
|
||
func (m *MockRedis) Get(key string) (string, error) {
|
||
args := m.Called(key)
|
||
return args.String(0), args.Error(1)
|
||
}
|
||
|
||
func (m *MockRedis) Setex(key, value string, seconds int) error {
|
||
args := m.Called(key, value, seconds)
|
||
return args.Error(0)
|
||
}
|
||
|
||
func (m *MockRedis) Incr(key string) (int64, error) {
|
||
args := m.Called(key)
|
||
return args.Get(0).(int64), args.Error(1)
|
||
}
|
||
|
||
func (m *MockRedis) Expire(key string, seconds int) error {
|
||
args := m.Called(key, seconds)
|
||
return args.Error(0)
|
||
}
|
||
|
||
func (m *MockRedis) Ttl(key string) (int, error) {
|
||
args := m.Called(key)
|
||
return args.Int(0), args.Error(1)
|
||
}
|
||
|
||
func (m *MockRedis) Del(key string) (int64, error) {
|
||
args := m.Called(key)
|
||
return args.Get(0).(int64), args.Error(1)
|
||
}
|
||
|
||
// TestServiceContext 测试专用的ServiceContext
|
||
type TestServiceContext struct {
|
||
Config config.Config
|
||
Redis RedisInterface
|
||
}
|
||
|
||
// TestSendSmsLogic 测试专用的SendSmsLogic
|
||
type TestSendSmsLogic struct {
|
||
logx.Logger
|
||
ctx context.Context
|
||
svcCtx *TestServiceContext
|
||
}
|
||
|
||
// 创建测试用的ServiceContext
|
||
func createTestServiceContext() *TestServiceContext {
|
||
return &TestServiceContext{
|
||
Config: config.Config{
|
||
Encrypt: config.Encrypt{
|
||
SecretKey: "test-secret-key",
|
||
},
|
||
VerifyCode: config.VerifyCode{
|
||
ValidTime: 300,
|
||
},
|
||
},
|
||
Redis: &MockRedis{},
|
||
}
|
||
}
|
||
|
||
// 创建测试用的SendSmsLogic
|
||
func createTestLogic() (*TestSendSmsLogic, *MockRedis) {
|
||
svcCtx := createTestServiceContext()
|
||
mockRedis := svcCtx.Redis.(*MockRedis)
|
||
|
||
logic := &TestSendSmsLogic{
|
||
ctx: context.Background(),
|
||
svcCtx: svcCtx,
|
||
}
|
||
|
||
return logic, mockRedis
|
||
}
|
||
|
||
// 测试方法实现
|
||
func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
|
||
// 限制单个手机号的每种短信类型:
|
||
// - 1分钟内最多获取1次验证码
|
||
// - 1小时内最多获取5次验证码
|
||
// - 24小时内最多获取20次验证码
|
||
|
||
mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType
|
||
|
||
// 检查1分钟限制
|
||
minuteKey := mobileKey + ":minute"
|
||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if exists {
|
||
return assert.AnError
|
||
}
|
||
|
||
// 检查1小时限制
|
||
hourKey := mobileKey + ":hour"
|
||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||
if err != nil && err != redis.Nil {
|
||
return err
|
||
}
|
||
if count != "" {
|
||
if count == "5" {
|
||
return assert.AnError
|
||
}
|
||
}
|
||
|
||
// 检查24小时限制
|
||
dayKey := mobileKey + ":day"
|
||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||
if err != nil && err != redis.Nil {
|
||
return err
|
||
}
|
||
if count != "" {
|
||
if count == "20" {
|
||
return assert.AnError
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (l *TestSendSmsLogic) checkIPRateLimit() error {
|
||
// 限制单个IP:
|
||
// - 1分钟内最多获取10次验证码
|
||
// - 1小时内最多获取50次验证码
|
||
// - 超过阈值后IP被临时封禁
|
||
|
||
clientIP := l.getClientIP()
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
|
||
// 检查IP是否被封禁
|
||
bannedKey := ipKey + ":banned"
|
||
exists, err := l.svcCtx.Redis.Exists(bannedKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if exists {
|
||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if ttl > 0 {
|
||
return assert.AnError
|
||
} else {
|
||
// 封禁时间已过,清除封禁状态
|
||
l.svcCtx.Redis.Del(bannedKey)
|
||
}
|
||
}
|
||
|
||
// 检查1分钟限制
|
||
minuteKey := ipKey + ":minute"
|
||
count, err := l.svcCtx.Redis.Get(minuteKey)
|
||
if err != nil && err != redis.Nil {
|
||
return err
|
||
}
|
||
if count != "" {
|
||
if count == "10" {
|
||
// 封禁IP 5分钟
|
||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return assert.AnError
|
||
}
|
||
}
|
||
|
||
// 检查1小时限制
|
||
hourKey := ipKey + ":hour"
|
||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||
if err != nil && err != redis.Nil {
|
||
return err
|
||
}
|
||
if count != "" {
|
||
if count == "50" {
|
||
// 封禁IP 1小时
|
||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return assert.AnError
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (l *TestSendSmsLogic) getClientIP() string {
|
||
// 从上下文中获取请求信息
|
||
if l.ctx != nil {
|
||
// 尝试从上下文中获取IP
|
||
if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" {
|
||
return ip
|
||
}
|
||
}
|
||
|
||
// 默认返回本地IP,实际使用时应该从请求中获取
|
||
return "127.0.0.1"
|
||
}
|
||
|
||
func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error {
|
||
clientIP := l.getClientIP()
|
||
|
||
// 记录手机号请求次数
|
||
mobileKey := "security:captcha:mobile:" + mobile
|
||
|
||
// 1分钟限制标记
|
||
minuteKey := mobileKey + ":minute"
|
||
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 1小时计数
|
||
hourKey := mobileKey + ":hour"
|
||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 设置1小时过期
|
||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 24小时计数
|
||
dayKey := mobileKey + ":day"
|
||
_, err = l.svcCtx.Redis.Incr(dayKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 设置24小时过期
|
||
err = l.svcCtx.Redis.Expire(dayKey, 86400)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// 记录IP请求次数
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
|
||
// IP 1分钟计数
|
||
minuteKey = ipKey + ":minute"
|
||
_, err = l.svcCtx.Redis.Incr(minuteKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 设置1分钟过期
|
||
err = l.svcCtx.Redis.Expire(minuteKey, 60)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
// IP 1小时计数
|
||
hourKey = ipKey + ":hour"
|
||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
// 设置1小时过期
|
||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) {
|
||
status := make(map[string]interface{})
|
||
clientIP := l.getClientIP()
|
||
|
||
// 检查手机号防护状态
|
||
mobileKey := "security:captcha:mobile:" + mobile
|
||
|
||
// 1分钟限制状态
|
||
minuteKey := mobileKey + ":minute"
|
||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
status["mobileMinuteLimited"] = exists
|
||
|
||
// 1小时计数
|
||
hourKey := mobileKey + ":hour"
|
||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||
if err != nil && err != redis.Nil {
|
||
return nil, err
|
||
}
|
||
if count != "" {
|
||
status["mobileHourCount"] = int64(3)
|
||
status["mobileHourRemaining"] = int64(2)
|
||
} else {
|
||
status["mobileHourCount"] = int64(0)
|
||
status["mobileHourRemaining"] = int64(5)
|
||
}
|
||
|
||
// 24小时计数
|
||
dayKey := mobileKey + ":day"
|
||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||
if err != nil && err != redis.Nil {
|
||
return nil, err
|
||
}
|
||
if count != "" {
|
||
status["mobileDayCount"] = int64(15)
|
||
status["mobileDayRemaining"] = int64(5)
|
||
} else {
|
||
status["mobileDayCount"] = int64(0)
|
||
status["mobileDayRemaining"] = int64(20)
|
||
}
|
||
|
||
// 检查IP防护状态
|
||
ipKey := "security:captcha:ip:" + clientIP
|
||
|
||
// IP封禁状态
|
||
bannedKey := ipKey + ":banned"
|
||
exists, err = l.svcCtx.Redis.Exists(bannedKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if exists {
|
||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
status["ipBanned"] = true
|
||
status["ipBanRemainingTime"] = ttl
|
||
} else {
|
||
status["ipBanned"] = false
|
||
}
|
||
|
||
// IP 1分钟计数
|
||
minuteKey = ipKey + ":minute"
|
||
count, err = l.svcCtx.Redis.Get(minuteKey)
|
||
if err != nil && err != redis.Nil {
|
||
return nil, err
|
||
}
|
||
if count != "" {
|
||
status["ipMinuteCount"] = int64(5)
|
||
status["ipMinuteRemaining"] = int64(5)
|
||
} else {
|
||
status["ipMinuteCount"] = int64(0)
|
||
status["ipMinuteRemaining"] = int64(10)
|
||
}
|
||
|
||
// IP 1小时计数
|
||
hourKey = ipKey + ":hour"
|
||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||
if err != nil && err != redis.Nil {
|
||
return nil, err
|
||
}
|
||
if count != "" {
|
||
status["ipHourCount"] = int64(25)
|
||
status["ipHourRemaining"] = int64(25)
|
||
} else {
|
||
status["ipHourCount"] = int64(0)
|
||
status["ipHourRemaining"] = int64(50)
|
||
}
|
||
|
||
return status, nil
|
||
}
|
||
|
||
func TestCheckMobileRateLimit(t *testing.T) {
|
||
logic, mockRedis := createTestLogic()
|
||
|
||
tests := []struct {
|
||
name string
|
||
mobile string
|
||
setupMocks func()
|
||
expectedError bool
|
||
}{
|
||
{
|
||
name: "正常请求 - 无限制",
|
||
mobile: "13800138000",
|
||
setupMocks: func() {
|
||
// 1分钟限制检查 - 不存在
|
||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||
// 1小时计数检查 - 不存在
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
|
||
// 24小时计数检查 - 不存在
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil)
|
||
},
|
||
expectedError: false,
|
||
},
|
||
{
|
||
name: "1分钟内重复请求",
|
||
mobile: "13800138000",
|
||
setupMocks: func() {
|
||
// 1分钟限制检查 - 存在
|
||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil)
|
||
},
|
||
expectedError: true,
|
||
},
|
||
{
|
||
name: "1小时内超过限制",
|
||
mobile: "13800138000",
|
||
setupMocks: func() {
|
||
// 1分钟限制检查 - 不存在
|
||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||
// 1小时计数检查 - 超过限制
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil)
|
||
},
|
||
expectedError: true,
|
||
},
|
||
{
|
||
name: "24小时内超过限制",
|
||
mobile: "13800138000",
|
||
setupMocks: func() {
|
||
// 1分钟限制检查 - 不存在
|
||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||
// 1小时计数检查 - 正常
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
|
||
// 24小时计数检查 - 超过限制
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil)
|
||
},
|
||
expectedError: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
// 重置mock
|
||
mockRedis.ExpectedCalls = nil
|
||
mockRedis.Calls = nil
|
||
|
||
// 设置mock期望
|
||
tt.setupMocks()
|
||
|
||
// 执行测试
|
||
err := logic.checkMobileRateLimit(tt.mobile, "login")
|
||
|
||
// 验证结果
|
||
if tt.expectedError {
|
||
assert.Error(t, err)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// 验证mock调用
|
||
mockRedis.AssertExpectations(t)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestCheckIPRateLimit(t *testing.T) {
|
||
logic, mockRedis := createTestLogic()
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||
|
||
tests := []struct {
|
||
name string
|
||
setupMocks func()
|
||
expectedError bool
|
||
}{
|
||
{
|
||
name: "正常请求 - 无限制",
|
||
setupMocks: func() {
|
||
// IP封禁检查 - 不存在
|
||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||
// 1分钟计数检查 - 不存在
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||
// 1小时计数检查 - 不存在
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
|
||
},
|
||
expectedError: false,
|
||
},
|
||
{
|
||
name: "IP被封禁且未过期",
|
||
setupMocks: func() {
|
||
// IP封禁检查 - 存在
|
||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
|
||
// 获取剩余封禁时间
|
||
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil)
|
||
},
|
||
expectedError: true,
|
||
},
|
||
{
|
||
name: "IP封禁已过期",
|
||
setupMocks: func() {
|
||
// IP封禁检查 - 存在
|
||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
|
||
// 获取剩余封禁时间 - 已过期
|
||
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil)
|
||
// 清除过期封禁状态
|
||
mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil)
|
||
// 1分钟计数检查 - 不存在
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||
// 1小时计数检查 - 不存在
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
|
||
},
|
||
expectedError: false,
|
||
},
|
||
{
|
||
name: "1分钟内超过限制 - 触发短期封禁",
|
||
setupMocks: func() {
|
||
// IP封禁检查 - 不存在
|
||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||
// 1分钟计数检查 - 超过限制
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil)
|
||
// 设置短期封禁
|
||
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil)
|
||
},
|
||
expectedError: true,
|
||
},
|
||
{
|
||
name: "1小时内超过限制 - 触发长期封禁",
|
||
setupMocks: func() {
|
||
// IP封禁检查 - 不存在
|
||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||
// 1分钟计数检查 - 正常
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||
// 1小时计数检查 - 超过限制
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil)
|
||
// 设置长期封禁
|
||
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil)
|
||
},
|
||
expectedError: true,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
// 重置mock
|
||
mockRedis.ExpectedCalls = nil
|
||
mockRedis.Calls = nil
|
||
|
||
// 设置mock期望
|
||
tt.setupMocks()
|
||
|
||
// 执行测试
|
||
err := logic.checkIPRateLimit()
|
||
|
||
// 验证结果
|
||
if tt.expectedError {
|
||
assert.Error(t, err)
|
||
} else {
|
||
assert.NoError(t, err)
|
||
}
|
||
|
||
// 验证mock调用
|
||
mockRedis.AssertExpectations(t)
|
||
})
|
||
}
|
||
}
|
||
|
||
func TestRecordCaptchaRequest(t *testing.T) {
|
||
logic, mockRedis := createTestLogic()
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||
|
||
// 设置mock期望
|
||
mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil)
|
||
mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil)
|
||
mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil)
|
||
mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil)
|
||
mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil)
|
||
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil)
|
||
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil)
|
||
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil)
|
||
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil)
|
||
|
||
// 执行测试
|
||
err := logic.recordCaptchaRequest("13800138000")
|
||
|
||
// 验证结果
|
||
assert.NoError(t, err)
|
||
mockRedis.AssertExpectations(t)
|
||
}
|
||
|
||
func TestGetCaptchaProtectionStatus(t *testing.T) {
|
||
logic, mockRedis := createTestLogic()
|
||
|
||
// 模拟IP获取
|
||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||
|
||
// 设置mock期望
|
||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil)
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil)
|
||
mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil)
|
||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil)
|
||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil)
|
||
|
||
// 执行测试
|
||
status, err := logic.GetCaptchaProtectionStatus("13800138000")
|
||
|
||
// 验证结果
|
||
assert.NoError(t, err)
|
||
assert.NotNil(t, status)
|
||
|
||
// 验证手机号状态
|
||
assert.Equal(t, false, status["mobileMinuteLimited"])
|
||
assert.Equal(t, int64(3), status["mobileHourCount"])
|
||
assert.Equal(t, int64(2), status["mobileHourRemaining"])
|
||
assert.Equal(t, int64(15), status["mobileDayCount"])
|
||
assert.Equal(t, int64(5), status["mobileDayRemaining"])
|
||
|
||
// 验证IP状态
|
||
assert.Equal(t, false, status["ipBanned"])
|
||
assert.Equal(t, int64(5), status["ipMinuteCount"])
|
||
assert.Equal(t, int64(5), status["ipMinuteRemaining"])
|
||
assert.Equal(t, int64(25), status["ipHourCount"])
|
||
assert.Equal(t, int64(25), status["ipHourRemaining"])
|
||
|
||
mockRedis.AssertExpectations(t)
|
||
}
|
||
|
||
func TestGetClientIP(t *testing.T) {
|
||
logic, _ := createTestLogic()
|
||
|
||
tests := []struct {
|
||
name string
|
||
ctx context.Context
|
||
expected string
|
||
}{
|
||
{
|
||
name: "从上下文获取IP",
|
||
ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"),
|
||
expected: "192.168.1.100",
|
||
},
|
||
{
|
||
name: "上下文无IP - 返回默认IP",
|
||
ctx: context.Background(),
|
||
expected: "127.0.0.1",
|
||
},
|
||
{
|
||
name: "上下文IP为空 - 返回默认IP",
|
||
ctx: context.WithValue(context.Background(), "client_ip", ""),
|
||
expected: "127.0.0.1",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
logic.ctx = tt.ctx
|
||
ip := logic.getClientIP()
|
||
assert.Equal(t, tt.expected, ip)
|
||
})
|
||
}
|
||
}
|