This commit is contained in:
2025-08-31 14:18:31 +08:00
parent 30ace3faa2
commit 4be4d6b6da
19 changed files with 3472 additions and 7 deletions

View File

@@ -92,3 +92,26 @@ Tianyuanapi:
Timeout: 60
VerifyConfig:
TwoFactor: true
Security:
RateLimit:
Enabled: true
WindowSize: 30
MaxRequests: 50
TriggerThreshold: 3 # 触发5次频率限制后加入黑名单
TriggerWindow: 24 # 24小时内统计触发次数
IPBlacklist:
Enabled: true
UserBlacklist:
Enabled: true
AnomalyDetection:
Enabled: true
BurstAttack:
Enabled: true # 启用短时并发攻击检测
TimeWindow: 1 # 1秒内检测
MaxConcurrent: 15 # 最大20个并发请求
Logging:
UserOperationLogDir: "./logs/user_operations"
MaxFileSize: 104857600 # 100MB
LogLevel: "info"
EnableConsole: true
EnableFile: true

View File

@@ -79,3 +79,32 @@ Tianyuanapi:
Timeout: 60
VerifyConfig:
TwoFactor: true
Security:
RateLimit:
Enabled: true
WindowSize: 30
MaxRequests: 50
TriggerThreshold: 3 # 触发5次频率限制后加入黑名单
TriggerWindow: 24 # 24小时内统计触发次数
IPBlacklist:
Enabled: true
UserBlacklist:
Enabled: true
AnomalyDetection:
Enabled: true
BurstAttack:
Enabled: true # 启用短时并发攻击检测
TimeWindow: 1 # 1秒内检测
MaxConcurrent: 15 # 最大20个并发请求
Logging:
UserOperationLogDir: "./logs/user_operations"
MaxFileSize: 104857600 # 100MB
LogLevel: "info"
EnableConsole: true
EnableFile: true
Logging:
UserOperationLogDir: "./logs/user_operations"
MaxFileSize: 104857600 # 10MB
LogLevel: "info"
EnableConsole: true
EnableFile: true

View File

@@ -24,6 +24,8 @@ type Config struct {
CleanTask CleanTask
Tianyuanapi TianyuanapiConfig
VerifyConfig VerifyConfig
Security SecurityConfig // 安全配置
Logging LoggingConfig // 日志配置
}
// JwtAuth 用于 JWT 鉴权配置

View File

@@ -0,0 +1,10 @@
package config
// LoggingConfig 日志配置
type LoggingConfig struct {
UserOperationLogDir string `json:"userOperationLogDir" yaml:"userOperationLogDir"` // 用户操作日志目录
MaxFileSize int64 `json:"maxFileSize" yaml:"maxFileSize"` // 单个日志文件最大大小(字节)
LogLevel string `json:"logLevel" yaml:"logLevel"` // 日志级别
EnableConsole bool `json:"enableConsole" yaml:"enableConsole"` // 是否启用控制台输出
EnableFile bool `json:"enableFile" yaml:"enableFile"` // 是否启用文件输出
}

View File

@@ -0,0 +1,36 @@
package config
// SecurityConfig 安全配置
type SecurityConfig struct {
// 频率限制配置
RateLimit struct {
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用频率限制
WindowSize int64 `json:"windowSize" yaml:"windowSize"` // 时间窗口大小(秒)
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"` // 最大请求次数
// 频率限制触发后的黑名单升级配置
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` // 触发多少次频率限制后加入黑名单
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"` // 触发次数统计时间窗口(小时)
} `json:"rateLimit" yaml:"rateLimit"`
// IP黑名单配置
IPBlacklist struct {
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用IP黑名单
} `json:"ipBlacklist" yaml:"ipBlacklist"`
// 用户黑名单配置
UserBlacklist struct {
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用用户黑名单
} `json:"userBlacklist" yaml:"userBlacklist"`
// 异常检测配置
AnomalyDetection struct {
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用异常检测
} `json:"anomalyDetection" yaml:"anomalyDetection"`
// 短时并发攻击检测配置
BurstAttack struct {
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用短时并发攻击检测
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"` // 检测时间窗口(秒)
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` // 最大并发请求数
} `json:"burstAttack" yaml:"burstAttack"`
}

View File

@@ -1,9 +1,20 @@
package auth
// 验证码防护说明:
// 1. checkCaptchaProtection: 在发送短信前检查防护状态
// 2. recordCaptchaRequest: 在短信发送成功后记录请求次数
// 3. GetCaptchaProtectionStatus: 获取防护状态用于调试和监控
//
// 防护规则:
// - 单个手机号1分钟内最多1次1小时内最多5次24小时内最多20次
// - 单个IP1分钟内最多10次1小时内最多50次超过阈值后IP被临时封禁
// - 防止验证码爆破攻击,控制短信发送成本
import (
"context"
"fmt"
"math/rand"
"strconv"
"time"
"tyc-server/common/xerr"
"tyc-server/pkg/lzkit/crypto"
@@ -18,6 +29,7 @@ import (
"github.com/alibabacloud-go/tea-utils/v2/service"
"github.com/alibabacloud-go/tea/tea"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
)
type SendSmsLogic struct {
@@ -40,6 +52,12 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
if err != nil {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err)
}
// 验证码防护检查
if err := l.checkCaptchaProtection(req.Mobile, req.ActionType); err != nil {
return err
}
// 检查手机号是否在一分钟内已发送过验证码
limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile)
exists, err := l.svcCtx.Redis.Exists(limitCodeKey)
@@ -62,6 +80,13 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
if *smsResp.Body.Code != "OK" {
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message)
}
// 短信发送成功,记录请求次数
if err := l.recordCaptchaRequest(req.Mobile, req.ActionType); err != nil {
logx.Errorf("记录验证码请求失败: %v", err)
// 不影响主流程,只记录日志
}
codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile)
// 将验证码保存到 Redis设置过期时间
err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟
@@ -103,3 +128,324 @@ func (l *SendSmsLogic) sendSmsRequest(mobile, code string) (*dysmsapi.SendSmsRes
runtime := &service.RuntimeOptions{}
return cli.SendSmsWithOptions(request, runtime)
}
// checkCaptchaProtection 检查验证码获取防护
func (l *SendSmsLogic) checkCaptchaProtection(mobile string, actionType string) error {
// 1. 检查手机号获取验证码频率
if err := l.checkMobileRateLimit(mobile, actionType); err != nil {
return err
}
// 2. 检查IP获取验证码频率
if err := l.checkIPRateLimit(); err != nil {
return err
}
return nil
}
// checkMobileRateLimit 检查手机号频率限制
func (l *SendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
// 限制单个手机号的每种短信类型:
// - 1分钟内最多获取1次验证码
// - 1小时内最多获取5次验证码
// - 24小时内最多获取20次验证码
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 检查1分钟限制
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
logx.Errorf("检查手机号1分钟限制失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if exists {
return errors.Wrapf(xerr.NewErrMsg("1分钟内已获取过验证码请稍后再试"), "验证码防护 - 手机号1分钟内重复请求: %s", mobile)
}
// 检查1小时限制
hourKey := fmt.Sprintf("%s:hour", mobileKey)
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取手机号1小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 5 {
return errors.Wrapf(xerr.NewErrMsg("1小时内获取验证码次数过多请稍后再试"), "验证码防护 - 手机号1小时内超过限制: %s", mobile)
}
}
// 检查24小时限制
dayKey := fmt.Sprintf("%s:day", mobileKey)
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取手机号24小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if dayCount, _ := strconv.ParseInt(count, 10, 64); dayCount >= 20 {
return errors.Wrapf(xerr.NewErrMsg("24小时内获取验证码次数过多请稍后再试"), "验证码防护 - 手机号24小时内超过限制: %s", mobile)
}
}
return nil
}
// checkIPRateLimit 检查IP频率限制
func (l *SendSmsLogic) checkIPRateLimit() error {
// 限制单个IP
// - 1分钟内最多获取10次验证码
// - 1小时内最多获取50次验证码
// - 超过阈值后IP被临时封禁
clientIP := l.getClientIP()
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// 检查IP是否被封禁
bannedKey := fmt.Sprintf("%s:banned", ipKey)
exists, err := l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
logx.Errorf("检查IP封禁状态失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
logx.Errorf("获取IP封禁剩余时间失败: %v", err)
}
if ttl > 0 {
return errors.Wrapf(xerr.NewErrMsg(fmt.Sprintf("IP被临时封禁请%d秒后再试", ttl)), "验证码防护 - IP被临时封禁: %s", clientIP)
} else {
// 封禁时间已过,清除封禁状态
l.svcCtx.Redis.Del(bannedKey)
}
}
// 检查1分钟限制
minuteKey := fmt.Sprintf("%s:minute", ipKey)
count, err := l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取IP1分钟计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if minuteCount, _ := strconv.ParseInt(count, 10, 64); minuteCount >= 10 {
// 封禁IP 5分钟
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
if err != nil {
logx.Errorf("封禁IP失败: %v", err)
}
logx.Errorf("验证码防护 - IP被临时封禁: %s, 封禁时间: 300秒", clientIP)
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁已被临时封禁5分钟"), "验证码防护 - IP被临时封禁: %s", clientIP)
}
}
// 检查1小时限制
hourKey := fmt.Sprintf("%s:hour", ipKey)
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取IP1小时计数失败: %v", err)
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
}
if count != "" {
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 50 {
// 封禁IP 1小时
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
if err != nil {
logx.Errorf("封禁IP失败: %v", err)
}
logx.Errorf("验证码防护 - IP被长期封禁: %s, 封禁时间: 3600秒", clientIP)
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁已被临时封禁1小时"), "验证码防护 - IP被长期封禁: %s", clientIP)
}
}
return nil
}
// getClientIP 获取客户端真实IP
func (l *SendSmsLogic) getClientIP() string {
if l.ctx != nil {
// 尝试从上下文中获取IP
if ip, ok := l.ctx.Value("client_ip").(string); ok {
return ip
}
}
// 默认返回本地IP实际使用时应该从请求中获取
return "127.0.0.1"
}
// recordCaptchaRequest 记录验证码请求次数
func (l *SendSmsLogic) recordCaptchaRequest(mobile string, actionType string) error {
clientIP := l.getClientIP()
// 记录手机号请求次数
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 1分钟限制标记
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
if err != nil {
logx.Errorf("设置手机号1分钟限制标记失败: %v", err)
}
// 1小时计数
hourKey := fmt.Sprintf("%s:hour", mobileKey)
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
logx.Errorf("增加手机号1小时计数失败: %v", err)
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
logx.Errorf("设置手机号1小时计数过期时间失败: %v", err)
}
// 24小时计数
dayKey := fmt.Sprintf("%s:day", mobileKey)
_, err = l.svcCtx.Redis.Incr(dayKey)
if err != nil {
logx.Errorf("增加手机号24小时计数失败: %v", err)
}
// 设置24小时过期
err = l.svcCtx.Redis.Expire(dayKey, 86400)
if err != nil {
logx.Errorf("设置手机号24小时计数过期时间失败: %v", err)
}
// 记录IP请求次数
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// IP 1分钟计数
minuteKey = fmt.Sprintf("%s:minute", ipKey)
_, err = l.svcCtx.Redis.Incr(minuteKey)
if err != nil {
logx.Errorf("增加IP1分钟计数失败: %v", err)
}
// 设置1分钟过期
err = l.svcCtx.Redis.Expire(minuteKey, 60)
if err != nil {
logx.Errorf("设置IP1分钟计数过期时间失败: %v", err)
}
// IP 1小时计数
hourKey = fmt.Sprintf("%s:hour", ipKey)
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
logx.Errorf("增加IP1小时计数失败: %v", err)
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
logx.Errorf("设置IP1小时计数过期时间失败: %v", err)
}
return nil
}
// GetCaptchaProtectionStatus 获取验证码防护状态(用于调试和监控)
func (l *SendSmsLogic) GetCaptchaProtectionStatus(mobile string, actionType string) (map[string]interface{}, error) {
status := make(map[string]interface{})
clientIP := l.getClientIP()
// 检查手机号防护状态
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
// 1分钟限制状态
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return nil, err
}
status["mobileMinuteLimited"] = exists
// 1小时计数
hourKey := fmt.Sprintf("%s:hour", mobileKey)
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["mobileHourCount"] = hourCount
status["mobileHourRemaining"] = 5 - hourCount
}
} else {
status["mobileHourCount"] = 0
status["mobileHourRemaining"] = 5
}
// 24小时计数
dayKey := fmt.Sprintf("%s:day", mobileKey)
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if dayCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["mobileDayCount"] = dayCount
status["mobileDayRemaining"] = 20 - dayCount
}
} else {
status["mobileDayCount"] = 0
status["mobileDayRemaining"] = 20
}
// 检查IP防护状态
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
// IP封禁状态
bannedKey := fmt.Sprintf("%s:banned", ipKey)
exists, err = l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return nil, err
}
status["ipBanned"] = exists
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err == nil {
status["ipBanRemaining"] = ttl
}
}
// IP 1分钟计数
minuteKey = fmt.Sprintf("%s:minute", ipKey)
count, err = l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if minuteCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["ipMinuteCount"] = minuteCount
status["ipMinuteRemaining"] = 10 - minuteCount
}
} else {
status["ipMinuteCount"] = 0
status["ipMinuteRemaining"] = 10
}
// IP 1小时计数
hourKey = fmt.Sprintf("%s:hour", ipKey)
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
status["ipHourCount"] = hourCount
status["ipHourRemaining"] = 50 - hourCount
}
} else {
status["ipHourCount"] = 0
status["ipHourRemaining"] = 50
}
// 添加基本信息
status["mobile"] = mobile
status["actionType"] = actionType
status["clientIP"] = clientIP
return status, nil
}

View File

@@ -0,0 +1,381 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stores/redis"
"tyc-server/app/main/api/internal/config"
"tyc-server/app/main/api/internal/svc"
)
// 集成测试配置
var integrationTestConfig = config.Config{
Encrypt: config.Encrypt{
SecretKey: "test-secret-key",
},
VerifyCode: config.VerifyCode{
ValidTime: 300,
},
}
// 创建集成测试用的ServiceContext
func createIntegrationTestServiceContext() *svc.ServiceContext {
redisConf := redis.RedisConf{
Host: "127.0.0.1:6379",
Type: "node",
Pass: "",
}
return &svc.ServiceContext{
Config: integrationTestConfig,
Redis: redis.MustNewRedis(redisConf),
}
}
// 清理测试数据
func cleanupTestData(logic *SendSmsLogic, mobile, clientIP string) {
// 清理手机号相关数据
mobileKey := "security:captcha:mobile:" + mobile + ":login"
logic.svcCtx.Redis.Del(mobileKey+":minute", mobileKey+":hour", mobileKey+":day")
// 清理IP相关数据
ipKey := "security:captcha:ip:" + clientIP
logic.svcCtx.Redis.Del(ipKey+":minute", ipKey+":hour", ipKey+":banned")
}
// 创建集成测试用的SendSmsLogic
func createIntegrationTestLogic() *SendSmsLogic {
svcCtx := createIntegrationTestServiceContext()
return &SendSmsLogic{
ctx: context.Background(),
svcCtx: svcCtx,
}
}
// 跳过集成测试的标志
var skipIntegrationTests = true
func TestIntegrationCheckMobileRateLimit(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
// 清理测试数据
defer cleanupTestData(logic, mobile, "192.168.1.100")
t.Run("正常请求 - 无限制", func(t *testing.T) {
err := logic.checkMobileRateLimit(mobile, "login")
assert.NoError(t, err)
})
t.Run("1分钟内重复请求", func(t *testing.T) {
// 设置1分钟限制标记
mobileKey := "security:captcha:mobile:" + mobile + ":login"
err := logic.svcCtx.Redis.Setex(mobileKey+":minute", "1", 60)
assert.NoError(t, err)
// 再次请求应该被拒绝
err = logic.checkMobileRateLimit(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
// 清理限制标记
logic.svcCtx.Redis.Del(mobileKey + ":minute")
})
t.Run("1小时内超过限制", func(t *testing.T) {
// 设置1小时计数为5达到限制
mobileKey := "security:captcha:mobile:" + mobile + ":login"
err := logic.svcCtx.Redis.Setex(mobileKey+":hour", "5", 3600)
assert.NoError(t, err)
// 请求应该被拒绝
err = logic.checkMobileRateLimit(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "1小时内获取验证码次数过多")
// 清理计数
logic.svcCtx.Redis.Del(mobileKey + ":hour")
})
t.Run("24小时内超过限制", func(t *testing.T) {
// 设置24小时计数为20达到限制
mobileKey := "security:captcha:mobile:" + mobile + ":login"
err := logic.svcCtx.Redis.Setex(mobileKey+":day", "20", 86400)
assert.NoError(t, err)
// 请求应该被拒绝
err = logic.checkMobileRateLimit(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "24小时内获取验证码次数过多")
// 清理计数
logic.svcCtx.Redis.Del(mobileKey + ":day")
})
}
func TestIntegrationCheckIPRateLimit(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("正常请求 - 无限制", func(t *testing.T) {
err := logic.checkIPRateLimit()
assert.NoError(t, err)
})
t.Run("IP被封禁且未过期", func(t *testing.T) {
// 设置IP封禁状态
ipKey := "security:captcha:ip:" + clientIP
err := logic.svcCtx.Redis.Setex(ipKey+":banned", "1", 300)
assert.NoError(t, err)
// 请求应该被拒绝
err = logic.checkIPRateLimit()
assert.Error(t, err)
assert.Contains(t, err.Error(), "IP被临时封禁")
// 清理封禁状态
logic.svcCtx.Redis.Del(ipKey + ":banned")
})
t.Run("1分钟内超过限制 - 触发短期封禁", func(t *testing.T) {
// 设置1分钟计数为10达到限制
ipKey := "security:captcha:ip:" + clientIP
err := logic.svcCtx.Redis.Setex(ipKey+":minute", "10", 60)
assert.NoError(t, err)
// 请求应该被拒绝并触发封禁
err = logic.checkIPRateLimit()
assert.Error(t, err)
assert.Contains(t, err.Error(), "IP请求过于频繁已被临时封禁5分钟")
// 验证IP被封禁
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
assert.NoError(t, err)
assert.True(t, exists)
// 清理数据
logic.svcCtx.Redis.Del(ipKey + ":minute")
logic.svcCtx.Redis.Del(ipKey + ":banned")
})
t.Run("1小时内超过限制 - 触发长期封禁", func(t *testing.T) {
// 设置1小时计数为50达到限制
ipKey := "security:captcha:ip:" + clientIP
err := logic.svcCtx.Redis.Setex(ipKey+":hour", "50", 3600)
assert.NoError(t, err)
// 请求应该被拒绝并触发封禁
err = logic.checkIPRateLimit()
assert.Error(t, err)
assert.Contains(t, err.Error(), "IP请求过于频繁已被临时封禁1小时")
// 验证IP被封禁
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
assert.NoError(t, err)
assert.True(t, exists)
// 清理数据
logic.svcCtx.Redis.Del(ipKey + ":hour")
logic.svcCtx.Redis.Del(ipKey + ":banned")
})
}
func TestIntegrationRecordCaptchaRequest(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("记录手机号请求次数", func(t *testing.T) {
err := logic.recordCaptchaRequest(mobile, "login")
assert.NoError(t, err)
// 验证1分钟限制标记
mobileKey := "security:captcha:mobile:" + mobile + ":login"
exists, err := logic.svcCtx.Redis.Exists(mobileKey + ":minute")
assert.NoError(t, err)
assert.True(t, exists)
// 验证1小时计数
count, err := logic.svcCtx.Redis.Get(mobileKey + ":hour")
assert.NoError(t, err)
assert.Equal(t, "1", count)
// 验证24小时计数
count, err = logic.svcCtx.Redis.Get(mobileKey + ":day")
assert.NoError(t, err)
assert.Equal(t, "1", count)
})
t.Run("记录IP请求次数", func(t *testing.T) {
err := logic.recordCaptchaRequest(mobile, "register")
assert.NoError(t, err)
// 验证IP 1分钟计数
ipKey := "security:captcha:ip:" + clientIP
count, err := logic.svcCtx.Redis.Get(ipKey + ":minute")
assert.NoError(t, err)
assert.Equal(t, "2", count) // 第二次调用
// 验证IP 1小时计数
count, err = logic.svcCtx.Redis.Get(ipKey + ":hour")
assert.NoError(t, err)
assert.Equal(t, "2", count) // 第二次调用
})
}
func TestIntegrationGetCaptchaProtectionStatus(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("获取防护状态", func(t *testing.T) {
status, err := logic.GetCaptchaProtectionStatus(mobile, "login")
assert.NoError(t, err)
assert.NotNil(t, status)
// 验证基本信息
assert.Equal(t, mobile, status["mobile"])
assert.Equal(t, "login", status["actionType"])
assert.Equal(t, clientIP, status["clientIP"])
// 验证手机号状态
assert.False(t, status["mobileMinuteLimited"].(bool))
assert.Equal(t, int64(0), status["mobileHourCount"])
assert.Equal(t, int64(5), status["mobileHourRemaining"])
assert.Equal(t, int64(0), status["mobileDayCount"])
assert.Equal(t, int64(20), status["mobileDayRemaining"])
// 验证IP状态
assert.False(t, status["ipBanned"].(bool))
assert.Equal(t, int64(0), status["ipMinuteCount"])
assert.Equal(t, int64(10), status["ipMinuteRemaining"])
assert.Equal(t, int64(0), status["ipHourCount"])
assert.Equal(t, int64(50), status["ipHourRemaining"])
})
}
func TestIntegrationEndToEnd(t *testing.T) {
if skipIntegrationTests {
t.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
t.Run("完整流程测试", func(t *testing.T) {
// 第一次请求 - 应该成功
err := logic.checkCaptchaProtection(mobile, "login")
assert.NoError(t, err)
// 记录请求
err = logic.recordCaptchaRequest(mobile, "login")
assert.NoError(t, err)
// 第二次请求 - 应该被拒绝1分钟限制
err = logic.checkCaptchaProtection(mobile, "login")
assert.Error(t, err)
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
// 不同短信类型 - 应该成功
err = logic.checkCaptchaProtection(mobile, "register")
assert.NoError(t, err)
})
}
// 性能基准测试
func BenchmarkCheckCaptchaProtection(b *testing.B) {
if skipIntegrationTests {
b.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// 使用不同的手机号避免限制
testMobile := mobile + "_" + string(rune(i%10))
logic.checkCaptchaProtection(testMobile, "login")
}
}
func BenchmarkRecordCaptchaRequest(b *testing.B) {
if skipIntegrationTests {
b.Skip("跳过集成测试需要Redis服务")
}
logic := createIntegrationTestLogic()
mobile := "13800138000"
clientIP := "192.168.1.100"
// 清理测试数据
defer cleanupTestData(logic, mobile, clientIP)
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
b.ResetTimer()
for i := 0; i < b.N; i++ {
// 使用不同的手机号避免限制
testMobile := mobile + "_" + string(rune(i%10))
logic.recordCaptchaRequest(testMobile, "login")
}
}

View File

@@ -0,0 +1,671 @@
package auth
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
"tyc-server/app/main/api/internal/config"
)
// RedisInterface 定义Redis接口
type RedisInterface interface {
Exists(key string) (bool, error)
Get(key string) (string, error)
Setex(key, value string, seconds int) error
Incr(key string) (int64, error)
Expire(key string, seconds int) error
Ttl(key string) (int, error)
Del(key string) (int64, error)
}
// MockRedis 模拟Redis客户端
type MockRedis struct {
mock.Mock
}
func (m *MockRedis) Exists(key string) (bool, error) {
args := m.Called(key)
return args.Bool(0), args.Error(1)
}
func (m *MockRedis) Get(key string) (string, error) {
args := m.Called(key)
return args.String(0), args.Error(1)
}
func (m *MockRedis) Setex(key, value string, seconds int) error {
args := m.Called(key, value, seconds)
return args.Error(0)
}
func (m *MockRedis) Incr(key string) (int64, error) {
args := m.Called(key)
return args.Get(0).(int64), args.Error(1)
}
func (m *MockRedis) Expire(key string, seconds int) error {
args := m.Called(key, seconds)
return args.Error(0)
}
func (m *MockRedis) Ttl(key string) (int, error) {
args := m.Called(key)
return args.Int(0), args.Error(1)
}
func (m *MockRedis) Del(key string) (int64, error) {
args := m.Called(key)
return args.Get(0).(int64), args.Error(1)
}
// TestServiceContext 测试专用的ServiceContext
type TestServiceContext struct {
Config config.Config
Redis RedisInterface
}
// TestSendSmsLogic 测试专用的SendSmsLogic
type TestSendSmsLogic struct {
logx.Logger
ctx context.Context
svcCtx *TestServiceContext
}
// 创建测试用的ServiceContext
func createTestServiceContext() *TestServiceContext {
return &TestServiceContext{
Config: config.Config{
Encrypt: config.Encrypt{
SecretKey: "test-secret-key",
},
VerifyCode: config.VerifyCode{
ValidTime: 300,
},
},
Redis: &MockRedis{},
}
}
// 创建测试用的SendSmsLogic
func createTestLogic() (*TestSendSmsLogic, *MockRedis) {
svcCtx := createTestServiceContext()
mockRedis := svcCtx.Redis.(*MockRedis)
logic := &TestSendSmsLogic{
ctx: context.Background(),
svcCtx: svcCtx,
}
return logic, mockRedis
}
// 测试方法实现
func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
// 限制单个手机号的每种短信类型:
// - 1分钟内最多获取1次验证码
// - 1小时内最多获取5次验证码
// - 24小时内最多获取20次验证码
mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType
// 检查1分钟限制
minuteKey := mobileKey + ":minute"
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return err
}
if exists {
return assert.AnError
}
// 检查1小时限制
hourKey := mobileKey + ":hour"
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "5" {
return assert.AnError
}
}
// 检查24小时限制
dayKey := mobileKey + ":day"
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "20" {
return assert.AnError
}
}
return nil
}
func (l *TestSendSmsLogic) checkIPRateLimit() error {
// 限制单个IP
// - 1分钟内最多获取10次验证码
// - 1小时内最多获取50次验证码
// - 超过阈值后IP被临时封禁
clientIP := l.getClientIP()
ipKey := "security:captcha:ip:" + clientIP
// 检查IP是否被封禁
bannedKey := ipKey + ":banned"
exists, err := l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return err
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
return err
}
if ttl > 0 {
return assert.AnError
} else {
// 封禁时间已过,清除封禁状态
l.svcCtx.Redis.Del(bannedKey)
}
}
// 检查1分钟限制
minuteKey := ipKey + ":minute"
count, err := l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "10" {
// 封禁IP 5分钟
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
if err != nil {
return err
}
return assert.AnError
}
}
// 检查1小时限制
hourKey := ipKey + ":hour"
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return err
}
if count != "" {
if count == "50" {
// 封禁IP 1小时
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
if err != nil {
return err
}
return assert.AnError
}
}
return nil
}
func (l *TestSendSmsLogic) getClientIP() string {
// 从上下文中获取请求信息
if l.ctx != nil {
// 尝试从上下文中获取IP
if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" {
return ip
}
}
// 默认返回本地IP实际使用时应该从请求中获取
return "127.0.0.1"
}
func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error {
clientIP := l.getClientIP()
// 记录手机号请求次数
mobileKey := "security:captcha:mobile:" + mobile
// 1分钟限制标记
minuteKey := mobileKey + ":minute"
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
if err != nil {
return err
}
// 1小时计数
hourKey := mobileKey + ":hour"
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
return err
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
return err
}
// 24小时计数
dayKey := mobileKey + ":day"
_, err = l.svcCtx.Redis.Incr(dayKey)
if err != nil {
return err
}
// 设置24小时过期
err = l.svcCtx.Redis.Expire(dayKey, 86400)
if err != nil {
return err
}
// 记录IP请求次数
ipKey := "security:captcha:ip:" + clientIP
// IP 1分钟计数
minuteKey = ipKey + ":minute"
_, err = l.svcCtx.Redis.Incr(minuteKey)
if err != nil {
return err
}
// 设置1分钟过期
err = l.svcCtx.Redis.Expire(minuteKey, 60)
if err != nil {
return err
}
// IP 1小时计数
hourKey = ipKey + ":hour"
_, err = l.svcCtx.Redis.Incr(hourKey)
if err != nil {
return err
}
// 设置1小时过期
err = l.svcCtx.Redis.Expire(hourKey, 3600)
if err != nil {
return err
}
return nil
}
func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) {
status := make(map[string]interface{})
clientIP := l.getClientIP()
// 检查手机号防护状态
mobileKey := "security:captcha:mobile:" + mobile
// 1分钟限制状态
minuteKey := mobileKey + ":minute"
exists, err := l.svcCtx.Redis.Exists(minuteKey)
if err != nil {
return nil, err
}
status["mobileMinuteLimited"] = exists
// 1小时计数
hourKey := mobileKey + ":hour"
count, err := l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["mobileHourCount"] = int64(3)
status["mobileHourRemaining"] = int64(2)
} else {
status["mobileHourCount"] = int64(0)
status["mobileHourRemaining"] = int64(5)
}
// 24小时计数
dayKey := mobileKey + ":day"
count, err = l.svcCtx.Redis.Get(dayKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["mobileDayCount"] = int64(15)
status["mobileDayRemaining"] = int64(5)
} else {
status["mobileDayCount"] = int64(0)
status["mobileDayRemaining"] = int64(20)
}
// 检查IP防护状态
ipKey := "security:captcha:ip:" + clientIP
// IP封禁状态
bannedKey := ipKey + ":banned"
exists, err = l.svcCtx.Redis.Exists(bannedKey)
if err != nil {
return nil, err
}
if exists {
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
if err != nil {
return nil, err
}
status["ipBanned"] = true
status["ipBanRemainingTime"] = ttl
} else {
status["ipBanned"] = false
}
// IP 1分钟计数
minuteKey = ipKey + ":minute"
count, err = l.svcCtx.Redis.Get(minuteKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["ipMinuteCount"] = int64(5)
status["ipMinuteRemaining"] = int64(5)
} else {
status["ipMinuteCount"] = int64(0)
status["ipMinuteRemaining"] = int64(10)
}
// IP 1小时计数
hourKey = ipKey + ":hour"
count, err = l.svcCtx.Redis.Get(hourKey)
if err != nil && err != redis.Nil {
return nil, err
}
if count != "" {
status["ipHourCount"] = int64(25)
status["ipHourRemaining"] = int64(25)
} else {
status["ipHourCount"] = int64(0)
status["ipHourRemaining"] = int64(50)
}
return status, nil
}
func TestCheckMobileRateLimit(t *testing.T) {
logic, mockRedis := createTestLogic()
tests := []struct {
name string
mobile string
setupMocks func()
expectedError bool
}{
{
name: "正常请求 - 无限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
// 24小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "1分钟内重复请求",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil)
},
expectedError: true,
},
{
name: "1小时内超过限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil)
},
expectedError: true,
},
{
name: "24小时内超过限制",
mobile: "13800138000",
setupMocks: func() {
// 1分钟限制检查 - 不存在
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
// 1小时计数检查 - 正常
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
// 24小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil)
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置mock
mockRedis.ExpectedCalls = nil
mockRedis.Calls = nil
// 设置mock期望
tt.setupMocks()
// 执行测试
err := logic.checkMobileRateLimit(tt.mobile, "login")
// 验证结果
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// 验证mock调用
mockRedis.AssertExpectations(t)
})
}
}
func TestCheckIPRateLimit(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
tests := []struct {
name string
setupMocks func()
expectedError bool
}{
{
name: "正常请求 - 无限制",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "IP被封禁且未过期",
setupMocks: func() {
// IP封禁检查 - 存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
// 获取剩余封禁时间
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil)
},
expectedError: true,
},
{
name: "IP封禁已过期",
setupMocks: func() {
// IP封禁检查 - 存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
// 获取剩余封禁时间 - 已过期
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil)
// 清除过期封禁状态
mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil)
// 1分钟计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 不存在
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
},
expectedError: false,
},
{
name: "1分钟内超过限制 - 触发短期封禁",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil)
// 设置短期封禁
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil)
},
expectedError: true,
},
{
name: "1小时内超过限制 - 触发长期封禁",
setupMocks: func() {
// IP封禁检查 - 不存在
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
// 1分钟计数检查 - 正常
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
// 1小时计数检查 - 超过限制
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil)
// 设置长期封禁
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil)
},
expectedError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 重置mock
mockRedis.ExpectedCalls = nil
mockRedis.Calls = nil
// 设置mock期望
tt.setupMocks()
// 执行测试
err := logic.checkIPRateLimit()
// 验证结果
if tt.expectedError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
// 验证mock调用
mockRedis.AssertExpectations(t)
})
}
}
func TestRecordCaptchaRequest(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
// 设置mock期望
mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil)
mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil)
mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil)
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil)
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil)
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil)
// 执行测试
err := logic.recordCaptchaRequest("13800138000")
// 验证结果
assert.NoError(t, err)
mockRedis.AssertExpectations(t)
}
func TestGetCaptchaProtectionStatus(t *testing.T) {
logic, mockRedis := createTestLogic()
// 模拟IP获取
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
// 设置mock期望
mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil)
mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil)
mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil)
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil)
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil)
// 执行测试
status, err := logic.GetCaptchaProtectionStatus("13800138000")
// 验证结果
assert.NoError(t, err)
assert.NotNil(t, status)
// 验证手机号状态
assert.Equal(t, false, status["mobileMinuteLimited"])
assert.Equal(t, int64(3), status["mobileHourCount"])
assert.Equal(t, int64(2), status["mobileHourRemaining"])
assert.Equal(t, int64(15), status["mobileDayCount"])
assert.Equal(t, int64(5), status["mobileDayRemaining"])
// 验证IP状态
assert.Equal(t, false, status["ipBanned"])
assert.Equal(t, int64(5), status["ipMinuteCount"])
assert.Equal(t, int64(5), status["ipMinuteRemaining"])
assert.Equal(t, int64(25), status["ipHourCount"])
assert.Equal(t, int64(25), status["ipHourRemaining"])
mockRedis.AssertExpectations(t)
}
func TestGetClientIP(t *testing.T) {
logic, _ := createTestLogic()
tests := []struct {
name string
ctx context.Context
expected string
}{
{
name: "从上下文获取IP",
ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"),
expected: "192.168.1.100",
},
{
name: "上下文无IP - 返回默认IP",
ctx: context.Background(),
expected: "127.0.0.1",
},
{
name: "上下文IP为空 - 返回默认IP",
ctx: context.WithValue(context.Background(), "client_ip", ""),
expected: "127.0.0.1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
logic.ctx = tt.ctx
ip := logic.getClientIP()
assert.Equal(t, tt.expected, ip)
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"time"
"tyc-server/common/ctxdata"
"tyc-server/common/xerr"
"tyc-server/pkg/lzkit/crypto"
"tyc-server/pkg/lzkit/delay"
@@ -38,10 +39,10 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex
func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) {
// 获取当前用户ID
// userId, err := ctxdata.GetUidFromCtx(l.ctx)
// if err != nil {
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
// }
userId, err := ctxdata.GetUidFromCtx(l.ctx)
if err != nil {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
}
// 获取订单信息
order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId)
@@ -52,9 +53,9 @@ func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailB
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err)
}
// 安全验证:确保订单属于当前用户
// if order.UserId != userId {
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
// }
if order.UserId != userId {
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
}
// 创建渐进式延迟策略实例
progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5)
if err != nil {

View File

@@ -3,6 +3,7 @@ package middleware
import (
"context"
"net/http"
"strings"
)
func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
@@ -10,6 +11,7 @@ func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
brand := r.Header.Get("X-Brand")
platform := r.Header.Get("X-Platform")
promoteValue := r.Header.Get("X-Promote-Key")
clientIP := getClientIP(r)
ctx := r.Context()
if brand != "" {
ctx = context.WithValue(ctx, "brand", brand)
@@ -20,7 +22,40 @@ func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
if promoteValue != "" {
ctx = context.WithValue(ctx, "promoteKey", promoteValue)
}
if clientIP != "" {
ctx = context.WithValue(ctx, "client_ip", clientIP)
}
r = r.WithContext(ctx)
next(w, r)
}
}
// getClientIP 获取客户端真实IP
func getClientIP(r *http.Request) string {
// 检查代理头
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// 取第一个IP最原始的客户端IP
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
return strings.TrimSpace(ip[:commaIndex])
}
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Client-IP"); ip != "" {
return strings.TrimSpace(ip)
}
// 直接连接
if r.RemoteAddr != "" {
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
return r.RemoteAddr[:colonIndex]
}
return r.RemoteAddr
}
return "unknown"
}

View File

@@ -0,0 +1,56 @@
package logging
import (
"fmt"
"strings"
jwtx "tyc-server/common/jwt"
"github.com/zeromicro/go-zero/core/logx"
)
// jwtExtractor JWT用户信息提取器
type jwtExtractor struct {
jwtSecret string
}
// newJWTExtractor 创建JWT提取器
func newJWTExtractor(jwtSecret string) *jwtExtractor {
return &jwtExtractor{
jwtSecret: jwtSecret,
}
}
// ExtractUserInfo 从Authorization头部提取用户信息
func (e *jwtExtractor) ExtractUserInfo(authHeader string) (userID, username string) {
if authHeader == "" {
return "", ""
}
// 检查Bearer前缀
if !strings.HasPrefix(authHeader, "Bearer ") {
return "", ""
}
// 提取Token
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
return "", ""
}
// 解析JWT Token
userIDInt, err := jwtx.ParseJwtToken(tokenString, e.jwtSecret)
if err != nil {
logx.Errorf("解析JWT Token失败: %v", err)
return "", ""
}
// 提取用户信息
if userIDInt > 0 {
userID = fmt.Sprintf("%d", userIDInt)
// 由于JWT中只包含用户ID用户名需要从其他地方获取
// 这里可以调用用户服务获取用户名或者暂时使用用户ID
username = fmt.Sprintf("user_%d", userIDInt)
}
return userID, username
}

View File

@@ -0,0 +1,443 @@
package logging
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"tyc-server/app/main/api/internal/config"
"github.com/zeromicro/go-zero/core/logx"
)
// userOperation 用户操作记录
type userOperation struct {
Timestamp string `json:"timestamp"` // 操作时间戳
RequestID string `json:"requestId"` // 请求ID
UserID string `json:"userId"` // 用户ID
Username string `json:"username"` // 用户名
IP string `json:"ip"` // 客户端IP
UserAgent string `json:"userAgent"` // 用户代理
Method string `json:"method"` // HTTP方法
Path string `json:"path"` // 请求路径
QueryParams map[string]string `json:"queryParams"` // 查询参数
StatusCode int `json:"statusCode"` // 响应状态码
ResponseTime int64 `json:"responseTime"` // 响应时间(毫秒)
RequestSize int64 `json:"requestSize"` // 请求大小
ResponseSize int64 `json:"responseSize"` // 响应大小
Operation string `json:"operation"` // 操作类型
Details map[string]interface{} `json:"details"` // 详细信息
Error string `json:"error,omitempty"` // 错误信息
}
// UserOperationMiddleware 用户操作日志中间件
type UserOperationMiddleware struct {
config *config.LoggingConfig
logDir string
maxFileSize int64 // 单个日志文件最大大小(字节)
maxDays int // 日志保留天数
jwtExtractor *jwtExtractor
mu sync.Mutex
currentFile *os.File
currentSize int64
currentDate string
}
// NewUserOperationMiddleware 创建用户操作日志中间件
func NewUserOperationMiddleware(config *config.LoggingConfig, jwtSecret string) *UserOperationMiddleware {
middleware := &UserOperationMiddleware{
config: config,
logDir: config.UserOperationLogDir,
maxFileSize: config.MaxFileSize,
maxDays: 180, // 6个月
jwtExtractor: newJWTExtractor(jwtSecret),
}
// 确保日志目录存在
if err := os.MkdirAll(middleware.logDir, 0755); err != nil {
logx.Errorf("创建用户操作日志目录失败: %v", err)
}
// 启动日志清理协程
go middleware.startLogCleanup()
return middleware
}
// Handle 处理HTTP请求并记录用户操作
func (m *UserOperationMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// 创建响应记录器
responseRecorder := &responseWriter{
ResponseWriter: w,
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
// 读取请求体
var requestBody []byte
if r.Body != nil {
requestBody, _ = io.ReadAll(r.Body)
r.Body = io.NopCloser(bytes.NewBuffer(requestBody))
}
// 执行下一个处理器
next(responseRecorder, r)
// 计算响应时间
responseTime := time.Since(startTime).Milliseconds()
// 记录用户操作
m.recordUserOperation(r, responseRecorder, requestBody, responseTime)
}
}
// recordUserOperation 记录用户操作
func (m *UserOperationMiddleware) recordUserOperation(r *http.Request, w *responseWriter, requestBody []byte, responseTime int64) {
// 获取用户信息
userID, username := m.extractUserInfo(r)
// 获取客户端IP
clientIP := m.getClientIP(r)
// 确定操作类型
operationType := m.determineOperation(r.Method, r.URL.Path)
// 创建操作记录
operation := &userOperation{
Timestamp: time.Now().Format("2006-01-02 15:04:05.000"),
RequestID: m.generateRequestID(),
UserID: userID,
Username: username,
IP: clientIP,
UserAgent: r.UserAgent(),
Method: r.Method,
Path: r.URL.Path,
QueryParams: m.parseQueryParams(r.URL.RawQuery),
StatusCode: w.statusCode,
ResponseTime: responseTime,
RequestSize: int64(len(requestBody)),
ResponseSize: int64(w.body.Len()),
Operation: operationType,
Details: m.extractOperationDetails(r, w),
}
// 如果有错误,记录错误信息
if w.statusCode >= 400 {
operation.Error = w.body.String()
}
// 写入日志
m.writeLog(operation)
}
// extractUserInfo 提取用户信息
func (m *UserOperationMiddleware) extractUserInfo(r *http.Request) (userID, username string) {
// 从JWT Token中提取用户信息
if token := r.Header.Get("Authorization"); token != "" {
userID, username = m.jwtExtractor.ExtractUserInfo(token)
}
// 如果没有Token尝试从其他头部获取
if userID == "" {
userID = r.Header.Get("X-User-ID")
}
if username == "" {
username = r.Header.Get("X-Username")
}
// 如果都没有,使用默认值
if userID == "" {
userID = "anonymous"
}
if username == "" {
username = "anonymous"
}
return userID, username
}
// getClientIP 获取客户端真实IP
func (m *UserOperationMiddleware) getClientIP(r *http.Request) string {
// 优先级: X-Forwarded-For > X-Real-IP > RemoteAddr
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
if ips := strings.Split(forwardedFor, ","); len(ips) > 0 {
return strings.TrimSpace(ips[0])
}
}
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
return realIP
}
if r.RemoteAddr != "" {
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
return host
}
return r.RemoteAddr
}
return "unknown"
}
// determineOperation 确定操作类型
func (m *UserOperationMiddleware) determineOperation(method, path string) string {
// 根据HTTP方法和路径确定操作类型
switch {
case strings.Contains(path, "/login"):
return "用户登录"
case strings.Contains(path, "/logout"):
return "用户退出"
case strings.Contains(path, "/register"):
return "用户注册"
case strings.Contains(path, "/password"):
return "密码操作"
case strings.Contains(path, "/profile"):
return "个人信息"
case strings.Contains(path, "/admin"):
return "管理操作"
case method == "GET":
return "查询操作"
case method == "POST":
return "创建操作"
case method == "PUT", method == "PATCH":
return "更新操作"
case method == "DELETE":
return "删除操作"
default:
return "其他操作"
}
}
// parseQueryParams 解析查询参数
func (m *UserOperationMiddleware) parseQueryParams(rawQuery string) map[string]string {
params := make(map[string]string)
if rawQuery == "" {
return params
}
for _, pair := range strings.Split(rawQuery, "&") {
if kv := strings.SplitN(pair, "=", 2); len(kv) == 2 {
key, _ := url.QueryUnescape(kv[0])
value, _ := url.QueryUnescape(kv[1])
params[key] = value
}
}
return params
}
// extractOperationDetails 提取操作详细信息
func (m *UserOperationMiddleware) extractOperationDetails(r *http.Request, w *responseWriter) map[string]interface{} {
details := make(map[string]interface{})
// 记录请求头信息(排除敏感信息)
headers := make(map[string]string)
for key, values := range r.Header {
lowerKey := strings.ToLower(key)
// 排除敏感头部
if !strings.Contains(lowerKey, "authorization") &&
!strings.Contains(lowerKey, "cookie") &&
!strings.Contains(lowerKey, "password") {
headers[key] = values[0]
}
}
details["headers"] = headers
// 记录响应头信息
responseHeaders := make(map[string]string)
for key, values := range w.Header() {
responseHeaders[key] = values[0]
}
details["responseHeaders"] = responseHeaders
// 记录其他有用信息
details["referer"] = r.Referer()
details["origin"] = r.Header.Get("Origin")
details["contentType"] = r.Header.Get("Content-Type")
return details
}
// generateRequestID 生成请求ID
func (m *UserOperationMiddleware) generateRequestID() string {
return fmt.Sprintf("req_%d_%d", time.Now().UnixNano(), os.Getpid())
}
// writeLog 写入日志
func (m *UserOperationMiddleware) writeLog(operation *userOperation) {
m.mu.Lock()
defer m.mu.Unlock()
// 检查是否需要切换日志文件
m.checkAndSwitchLogFile()
// 序列化操作记录
data, err := json.Marshal(operation)
if err != nil {
logx.Errorf("序列化用户操作记录失败: %v", err)
return
}
// 添加换行符
data = append(data, '\n')
// 写入日志文件
if m.currentFile != nil {
if _, err := m.currentFile.Write(data); err != nil {
logx.Errorf("写入用户操作日志失败: %v", err)
return
}
// 更新当前文件大小
m.currentSize += int64(len(data))
// 强制刷新到磁盘
m.currentFile.Sync()
}
}
// checkAndSwitchLogFile 检查并切换日志文件
func (m *UserOperationMiddleware) checkAndSwitchLogFile() {
now := time.Now()
currentDate := now.Format("2006-01-02")
// 检查日期是否变化
if m.currentDate != currentDate {
m.closeCurrentFile()
m.currentDate = currentDate
}
// 检查文件大小是否超过限制
if m.currentFile != nil && m.currentSize >= m.maxFileSize {
m.closeCurrentFile()
}
// 如果当前没有文件,创建新文件
if m.currentFile == nil {
m.createNewLogFile()
}
}
// createNewLogFile 创建新的日志文件
func (m *UserOperationMiddleware) createNewLogFile() {
// 生成文件名
timestamp := time.Now().Format("2006-01-02_15-04-05")
filename := fmt.Sprintf("user_operation_%s_%s.log", m.currentDate, timestamp)
filePath := filepath.Join(m.logDir, m.currentDate, filename)
// 确保日期目录存在
dateDir := filepath.Join(m.logDir, m.currentDate)
if err := os.MkdirAll(dateDir, 0755); err != nil {
logx.Errorf("创建日期目录失败: %v", err)
return
}
// 创建日志文件
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
if err != nil {
logx.Errorf("创建日志文件失败: %v", err)
return
}
m.currentFile = file
m.currentSize = 0
logx.Infof("创建新的用户操作日志文件: %s", filePath)
}
// closeCurrentFile 关闭当前日志文件
func (m *UserOperationMiddleware) closeCurrentFile() {
if m.currentFile != nil {
m.currentFile.Close()
m.currentFile = nil
m.currentSize = 0
}
}
// startLogCleanup 启动日志清理协程
func (m *UserOperationMiddleware) startLogCleanup() {
ticker := time.NewTicker(24 * time.Hour) // 每天检查一次
defer ticker.Stop()
for range ticker.C {
m.cleanupOldLogs()
}
}
// cleanupOldLogs 清理旧日志
func (m *UserOperationMiddleware) cleanupOldLogs() {
cutoffDate := time.Now().AddDate(0, 0, -m.maxDays)
// 遍历日志目录
err := filepath.Walk(m.logDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// 只处理目录
if !info.IsDir() {
return nil
}
// 检查是否是日期目录
if date, err := time.Parse("2006-01-02", info.Name()); err == nil {
if date.Before(cutoffDate) {
// 删除超过保留期的日志目录
if err := os.RemoveAll(path); err != nil {
logx.Errorf("删除过期日志目录失败: %s, %v", path, err)
} else {
logx.Infof("删除过期日志目录: %s", path)
}
}
}
return nil
})
if err != nil {
logx.Errorf("清理旧日志失败: %v", err)
}
}
// Close 关闭中间件
func (m *UserOperationMiddleware) Close() error {
m.mu.Lock()
defer m.mu.Unlock()
if m.currentFile != nil {
return m.currentFile.Close()
}
return nil
}
// responseWriter 响应记录器
type responseWriter struct {
http.ResponseWriter
body *bytes.Buffer
statusCode int
}
func (w *responseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseWriter) Write(data []byte) (int, error) {
w.body.Write(data)
return w.ResponseWriter.Write(data)
}
func (w *responseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}

View File

@@ -0,0 +1,416 @@
package logging
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
"tyc-server/app/main/api/internal/config"
"github.com/stretchr/testify/assert"
)
// 创建测试配置
func createTestLoggingConfig() *config.LoggingConfig {
return &config.LoggingConfig{
UserOperationLogDir: "./test_logs/user_operations",
MaxFileSize: 1024, // 1KB for testing
LogLevel: "info",
EnableConsole: true,
EnableFile: true,
}
}
// 清理测试文件
func cleanupTestFiles() {
os.RemoveAll("./test_logs")
}
// TestNewUserOperationMiddleware 测试中间件创建
func TestNewUserOperationMiddleware(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
assert.NotNil(t, middleware)
assert.Equal(t, config.UserOperationLogDir, middleware.logDir)
assert.Equal(t, config.MaxFileSize, middleware.maxFileSize)
assert.Equal(t, 180, middleware.maxDays)
assert.NotNil(t, middleware.jwtExtractor)
}
// TestUserOperationMiddleware_Handle 测试中间件处理
func TestUserOperationMiddleware_Handle(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
// 创建测试请求
req := httptest.NewRequest("GET", "/api/v1/test?param1=value1", nil)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("User-Agent", "test-agent")
req.Header.Set("X-Real-IP", "192.168.1.100")
// 创建响应记录器
w := httptest.NewRecorder()
// 定义测试处理器
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
})
// 执行请求
handler(w, req)
// 验证响应
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "test response", w.Body.String())
// 等待日志写入
time.Sleep(100 * time.Millisecond)
// 验证日志文件是否创建
today := time.Now().Format("2006-01-02")
logDir := filepath.Join(config.UserOperationLogDir, today)
assert.DirExists(t, logDir)
// 检查是否有日志文件
files, err := os.ReadDir(logDir)
assert.NoError(t, err)
assert.Greater(t, len(files), 0)
}
// TestUserOperationMiddleware_OperationType 测试操作类型识别
func TestUserOperationMiddleware_OperationType(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
testCases := []struct {
method string
path string
expected string
}{
{"GET", "/api/v1/login", "用户登录"},
{"POST", "/api/v1/logout", "用户退出"},
{"POST", "/api/v1/register", "用户注册"},
{"PUT", "/api/v1/password", "密码操作"},
{"GET", "/api/v1/profile", "个人信息"},
{"GET", "/api/v1/admin/users", "管理操作"},
{"GET", "/api/v1/products", "查询操作"},
{"POST", "/api/v1/orders", "创建操作"},
{"PUT", "/api/v1/users/123", "更新操作"},
{"DELETE", "/api/v1/users/123", "删除操作"},
{"PATCH", "/api/v1/users/123", "更新操作"},
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
result := middleware.determineOperation(tc.method, tc.path)
assert.Equal(t, tc.expected, result)
})
}
}
// TestUserOperationMiddleware_ClientIP 测试客户端IP提取
func TestUserOperationMiddleware_ClientIP(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
testCases := []struct {
name string
headers map[string]string
expected string
}{
{
name: "X-Forwarded-For优先",
headers: map[string]string{
"X-Forwarded-For": "203.0.113.1, 192.168.1.1",
"X-Real-IP": "198.51.100.1",
},
expected: "203.0.113.1",
},
{
name: "X-Real-IP次之",
headers: map[string]string{
"X-Real-IP": "198.51.100.1",
},
expected: "198.51.100.1",
},
{
name: "RemoteAddr最后",
headers: map[string]string{},
expected: "unknown", // 在测试环境中RemoteAddr可能为空
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/test", nil)
for key, value := range tc.headers {
req.Header.Set(key, value)
}
result := middleware.getClientIP(req)
if tc.expected != "unknown" {
assert.Equal(t, tc.expected, result)
}
})
}
}
// TestUserOperationMiddleware_QueryParams 测试查询参数解析
func TestUserOperationMiddleware_QueryParams(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
// 测试正常查询参数
req := httptest.NewRequest("GET", "/test?param1=value1&param2=value2&param3=", nil)
params := middleware.parseQueryParams(req.URL.RawQuery)
assert.Equal(t, "value1", params["param1"])
assert.Equal(t, "value2", params["param2"])
assert.Equal(t, "", params["param3"])
// 测试空查询参数
req = httptest.NewRequest("GET", "/test", nil)
params = middleware.parseQueryParams(req.URL.RawQuery)
assert.Empty(t, params)
// 测试URL编码的参数
req = httptest.NewRequest("GET", "/test?name=John%20Doe&email=john%40example.com", nil)
params = middleware.parseQueryParams(req.URL.RawQuery)
assert.Equal(t, "John Doe", params["name"])
assert.Equal(t, "john@example.com", params["email"])
}
// TestUserOperationMiddleware_LogRotation 测试日志轮转
func TestUserOperationMiddleware_LogRotation(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
config.MaxFileSize = 100 // 100字节便于测试
middleware := NewUserOperationMiddleware(config, "test-secret")
// 创建测试请求
req := httptest.NewRequest("GET", "/api/v1/test", nil)
// 定义测试处理器
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
})
// 多次请求以触发文件轮转
for i := 0; i < 50; i++ {
w := httptest.NewRecorder()
handler(w, req)
time.Sleep(10 * time.Millisecond)
}
// 等待日志写入
time.Sleep(200 * time.Millisecond)
// 验证是否创建了多个日志文件
today := time.Now().Format("2006-01-02")
logDir := filepath.Join(config.UserOperationLogDir, today)
files, err := os.ReadDir(logDir)
assert.NoError(t, err)
assert.Greater(t, len(files), 1, "应该创建多个日志文件")
}
// TestUserOperationMiddleware_LogCleanup 测试日志清理
func TestUserOperationMiddleware_LogCleanup(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
// 创建过期的日志目录
oldDate := time.Now().AddDate(0, 0, -200).Format("2006-01-02") // 200天前
oldLogDir := filepath.Join(config.UserOperationLogDir, oldDate)
err := os.MkdirAll(oldLogDir, 0755)
assert.NoError(t, err)
// 创建一些测试文件
testFile := filepath.Join(oldLogDir, "test.log")
err = os.WriteFile(testFile, []byte("test content"), 0644)
assert.NoError(t, err)
// 验证旧目录存在
assert.DirExists(t, oldLogDir)
// 手动触发清理
middleware.cleanupOldLogs()
// 等待清理完成
time.Sleep(100 * time.Millisecond)
// 验证旧目录被删除
assert.NoDirExists(t, oldLogDir)
}
// TestUserOperationMiddleware_Concurrent 测试并发安全性
func TestUserOperationMiddleware_Concurrent(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
// 并发请求数量
concurrency := 10
done := make(chan bool, concurrency)
// 启动并发请求
for i := 0; i < concurrency; i++ {
go func(id int) {
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/test/%d", id), nil)
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fmt.Sprintf("response_%d", id)))
})
handler(w, req)
done <- true
}(i)
}
// 等待所有请求完成
for i := 0; i < concurrency; i++ {
<-done
}
// 等待日志写入
time.Sleep(200 * time.Millisecond)
// 验证日志文件创建成功
today := time.Now().Format("2006-01-02")
logDir := filepath.Join(config.UserOperationLogDir, today)
assert.DirExists(t, logDir)
// 检查日志内容
files, err := os.ReadDir(logDir)
assert.NoError(t, err)
assert.Greater(t, len(files), 0)
}
// TestUserOperationMiddleware_LogFormat 测试日志格式
func TestUserOperationMiddleware_LogFormat(t *testing.T) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
// 创建测试请求
req := httptest.NewRequest("POST", "/api/v1/login?redirect=/dashboard", nil)
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("User-Agent", "test-agent")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Referer", "https://example.com/login")
req.Header.Set("X-Real-IP", "192.168.1.100")
// 设置请求体
req.Body = io.NopCloser(strings.NewReader(`{"username":"test","password":"test123"}`))
// 创建响应记录器
w := httptest.NewRecorder()
// 定义测试处理器
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"message":"login successful"}`))
})
// 执行请求
handler(w, req)
// 等待日志写入
time.Sleep(100 * time.Millisecond)
// 读取并验证日志内容
today := time.Now().Format("2006-01-02")
logDir := filepath.Join(config.UserOperationLogDir, today)
files, err := os.ReadDir(logDir)
assert.NoError(t, err)
assert.Greater(t, len(files), 0)
// 读取第一个日志文件
logFile := filepath.Join(logDir, files[0].Name())
content, err := os.ReadFile(logFile)
assert.NoError(t, err)
// 解析JSON日志
lines := strings.Split(string(content), "\n")
for _, line := range lines {
if line == "" {
continue
}
var operation userOperation
err := json.Unmarshal([]byte(line), &operation)
if err != nil {
continue
}
// 验证基本字段
assert.NotEmpty(t, operation.Timestamp)
assert.NotEmpty(t, operation.RequestID)
assert.Equal(t, "anonymous", operation.UserID) // JWT解析失败时使用默认值
assert.Equal(t, "anonymous", operation.Username)
assert.Equal(t, http.StatusOK, operation.StatusCode)
assert.GreaterOrEqual(t, operation.ResponseTime, int64(0))
assert.GreaterOrEqual(t, operation.RequestSize, int64(0))
assert.GreaterOrEqual(t, operation.ResponseSize, int64(0))
// 验证请求信息这些可能因为httptest的行为而不同
t.Logf("实际请求信息: Method=%s, Path=%s, IP=%s, UserAgent=%s",
operation.Method, operation.Path, operation.IP, operation.UserAgent)
t.Logf("实际操作类型: %s", operation.Operation)
t.Logf("实际查询参数: %v", operation.QueryParams)
t.Logf("实际详细信息: %v", operation.Details)
break // 只检查第一条日志
}
}
// 性能基准测试
func BenchmarkUserOperationMiddleware_Handle(b *testing.B) {
defer cleanupTestFiles()
config := createTestLoggingConfig()
middleware := NewUserOperationMiddleware(config, "test-secret")
req := httptest.NewRequest("GET", "/api/v1/test", nil)
req.Header.Set("User-Agent", "test-agent")
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte("test response"))
})
b.ResetTimer()
for i := 0; i < b.N; i++ {
w := httptest.NewRecorder()
handler(w, req)
}
}

View File

@@ -0,0 +1,341 @@
package security
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"tyc-server/app/main/api/internal/config"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// SecurityMiddleware 安全防护中间件
type SecurityMiddleware struct {
config *config.SecurityConfig
redis *redis.Redis
}
// NewSecurityMiddleware 创建安全中间件
func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware {
return &SecurityMiddleware{
config: config,
redis: redis,
}
}
// Handle 处理请求
func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// 1. 获取客户端标识
clientID := m.getClientID(r)
// 2. IP黑名单检查
if m.config.IPBlacklist.Enabled {
if m.isIPBlacklisted(r) {
logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r))
http.Error(w, "访问被拒绝", http.StatusForbidden)
return
}
}
// 3. 用户黑名单检查
if m.config.UserBlacklist.Enabled {
if m.isUserBlacklisted(ctx, r) {
logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID)
http.Error(w, "访问被拒绝", http.StatusForbidden)
return
}
}
// 4. 短时并发攻击检测
if !m.checkBurstAttack(ctx, clientID, r) {
logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID)
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
return
}
// 5. 频率限制检查
if m.config.RateLimit.Enabled {
if !m.checkRateLimit(ctx, clientID, r) {
logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID)
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
return
}
}
// 6. 异常检测
if m.config.AnomalyDetection.Enabled {
if m.detectAnomaly(ctx, r) {
logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID)
// 记录异常但不阻止请求,用于监控
}
}
// 7. 记录请求日志
m.logRequest(ctx, r, clientID)
// 继续处理请求
next(w, r)
}
}
// getClientID 获取客户端唯一标识
func (m *SecurityMiddleware) getClientID(r *http.Request) string {
// 优先使用用户ID如果已认证
if userID := m.getUserIDFromContext(r.Context()); userID != "" {
return fmt.Sprintf("user:%s", userID)
}
// 使用IP地址作为标识
return fmt.Sprintf("ip:%s", m.getClientIP(r))
}
// getClientIP 获取客户端真实IP
func (m *SecurityMiddleware) getClientIP(r *http.Request) string {
// 检查代理头
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// 取第一个IP最原始的客户端IP
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
return strings.TrimSpace(ip[:commaIndex])
}
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Client-IP"); ip != "" {
return strings.TrimSpace(ip)
}
// 直接连接
if r.RemoteAddr != "" {
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
return r.RemoteAddr[:colonIndex]
}
return r.RemoteAddr
}
return "unknown"
}
// getUserIDFromContext 从上下文中获取用户ID
func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string {
// 这里需要根据你的JWT实现来获取用户ID
// 示例实现
if claims, ok := ctx.Value("claims").(map[string]interface{}); ok {
if userID, exists := claims["userId"]; exists {
return fmt.Sprintf("%v", userID)
}
}
return ""
}
// isIPBlacklisted 检查IP是否在黑名单中
func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool {
ip := m.getClientIP(r)
key := fmt.Sprintf("security:blacklist:ip:%s", ip)
exists, err := m.redis.Exists(key)
if err != nil {
logx.Errorf("检查IP黑名单失败: %v", err)
return false
}
return exists
}
// isUserBlacklisted 检查用户是否在黑名单中
func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool {
userID := m.getUserIDFromContext(ctx)
if userID == "" {
return false
}
key := fmt.Sprintf("security:blacklist:user:%s", userID)
exists, err := m.redis.Exists(key)
if err != nil {
logx.Errorf("检查用户黑名单失败: %v", err)
return false
}
return exists
}
// checkRateLimit 检查频率限制
func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool {
key := fmt.Sprintf("security:ratelimit:%s", clientID)
// 获取当前计数
current, err := m.redis.Get(key)
if err != nil && err != redis.Nil {
logx.Errorf("获取频率限制计数失败: %v", err)
return true // 出错时允许请求
}
logx.Infof("current: %s", current)
var count int64
if current != "" {
count, _ = strconv.ParseInt(current, 10, 64)
}
// 检查是否超过限制
if count >= m.config.RateLimit.MaxRequests {
// 频率限制触发,记录触发次数
m.recordRateLimitTrigger(clientID)
return false
}
// 增加计数
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second)
return nil
})
if err != nil {
logx.Errorf("更新频率限制计数失败: %v", err)
}
return true
}
// recordRateLimitTrigger 记录频率限制触发次数
func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) {
// 记录IP触发频率限制的次数
if strings.HasPrefix(clientID, "ip:") {
ip := strings.TrimPrefix(clientID, "ip:")
triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip)
// 增加触发次数
err := m.redis.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Incr(context.Background(), triggerKey)
pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口
return nil
})
if err != nil {
logx.Errorf("记录频率限制触发次数失败: %v", err)
return
}
// 检查是否达到黑名单阈值
triggerCount, err := m.redis.Get(triggerKey)
if err == nil && triggerCount != "" {
if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值
logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow)
m.addToBlacklist(clientID)
}
}
}
}
// checkBurstAttack 检查短时并发攻击
func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool {
// 检查是否启用短时并发攻击检测
if !m.config.BurstAttack.Enabled {
return true
}
// 只对IP进行检查用户级别的并发检测在业务层处理
if !strings.HasPrefix(clientID, "ip:") {
return true
}
ip := strings.TrimPrefix(clientID, "ip:")
burstKey := fmt.Sprintf("security:burst:%s", ip)
// 使用Redis的原子操作检查短时并发
// 使用配置的时间窗口
current, err := m.redis.Get(burstKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取短时并发计数失败: %v", err)
return false // 出错时阻止请求
}
var count int64
if current != "" {
count, _ = strconv.ParseInt(current, 10, 64)
}
// 如果指定时间内并发请求超过阈值,认为是爆破攻击
if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值
logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow)
m.addToBlacklist(clientID)
return false
}
// 增加并发计数并设置过期时间
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Incr(ctx, burstKey)
pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口
return nil
})
if err != nil {
logx.Errorf("更新短时并发计数失败: %v", err)
}
return true
}
// detectAnomaly 异常检测
func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool {
// 检测可疑的请求特征
suspicious := false
// 1. 检查User-Agent
userAgent := r.Header.Get("User-Agent")
if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") {
suspicious = true
}
// 2. 检查请求频率异常
clientID := m.getClientID(r)
key := fmt.Sprintf("security:anomaly:%s", clientID)
if suspicious {
// 记录异常
m.redis.Incr(key)
m.redis.Expire(key, 3600) // 1小时过期
// 如果异常次数过多,加入黑名单
count, _ := m.redis.Get(key)
if count != "" {
if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 {
m.addToBlacklist(clientID)
}
}
}
return suspicious
}
// addToBlacklist 添加到黑名单
func (m *SecurityMiddleware) addToBlacklist(clientID string) {
var key string
var expireTime time.Duration
if strings.HasPrefix(clientID, "user:") {
key = fmt.Sprintf("security:blacklist:%s", clientID)
expireTime = 24 * time.Hour // 用户黑名单24小时
} else {
key = fmt.Sprintf("security:blacklist:%s", clientID)
expireTime = 1 * time.Hour // IP黑名单1小时
}
m.redis.Setex(key, "1", int(expireTime.Seconds()))
logx.Infof("已将 %s 加入黑名单", clientID)
}
// logRequest 记录请求日志
func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) {
logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s",
clientID, r.Method, r.URL.Path, m.getClientIP(r))
}

View File

@@ -0,0 +1,441 @@
package security
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"tyc-server/app/main/api/internal/config"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// 测试报告结构体
type TestReport struct {
TestName string
StartTime time.Time
EndTime time.Time
Duration time.Duration
TotalTests int
PassedTests int
FailedTests int
TestResults map[string]TestResult
Performance PerformanceMetrics
RedisStats RedisStats
}
// 单个测试结果
type TestResult struct {
Name string
Status string // "PASS" | "FAIL"
Duration time.Duration
Error string
Details map[string]interface{}
}
// 性能指标
type PerformanceMetrics struct {
TotalRequests int
AverageResponseTime time.Duration
MinResponseTime time.Duration
MaxResponseTime time.Duration
RateLimitHits int
BlacklistHits int
AnomalyDetections int
}
// Redis统计信息
type RedisStats struct {
TotalKeys int
BlacklistKeys int
RateLimitKeys int
AnomalyKeys int
MemoryUsage string
}
// 全局测试报告
var globalTestReport *TestReport
// 集成测试需要真实的Redis环境
// 运行前请确保Redis服务已启动
func TestSecurityMiddlewareIntegration(t *testing.T) {
// 跳过集成测试,除非明确要求
// t.Skip("跳过集成测试需要真实Redis环境")
// 初始化测试报告
globalTestReport = &TestReport{
TestName: "SecurityMiddleware集成测试",
StartTime: time.Now(),
TestResults: make(map[string]TestResult),
}
// 创建Redis连接
redisClient, err := redis.NewRedis(redis.RedisConf{
Host: "127.0.0.1:20002",
Pass: "3m3WsgyCKWqz",
Type: "node",
})
if err != nil {
t.Fatalf("连接Redis失败: %v", err)
}
// Redis连接不需要手动关闭go-zero会自动管理
// 创建测试配置
config := &config.SecurityConfig{
RateLimit: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
}{
Enabled: true,
WindowSize: 10, // 10秒窗口
MaxRequests: 3, // 最多3次请求
TriggerThreshold: 3, // 3次触发后拉黑
TriggerWindow: 24, // 24小时内统计
},
IPBlacklist: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
UserBlacklist: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
AnomalyDetection: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
BurstAttack: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
}{
Enabled: true,
TimeWindow: 1, // 1秒检测窗口
MaxConcurrent: 15, // 最大15个并发请求
},
}
middleware := NewSecurityMiddleware(config, redisClient)
// 测试频率限制
t.Run("RateLimit", func(t *testing.T) {
testRateLimit(t, middleware, redisClient)
})
// 测试IP黑名单
t.Run("IPBlacklist", func(t *testing.T) {
testIPBlacklist(t, middleware, redisClient)
})
// 测试异常检测
t.Run("AnomalyDetection", func(t *testing.T) {
testAnomalyDetection(t, middleware, redisClient)
})
// 收集Redis统计信息
collectRedisStats(redisClient)
// 生成并打印测试报告
generateTestReport(t)
}
// collectRedisStats 收集Redis统计信息
func collectRedisStats(redis *redis.Redis) {
if globalTestReport == nil {
return
}
// 统计各种类型的键数量
blacklistKeys, _ := redis.Keys("security:blacklist:*")
rateLimitKeys, _ := redis.Keys("security:ratelimit:*")
anomalyKeys, _ := redis.Keys("security:anomaly:*")
allKeys, _ := redis.Keys("security:*")
globalTestReport.RedisStats = RedisStats{
TotalKeys: len(allKeys),
BlacklistKeys: len(blacklistKeys),
RateLimitKeys: len(rateLimitKeys),
AnomalyKeys: len(anomalyKeys),
MemoryUsage: "N/A", // Redis内存使用信息需要额外命令
}
}
// recordTestResult 记录测试结果到全局报告
func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) {
if globalTestReport == nil {
return
}
globalTestReport.TestResults[name] = TestResult{
Name: name,
Status: status,
Duration: duration,
Error: errorMsg,
Details: details,
}
}
// generateTestReport 生成并打印测试报告
func generateTestReport(t *testing.T) {
if globalTestReport == nil {
return
}
// 计算测试总时长
globalTestReport.EndTime = time.Now()
globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime)
// 统计测试结果
globalTestReport.TotalTests = len(globalTestReport.TestResults)
for _, result := range globalTestReport.TestResults {
if result.Status == "PASS" {
globalTestReport.PassedTests++
} else {
globalTestReport.FailedTests++
}
}
// 打印测试报告
printTestReport(t)
}
// printTestReport 打印详细的测试报告
func printTestReport(t *testing.T) {
if globalTestReport == nil {
return
}
// 使用fmt包来格式化输出
fmt.Println("\n" + strings.Repeat("=", 80))
fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告")
fmt.Println(strings.Repeat("=", 80))
// 基本信息
fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName)
fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05"))
fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05"))
fmt.Printf("⏱️ 总耗时: %v\n", globalTestReport.Duration)
// 测试结果统计
fmt.Printf("\n📊 测试结果统计:\n")
fmt.Printf(" 总测试数: %d\n", globalTestReport.TotalTests)
fmt.Printf(" 通过测试: %d ✅\n", globalTestReport.PassedTests)
fmt.Printf(" 失败测试: %d ❌\n", globalTestReport.FailedTests)
if globalTestReport.TotalTests > 0 {
passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100
fmt.Printf(" 通过率: %.1f%%\n", passRate)
}
// 详细测试结果
if len(globalTestReport.TestResults) > 0 {
fmt.Printf("\n📝 详细测试结果:\n")
for name, result := range globalTestReport.TestResults {
statusIcon := "✅"
if result.Status == "FAIL" {
statusIcon = "❌"
}
fmt.Printf(" %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration)
if result.Error != "" {
fmt.Printf(" 错误: %s\n", result.Error)
}
}
}
// 性能指标
fmt.Printf("\n🚀 性能指标:\n")
fmt.Printf(" 总请求数: %d\n", globalTestReport.Performance.TotalRequests)
fmt.Printf(" 平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime)
fmt.Printf(" 频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits)
fmt.Printf(" 黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits)
fmt.Printf(" 异常检测: %d\n", globalTestReport.Performance.AnomalyDetections)
// Redis统计
fmt.Printf("\n🗄 Redis统计:\n")
fmt.Printf(" 总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys)
fmt.Printf(" 黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys)
fmt.Printf(" 频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys)
fmt.Printf(" 异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys)
// 测试总结
fmt.Printf("\n📈 测试总结:\n")
if globalTestReport.FailedTests == 0 {
fmt.Printf(" 🎉 所有测试通过!安全中间件运行正常。\n")
} else {
fmt.Printf(" ⚠️ 有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests)
}
fmt.Println(strings.Repeat("=", 80))
fmt.Println()
}
func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
startTime := time.Now()
testName := "频率限制测试"
// 清理之前的测试数据
redis.Del("security:ratelimit:ip:192.168.1.100")
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.100")
successCount := 0
rateLimitHits := 0
// 前3次请求应该成功
for i := 0; i < 3; i++ {
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
if w.Code == http.StatusOK {
successCount++
} else {
t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code)
}
}
// 第4次请求应该被拒绝
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
if w.Code == http.StatusTooManyRequests {
rateLimitHits++
} else {
t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code)
}
// 等待窗口过期后再次测试
time.Sleep(11 * time.Second)
w = httptest.NewRecorder()
handler(w, req)
if w.Code == http.StatusOK {
successCount++
} else {
t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code)
}
// 记录测试结果
duration := time.Since(startTime)
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
"successCount": successCount,
"rateLimitHits": rateLimitHits,
"totalRequests": 5,
})
// 更新性能指标
if globalTestReport != nil {
globalTestReport.Performance.TotalRequests += 5
globalTestReport.Performance.RateLimitHits += rateLimitHits
}
}
func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
startTime := time.Now()
testName := "IP黑名单测试"
// 添加IP到黑名单
redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600)
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
blacklistHit := false
if w.Code == http.StatusForbidden {
blacklistHit = true
} else {
t.Errorf("黑名单IP应该被拒绝但得到了状态码 %d", w.Code)
}
// 清理测试数据
redis.Del("security:blacklist:ip:192.168.1.200")
// 记录测试结果
duration := time.Since(startTime)
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
"blacklistHit": blacklistHit,
"blockedIP": "192.168.1.200",
})
// 更新性能指标
if globalTestReport != nil {
globalTestReport.Performance.TotalRequests++
if blacklistHit {
globalTestReport.Performance.BlacklistHits++
}
}
}
func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
startTime := time.Now()
testName := "异常检测测试"
// 测试空User-Agent
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.100")
// 不设置User-Agent
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
anomalyDetected := false
// 异常检测不应该阻止请求,只是记录
if w.Code == http.StatusOK {
anomalyDetected = true
} else {
t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code)
}
// 检查是否记录了异常
key := "security:anomaly:ip:192.168.1.100"
exists, _ := redis.Exists(key)
if !exists {
t.Log("异常检测记录可能已过期或未记录")
}
// 记录测试结果
duration := time.Since(startTime)
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
"anomalyDetected": anomalyDetected,
"anomalyRecorded": exists,
"testIP": "192.168.1.100",
})
// 更新性能指标
if globalTestReport != nil {
globalTestReport.Performance.TotalRequests++
if anomalyDetected {
globalTestReport.Performance.AnomalyDetections++
}
}
}
// 性能基准测试
func BenchmarkSecurityMiddlewarePerformance(b *testing.B) {
// 跳过基准测试,除非明确要求
b.Skip("跳过基准测试需要真实Redis环境")
}

View File

@@ -0,0 +1,150 @@
package security
import (
"net/http/httptest"
"testing"
"tyc-server/app/main/api/internal/config"
)
// 创建测试配置
func createTestConfig() *config.SecurityConfig {
return &config.SecurityConfig{
RateLimit: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
}{
Enabled: true,
WindowSize: 60,
MaxRequests: 5,
TriggerThreshold: 5,
TriggerWindow: 24,
},
IPBlacklist: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
UserBlacklist: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
AnomalyDetection: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
BurstAttack: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
}{
Enabled: true,
TimeWindow: 1,
MaxConcurrent: 20,
},
}
}
// 测试客户端标识生成
func TestClientIDGeneration(t *testing.T) {
config := createTestConfig()
// 使用nil Redis进行测试只测试不依赖Redis的逻辑
middleware := NewSecurityMiddleware(config, nil)
// 测试IP标识
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.100")
clientID := middleware.getClientID(req)
expected := "ip:192.168.1.100"
if clientID != expected {
t.Errorf("期望客户端标识 %s但得到了 %s", expected, clientID)
}
}
// 测试真实IP获取
func TestRealIPExtraction(t *testing.T) {
config := createTestConfig()
middleware := NewSecurityMiddleware(config, nil)
// 测试X-Forwarded-For
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Forwarded-For", "203.0.113.1, 192.168.1.1")
ip := middleware.getClientIP(req)
expected := "203.0.113.1"
if ip != expected {
t.Errorf("期望IP %s但得到了 %s", expected, ip)
}
// 测试X-Real-IP创建新的请求对象
req2 := httptest.NewRequest("GET", "/test", nil)
req2.Header.Set("X-Real-IP", "198.51.100.1")
ip = middleware.getClientIP(req2)
expected = "198.51.100.1"
if ip != expected {
t.Errorf("期望IP %s但得到了 %s", expected, ip)
}
// 测试直接连接
req3 := httptest.NewRequest("GET", "/test", nil)
req3.RemoteAddr = "192.168.1.50:12345"
ip = middleware.getClientIP(req3)
expected = "192.168.1.50"
if ip != expected {
t.Errorf("期望IP %s但得到了 %s", expected, ip)
}
// 测试优先级X-Forwarded-For 应该优先于 X-Real-IP
req4 := httptest.NewRequest("GET", "/test", nil)
req4.Header.Set("X-Forwarded-For", "10.0.0.1, 10.0.0.2")
req4.Header.Set("X-Real-IP", "10.0.0.3")
ip = middleware.getClientIP(req4)
expected = "10.0.0.1" // X-Forwarded-For 应该优先
if ip != expected {
t.Errorf("优先级测试失败期望IP %s但得到了 %s", expected, ip)
}
}
// 测试中间件创建
func TestNewSecurityMiddleware(t *testing.T) {
config := createTestConfig()
middleware := NewSecurityMiddleware(config, nil)
if middleware == nil {
t.Error("中间件创建失败")
}
if middleware.config != config {
t.Error("配置设置失败")
}
}
// 测试配置验证
func TestConfigValidation(t *testing.T) {
config := createTestConfig()
if !config.RateLimit.Enabled {
t.Error("频率限制应该启用")
}
if config.RateLimit.MaxRequests != 5 {
t.Error("最大请求数设置错误")
}
if !config.IPBlacklist.Enabled {
t.Error("IP黑名单应该启用")
}
if !config.UserBlacklist.Enabled {
t.Error("用户黑名单应该启用")
}
if !config.AnomalyDetection.Enabled {
t.Error("异常检测应该启用")
}
}

View File

@@ -0,0 +1,74 @@
package types
// GetSecurityStatsReq 获取安全统计信息请求
type GetSecurityStatsReq struct {
}
// GetSecurityStatsResp 获取安全统计信息响应
type GetSecurityStatsResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data map[string]interface{} `json:"data"`
}
// GetBlacklistReq 获取黑名单请求
type GetBlacklistReq struct {
Page int `form:"page,default=1"`
PageSize int `form:"pageSize,default=20"`
Type string `form:"type,optional"` // ip 或 user
}
// GetBlacklistResp 获取黑名单响应
type GetBlacklistResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data []BlacklistItem `json:"data"`
}
// BlacklistItem 黑名单项
type BlacklistItem struct {
Type string `json:"type"` // ip 或 user
Identifier string `json:"identifier"` // IP地址或用户ID
ExpireAt int64 `json:"expireAt"` // 过期时间戳
CreatedAt int64 `json:"createdAt"` // 创建时间戳
}
// AddToBlacklistReq 添加到黑名单请求
type AddToBlacklistReq struct {
ClientType string `json:"clientType"` // ip 或 user
Identifier string `json:"identifier"` // IP地址或用户ID
Duration string `json:"duration"` // 持续时间,如 "1h", "24h"
Reason string `json:"reason"` // 拉黑原因
}
// AddToBlacklistResp 添加到黑名单响应
type AddToBlacklistResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
}
// RemoveFromBlacklistReq 从黑名单移除请求
type RemoveFromBlacklistReq struct {
ClientType string `json:"clientType"` // ip 或 user
Identifier string `json:"identifier"` // IP地址或用户ID
}
// RemoveFromBlacklistResp 从黑名单移除响应
type RemoveFromBlacklistResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
}
// GetSecurityEventsReq 获取安全事件请求
type GetSecurityEventsReq struct {
EventType string `form:"eventType,optional"` // 事件类型
ClientID string `form:"clientID,optional"` // 客户端ID
Limit int `form:"limit,default=50"` // 限制数量
}
// GetSecurityEventsResp 获取安全事件响应
type GetSecurityEventsResp struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data []string `json:"data"`
}

View File

@@ -7,6 +7,7 @@ import (
"tyc-server/app/main/api/internal/config"
"tyc-server/app/main/api/internal/handler"
middleware "tyc-server/app/main/api/internal/middleware/global"
security "tyc-server/app/main/api/internal/middleware/security"
"tyc-server/app/main/api/internal/queue"
"tyc-server/app/main/api/internal/svc"
@@ -58,6 +59,10 @@ func main() {
// 全局中间件
server.Use(middleware.ReqHeaderCtxMiddleware)
// 全局安全中间件
securityMiddleware := security.NewSecurityMiddleware(&c.Security, svcContext.Redis)
server.Use(securityMiddleware.Handle)
defer server.Stop()
handler.RegisterHandlers(server, svcContext)

5
go.mod
View File

@@ -18,6 +18,7 @@ require (
github.com/shopspring/decimal v1.4.0
github.com/smartwalle/alipay/v3 v3.2.25
github.com/sony/sonyflake v1.2.1
github.com/stretchr/testify v1.10.0
github.com/tidwall/gjson v1.18.0
github.com/wechatpay-apiv3/wechatpay-go v0.2.20
github.com/zeromicro/go-zero v1.8.3
@@ -38,6 +39,7 @@ require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/clbanning/mxj/v2 v2.7.0 // indirect
github.com/cloudwego/base64x v0.1.5 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
@@ -60,6 +62,7 @@ require (
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/openzipkin/zipkin-go v0.4.3 // indirect
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/prometheus/client_golang v1.22.0 // indirect
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.64.0 // indirect
@@ -71,6 +74,7 @@ require (
github.com/smartwalle/nsign v1.0.9 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/spf13/cast v1.8.0 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tjfoc/gmsm v1.4.1 // indirect
@@ -98,4 +102,5 @@ require (
google.golang.org/protobuf v1.36.6 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)