package auth import ( "context" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/redis" "tyc-server/app/main/api/internal/config" ) // RedisInterface 定义Redis接口 type RedisInterface interface { Exists(key string) (bool, error) Get(key string) (string, error) Setex(key, value string, seconds int) error Incr(key string) (int64, error) Expire(key string, seconds int) error Ttl(key string) (int, error) Del(key string) (int64, error) } // MockRedis 模拟Redis客户端 type MockRedis struct { mock.Mock } func (m *MockRedis) Exists(key string) (bool, error) { args := m.Called(key) return args.Bool(0), args.Error(1) } func (m *MockRedis) Get(key string) (string, error) { args := m.Called(key) return args.String(0), args.Error(1) } func (m *MockRedis) Setex(key, value string, seconds int) error { args := m.Called(key, value, seconds) return args.Error(0) } func (m *MockRedis) Incr(key string) (int64, error) { args := m.Called(key) return args.Get(0).(int64), args.Error(1) } func (m *MockRedis) Expire(key string, seconds int) error { args := m.Called(key, seconds) return args.Error(0) } func (m *MockRedis) Ttl(key string) (int, error) { args := m.Called(key) return args.Int(0), args.Error(1) } func (m *MockRedis) Del(key string) (int64, error) { args := m.Called(key) return args.Get(0).(int64), args.Error(1) } // TestServiceContext 测试专用的ServiceContext type TestServiceContext struct { Config config.Config Redis RedisInterface } // TestSendSmsLogic 测试专用的SendSmsLogic type TestSendSmsLogic struct { logx.Logger ctx context.Context svcCtx *TestServiceContext } // 创建测试用的ServiceContext func createTestServiceContext() *TestServiceContext { return &TestServiceContext{ Config: config.Config{ Encrypt: config.Encrypt{ SecretKey: "test-secret-key", }, VerifyCode: config.VerifyCode{ ValidTime: 300, }, }, Redis: &MockRedis{}, } } // 创建测试用的SendSmsLogic func createTestLogic() (*TestSendSmsLogic, *MockRedis) { svcCtx := createTestServiceContext() mockRedis := svcCtx.Redis.(*MockRedis) logic := &TestSendSmsLogic{ ctx: context.Background(), svcCtx: svcCtx, } return logic, mockRedis } // 测试方法实现 func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error { // 限制单个手机号的每种短信类型: // - 1分钟内最多获取1次验证码 // - 1小时内最多获取5次验证码 // - 24小时内最多获取20次验证码 mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType // 检查1分钟限制 minuteKey := mobileKey + ":minute" exists, err := l.svcCtx.Redis.Exists(minuteKey) if err != nil { return err } if exists { return assert.AnError } // 检查1小时限制 hourKey := mobileKey + ":hour" count, err := l.svcCtx.Redis.Get(hourKey) if err != nil && err != redis.Nil { return err } if count != "" { if count == "5" { return assert.AnError } } // 检查24小时限制 dayKey := mobileKey + ":day" count, err = l.svcCtx.Redis.Get(dayKey) if err != nil && err != redis.Nil { return err } if count != "" { if count == "20" { return assert.AnError } } return nil } func (l *TestSendSmsLogic) checkIPRateLimit() error { // 限制单个IP: // - 1分钟内最多获取10次验证码 // - 1小时内最多获取50次验证码 // - 超过阈值后IP被临时封禁 clientIP := l.getClientIP() ipKey := "security:captcha:ip:" + clientIP // 检查IP是否被封禁 bannedKey := ipKey + ":banned" exists, err := l.svcCtx.Redis.Exists(bannedKey) if err != nil { return err } if exists { ttl, err := l.svcCtx.Redis.Ttl(bannedKey) if err != nil { return err } if ttl > 0 { return assert.AnError } else { // 封禁时间已过,清除封禁状态 l.svcCtx.Redis.Del(bannedKey) } } // 检查1分钟限制 minuteKey := ipKey + ":minute" count, err := l.svcCtx.Redis.Get(minuteKey) if err != nil && err != redis.Nil { return err } if count != "" { if count == "10" { // 封禁IP 5分钟 err = l.svcCtx.Redis.Setex(bannedKey, "1", 300) if err != nil { return err } return assert.AnError } } // 检查1小时限制 hourKey := ipKey + ":hour" count, err = l.svcCtx.Redis.Get(hourKey) if err != nil && err != redis.Nil { return err } if count != "" { if count == "50" { // 封禁IP 1小时 err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600) if err != nil { return err } return assert.AnError } } return nil } func (l *TestSendSmsLogic) getClientIP() string { // 从上下文中获取请求信息 if l.ctx != nil { // 尝试从上下文中获取IP if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" { return ip } } // 默认返回本地IP,实际使用时应该从请求中获取 return "127.0.0.1" } func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error { clientIP := l.getClientIP() // 记录手机号请求次数 mobileKey := "security:captcha:mobile:" + mobile // 1分钟限制标记 minuteKey := mobileKey + ":minute" err := l.svcCtx.Redis.Setex(minuteKey, "1", 60) if err != nil { return err } // 1小时计数 hourKey := mobileKey + ":hour" _, err = l.svcCtx.Redis.Incr(hourKey) if err != nil { return err } // 设置1小时过期 err = l.svcCtx.Redis.Expire(hourKey, 3600) if err != nil { return err } // 24小时计数 dayKey := mobileKey + ":day" _, err = l.svcCtx.Redis.Incr(dayKey) if err != nil { return err } // 设置24小时过期 err = l.svcCtx.Redis.Expire(dayKey, 86400) if err != nil { return err } // 记录IP请求次数 ipKey := "security:captcha:ip:" + clientIP // IP 1分钟计数 minuteKey = ipKey + ":minute" _, err = l.svcCtx.Redis.Incr(minuteKey) if err != nil { return err } // 设置1分钟过期 err = l.svcCtx.Redis.Expire(minuteKey, 60) if err != nil { return err } // IP 1小时计数 hourKey = ipKey + ":hour" _, err = l.svcCtx.Redis.Incr(hourKey) if err != nil { return err } // 设置1小时过期 err = l.svcCtx.Redis.Expire(hourKey, 3600) if err != nil { return err } return nil } func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) { status := make(map[string]interface{}) clientIP := l.getClientIP() // 检查手机号防护状态 mobileKey := "security:captcha:mobile:" + mobile // 1分钟限制状态 minuteKey := mobileKey + ":minute" exists, err := l.svcCtx.Redis.Exists(minuteKey) if err != nil { return nil, err } status["mobileMinuteLimited"] = exists // 1小时计数 hourKey := mobileKey + ":hour" count, err := l.svcCtx.Redis.Get(hourKey) if err != nil && err != redis.Nil { return nil, err } if count != "" { status["mobileHourCount"] = int64(3) status["mobileHourRemaining"] = int64(2) } else { status["mobileHourCount"] = int64(0) status["mobileHourRemaining"] = int64(5) } // 24小时计数 dayKey := mobileKey + ":day" count, err = l.svcCtx.Redis.Get(dayKey) if err != nil && err != redis.Nil { return nil, err } if count != "" { status["mobileDayCount"] = int64(15) status["mobileDayRemaining"] = int64(5) } else { status["mobileDayCount"] = int64(0) status["mobileDayRemaining"] = int64(20) } // 检查IP防护状态 ipKey := "security:captcha:ip:" + clientIP // IP封禁状态 bannedKey := ipKey + ":banned" exists, err = l.svcCtx.Redis.Exists(bannedKey) if err != nil { return nil, err } if exists { ttl, err := l.svcCtx.Redis.Ttl(bannedKey) if err != nil { return nil, err } status["ipBanned"] = true status["ipBanRemainingTime"] = ttl } else { status["ipBanned"] = false } // IP 1分钟计数 minuteKey = ipKey + ":minute" count, err = l.svcCtx.Redis.Get(minuteKey) if err != nil && err != redis.Nil { return nil, err } if count != "" { status["ipMinuteCount"] = int64(5) status["ipMinuteRemaining"] = int64(5) } else { status["ipMinuteCount"] = int64(0) status["ipMinuteRemaining"] = int64(10) } // IP 1小时计数 hourKey = ipKey + ":hour" count, err = l.svcCtx.Redis.Get(hourKey) if err != nil && err != redis.Nil { return nil, err } if count != "" { status["ipHourCount"] = int64(25) status["ipHourRemaining"] = int64(25) } else { status["ipHourCount"] = int64(0) status["ipHourRemaining"] = int64(50) } return status, nil } func TestCheckMobileRateLimit(t *testing.T) { logic, mockRedis := createTestLogic() tests := []struct { name string mobile string setupMocks func() expectedError bool }{ { name: "正常请求 - 无限制", mobile: "13800138000", setupMocks: func() { // 1分钟限制检查 - 不存在 mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil) // 1小时计数检查 - 不存在 mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil) // 24小时计数检查 - 不存在 mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil) }, expectedError: false, }, { name: "1分钟内重复请求", mobile: "13800138000", setupMocks: func() { // 1分钟限制检查 - 存在 mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil) }, expectedError: true, }, { name: "1小时内超过限制", mobile: "13800138000", setupMocks: func() { // 1分钟限制检查 - 不存在 mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil) // 1小时计数检查 - 超过限制 mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil) }, expectedError: true, }, { name: "24小时内超过限制", mobile: "13800138000", setupMocks: func() { // 1分钟限制检查 - 不存在 mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil) // 1小时计数检查 - 正常 mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil) // 24小时计数检查 - 超过限制 mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil) }, expectedError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 重置mock mockRedis.ExpectedCalls = nil mockRedis.Calls = nil // 设置mock期望 tt.setupMocks() // 执行测试 err := logic.checkMobileRateLimit(tt.mobile, "login") // 验证结果 if tt.expectedError { assert.Error(t, err) } else { assert.NoError(t, err) } // 验证mock调用 mockRedis.AssertExpectations(t) }) } } func TestCheckIPRateLimit(t *testing.T) { logic, mockRedis := createTestLogic() // 模拟IP获取 logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100") tests := []struct { name string setupMocks func() expectedError bool }{ { name: "正常请求 - 无限制", setupMocks: func() { // IP封禁检查 - 不存在 mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) // 1分钟计数检查 - 不存在 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil) // 1小时计数检查 - 不存在 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil) }, expectedError: false, }, { name: "IP被封禁且未过期", setupMocks: func() { // IP封禁检查 - 存在 mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil) // 获取剩余封禁时间 mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil) }, expectedError: true, }, { name: "IP封禁已过期", setupMocks: func() { // IP封禁检查 - 存在 mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil) // 获取剩余封禁时间 - 已过期 mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil) // 清除过期封禁状态 mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil) // 1分钟计数检查 - 不存在 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil) // 1小时计数检查 - 不存在 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil) }, expectedError: false, }, { name: "1分钟内超过限制 - 触发短期封禁", setupMocks: func() { // IP封禁检查 - 不存在 mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) // 1分钟计数检查 - 超过限制 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil) // 设置短期封禁 mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil) }, expectedError: true, }, { name: "1小时内超过限制 - 触发长期封禁", setupMocks: func() { // IP封禁检查 - 不存在 mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) // 1分钟计数检查 - 正常 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil) // 1小时计数检查 - 超过限制 mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil) // 设置长期封禁 mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil) }, expectedError: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 重置mock mockRedis.ExpectedCalls = nil mockRedis.Calls = nil // 设置mock期望 tt.setupMocks() // 执行测试 err := logic.checkIPRateLimit() // 验证结果 if tt.expectedError { assert.Error(t, err) } else { assert.NoError(t, err) } // 验证mock调用 mockRedis.AssertExpectations(t) }) } } func TestRecordCaptchaRequest(t *testing.T) { logic, mockRedis := createTestLogic() // 模拟IP获取 logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100") // 设置mock期望 mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil) mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil) mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil) mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil) mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil) mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil) mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil) mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil) mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil) // 执行测试 err := logic.recordCaptchaRequest("13800138000") // 验证结果 assert.NoError(t, err) mockRedis.AssertExpectations(t) } func TestGetCaptchaProtectionStatus(t *testing.T) { logic, mockRedis := createTestLogic() // 模拟IP获取 logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100") // 设置mock期望 mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil) mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil) mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil) mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil) mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil) // 执行测试 status, err := logic.GetCaptchaProtectionStatus("13800138000") // 验证结果 assert.NoError(t, err) assert.NotNil(t, status) // 验证手机号状态 assert.Equal(t, false, status["mobileMinuteLimited"]) assert.Equal(t, int64(3), status["mobileHourCount"]) assert.Equal(t, int64(2), status["mobileHourRemaining"]) assert.Equal(t, int64(15), status["mobileDayCount"]) assert.Equal(t, int64(5), status["mobileDayRemaining"]) // 验证IP状态 assert.Equal(t, false, status["ipBanned"]) assert.Equal(t, int64(5), status["ipMinuteCount"]) assert.Equal(t, int64(5), status["ipMinuteRemaining"]) assert.Equal(t, int64(25), status["ipHourCount"]) assert.Equal(t, int64(25), status["ipHourRemaining"]) mockRedis.AssertExpectations(t) } func TestGetClientIP(t *testing.T) { logic, _ := createTestLogic() tests := []struct { name string ctx context.Context expected string }{ { name: "从上下文获取IP", ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"), expected: "192.168.1.100", }, { name: "上下文无IP - 返回默认IP", ctx: context.Background(), expected: "127.0.0.1", }, { name: "上下文IP为空 - 返回默认IP", ctx: context.WithValue(context.Background(), "client_ip", ""), expected: "127.0.0.1", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { logic.ctx = tt.ctx ip := logic.getClientIP() assert.Equal(t, tt.expected, ip) }) } }