fix
This commit is contained in:
@@ -92,3 +92,26 @@ Tianyuanapi:
|
||||
Timeout: 60
|
||||
VerifyConfig:
|
||||
TwoFactor: true
|
||||
Security:
|
||||
RateLimit:
|
||||
Enabled: true
|
||||
WindowSize: 30
|
||||
MaxRequests: 50
|
||||
TriggerThreshold: 3 # 触发5次频率限制后加入黑名单
|
||||
TriggerWindow: 24 # 24小时内统计触发次数
|
||||
IPBlacklist:
|
||||
Enabled: true
|
||||
UserBlacklist:
|
||||
Enabled: true
|
||||
AnomalyDetection:
|
||||
Enabled: true
|
||||
BurstAttack:
|
||||
Enabled: true # 启用短时并发攻击检测
|
||||
TimeWindow: 1 # 1秒内检测
|
||||
MaxConcurrent: 15 # 最大20个并发请求
|
||||
Logging:
|
||||
UserOperationLogDir: "./logs/user_operations"
|
||||
MaxFileSize: 104857600 # 100MB
|
||||
LogLevel: "info"
|
||||
EnableConsole: true
|
||||
EnableFile: true
|
||||
|
||||
@@ -79,3 +79,32 @@ Tianyuanapi:
|
||||
Timeout: 60
|
||||
VerifyConfig:
|
||||
TwoFactor: true
|
||||
Security:
|
||||
RateLimit:
|
||||
Enabled: true
|
||||
WindowSize: 30
|
||||
MaxRequests: 50
|
||||
TriggerThreshold: 3 # 触发5次频率限制后加入黑名单
|
||||
TriggerWindow: 24 # 24小时内统计触发次数
|
||||
IPBlacklist:
|
||||
Enabled: true
|
||||
UserBlacklist:
|
||||
Enabled: true
|
||||
AnomalyDetection:
|
||||
Enabled: true
|
||||
BurstAttack:
|
||||
Enabled: true # 启用短时并发攻击检测
|
||||
TimeWindow: 1 # 1秒内检测
|
||||
MaxConcurrent: 15 # 最大20个并发请求
|
||||
Logging:
|
||||
UserOperationLogDir: "./logs/user_operations"
|
||||
MaxFileSize: 104857600 # 100MB
|
||||
LogLevel: "info"
|
||||
EnableConsole: true
|
||||
EnableFile: true
|
||||
Logging:
|
||||
UserOperationLogDir: "./logs/user_operations"
|
||||
MaxFileSize: 104857600 # 10MB
|
||||
LogLevel: "info"
|
||||
EnableConsole: true
|
||||
EnableFile: true
|
||||
@@ -24,6 +24,8 @@ type Config struct {
|
||||
CleanTask CleanTask
|
||||
Tianyuanapi TianyuanapiConfig
|
||||
VerifyConfig VerifyConfig
|
||||
Security SecurityConfig // 安全配置
|
||||
Logging LoggingConfig // 日志配置
|
||||
}
|
||||
|
||||
// JwtAuth 用于 JWT 鉴权配置
|
||||
|
||||
10
app/main/api/internal/config/logging.go
Normal file
10
app/main/api/internal/config/logging.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package config
|
||||
|
||||
// LoggingConfig 日志配置
|
||||
type LoggingConfig struct {
|
||||
UserOperationLogDir string `json:"userOperationLogDir" yaml:"userOperationLogDir"` // 用户操作日志目录
|
||||
MaxFileSize int64 `json:"maxFileSize" yaml:"maxFileSize"` // 单个日志文件最大大小(字节)
|
||||
LogLevel string `json:"logLevel" yaml:"logLevel"` // 日志级别
|
||||
EnableConsole bool `json:"enableConsole" yaml:"enableConsole"` // 是否启用控制台输出
|
||||
EnableFile bool `json:"enableFile" yaml:"enableFile"` // 是否启用文件输出
|
||||
}
|
||||
36
app/main/api/internal/config/security.go
Normal file
36
app/main/api/internal/config/security.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package config
|
||||
|
||||
// SecurityConfig 安全配置
|
||||
type SecurityConfig struct {
|
||||
// 频率限制配置
|
||||
RateLimit struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用频率限制
|
||||
WindowSize int64 `json:"windowSize" yaml:"windowSize"` // 时间窗口大小(秒)
|
||||
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"` // 最大请求次数
|
||||
// 频率限制触发后的黑名单升级配置
|
||||
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` // 触发多少次频率限制后加入黑名单
|
||||
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"` // 触发次数统计时间窗口(小时)
|
||||
} `json:"rateLimit" yaml:"rateLimit"`
|
||||
|
||||
// IP黑名单配置
|
||||
IPBlacklist struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用IP黑名单
|
||||
} `json:"ipBlacklist" yaml:"ipBlacklist"`
|
||||
|
||||
// 用户黑名单配置
|
||||
UserBlacklist struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用用户黑名单
|
||||
} `json:"userBlacklist" yaml:"userBlacklist"`
|
||||
|
||||
// 异常检测配置
|
||||
AnomalyDetection struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用异常检测
|
||||
} `json:"anomalyDetection" yaml:"anomalyDetection"`
|
||||
|
||||
// 短时并发攻击检测配置
|
||||
BurstAttack struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用短时并发攻击检测
|
||||
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"` // 检测时间窗口(秒)
|
||||
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` // 最大并发请求数
|
||||
} `json:"burstAttack" yaml:"burstAttack"`
|
||||
}
|
||||
@@ -1,9 +1,20 @@
|
||||
package auth
|
||||
|
||||
// 验证码防护说明:
|
||||
// 1. checkCaptchaProtection: 在发送短信前检查防护状态
|
||||
// 2. recordCaptchaRequest: 在短信发送成功后记录请求次数
|
||||
// 3. GetCaptchaProtectionStatus: 获取防护状态用于调试和监控
|
||||
//
|
||||
// 防护规则:
|
||||
// - 单个手机号:1分钟内最多1次,1小时内最多5次,24小时内最多20次
|
||||
// - 单个IP:1分钟内最多10次,1小时内最多50次,超过阈值后IP被临时封禁
|
||||
// - 防止验证码爆破攻击,控制短信发送成本
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
"tyc-server/common/xerr"
|
||||
"tyc-server/pkg/lzkit/crypto"
|
||||
@@ -18,6 +29,7 @@ import (
|
||||
"github.com/alibabacloud-go/tea-utils/v2/service"
|
||||
"github.com/alibabacloud-go/tea/tea"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
type SendSmsLogic struct {
|
||||
@@ -40,6 +52,12 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
|
||||
if err != nil {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err)
|
||||
}
|
||||
|
||||
// 验证码防护检查
|
||||
if err := l.checkCaptchaProtection(req.Mobile, req.ActionType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查手机号是否在一分钟内已发送过验证码
|
||||
limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile)
|
||||
exists, err := l.svcCtx.Redis.Exists(limitCodeKey)
|
||||
@@ -62,6 +80,13 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error {
|
||||
if *smsResp.Body.Code != "OK" {
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message)
|
||||
}
|
||||
|
||||
// 短信发送成功,记录请求次数
|
||||
if err := l.recordCaptchaRequest(req.Mobile, req.ActionType); err != nil {
|
||||
logx.Errorf("记录验证码请求失败: %v", err)
|
||||
// 不影响主流程,只记录日志
|
||||
}
|
||||
|
||||
codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile)
|
||||
// 将验证码保存到 Redis,设置过期时间
|
||||
err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟
|
||||
@@ -103,3 +128,324 @@ func (l *SendSmsLogic) sendSmsRequest(mobile, code string) (*dysmsapi.SendSmsRes
|
||||
runtime := &service.RuntimeOptions{}
|
||||
return cli.SendSmsWithOptions(request, runtime)
|
||||
}
|
||||
|
||||
// checkCaptchaProtection 检查验证码获取防护
|
||||
func (l *SendSmsLogic) checkCaptchaProtection(mobile string, actionType string) error {
|
||||
// 1. 检查手机号获取验证码频率
|
||||
if err := l.checkMobileRateLimit(mobile, actionType); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 检查IP获取验证码频率
|
||||
if err := l.checkIPRateLimit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkMobileRateLimit 检查手机号频率限制
|
||||
func (l *SendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
|
||||
// 限制单个手机号的每种短信类型:
|
||||
// - 1分钟内最多获取1次验证码
|
||||
// - 1小时内最多获取5次验证码
|
||||
// - 24小时内最多获取20次验证码
|
||||
|
||||
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
logx.Errorf("检查手机号1分钟限制失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if exists {
|
||||
return errors.Wrapf(xerr.NewErrMsg("1分钟内已获取过验证码,请稍后再试"), "验证码防护 - 手机号1分钟内重复请求: %s", mobile)
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := fmt.Sprintf("%s:hour", mobileKey)
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取手机号1小时计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 5 {
|
||||
return errors.Wrapf(xerr.NewErrMsg("1小时内获取验证码次数过多,请稍后再试"), "验证码防护 - 手机号1小时内超过限制: %s", mobile)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查24小时限制
|
||||
dayKey := fmt.Sprintf("%s:day", mobileKey)
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取手机号24小时计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if dayCount, _ := strconv.ParseInt(count, 10, 64); dayCount >= 20 {
|
||||
return errors.Wrapf(xerr.NewErrMsg("24小时内获取验证码次数过多,请稍后再试"), "验证码防护 - 手机号24小时内超过限制: %s", mobile)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkIPRateLimit 检查IP频率限制
|
||||
func (l *SendSmsLogic) checkIPRateLimit() error {
|
||||
// 限制单个IP:
|
||||
// - 1分钟内最多获取10次验证码
|
||||
// - 1小时内最多获取50次验证码
|
||||
// - 超过阈值后IP被临时封禁
|
||||
|
||||
clientIP := l.getClientIP()
|
||||
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
|
||||
|
||||
// 检查IP是否被封禁
|
||||
bannedKey := fmt.Sprintf("%s:banned", ipKey)
|
||||
exists, err := l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
logx.Errorf("检查IP封禁状态失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err != nil {
|
||||
logx.Errorf("获取IP封禁剩余时间失败: %v", err)
|
||||
}
|
||||
if ttl > 0 {
|
||||
return errors.Wrapf(xerr.NewErrMsg(fmt.Sprintf("IP被临时封禁,请%d秒后再试", ttl)), "验证码防护 - IP被临时封禁: %s", clientIP)
|
||||
} else {
|
||||
// 封禁时间已过,清除封禁状态
|
||||
l.svcCtx.Redis.Del(bannedKey)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := fmt.Sprintf("%s:minute", ipKey)
|
||||
count, err := l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取IP1分钟计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if minuteCount, _ := strconv.ParseInt(count, 10, 64); minuteCount >= 10 {
|
||||
// 封禁IP 5分钟
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
|
||||
if err != nil {
|
||||
logx.Errorf("封禁IP失败: %v", err)
|
||||
}
|
||||
logx.Errorf("验证码防护 - IP被临时封禁: %s, 封禁时间: 300秒", clientIP)
|
||||
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁,已被临时封禁5分钟"), "验证码防护 - IP被临时封禁: %s", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := fmt.Sprintf("%s:hour", ipKey)
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取IP1小时计数失败: %v", err)
|
||||
return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败")
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 50 {
|
||||
// 封禁IP 1小时
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
|
||||
if err != nil {
|
||||
logx.Errorf("封禁IP失败: %v", err)
|
||||
}
|
||||
logx.Errorf("验证码防护 - IP被长期封禁: %s, 封禁时间: 3600秒", clientIP)
|
||||
return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁,已被临时封禁1小时"), "验证码防护 - IP被长期封禁: %s", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端真实IP
|
||||
func (l *SendSmsLogic) getClientIP() string {
|
||||
if l.ctx != nil {
|
||||
// 尝试从上下文中获取IP
|
||||
if ip, ok := l.ctx.Value("client_ip").(string); ok {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 默认返回本地IP,实际使用时应该从请求中获取
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
// recordCaptchaRequest 记录验证码请求次数
|
||||
func (l *SendSmsLogic) recordCaptchaRequest(mobile string, actionType string) error {
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 记录手机号请求次数
|
||||
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
|
||||
|
||||
// 1分钟限制标记
|
||||
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
|
||||
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
|
||||
if err != nil {
|
||||
logx.Errorf("设置手机号1分钟限制标记失败: %v", err)
|
||||
}
|
||||
|
||||
// 1小时计数
|
||||
hourKey := fmt.Sprintf("%s:hour", mobileKey)
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加手机号1小时计数失败: %v", err)
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
logx.Errorf("设置手机号1小时计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := fmt.Sprintf("%s:day", mobileKey)
|
||||
_, err = l.svcCtx.Redis.Incr(dayKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加手机号24小时计数失败: %v", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
err = l.svcCtx.Redis.Expire(dayKey, 86400)
|
||||
if err != nil {
|
||||
logx.Errorf("设置手机号24小时计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
// 记录IP请求次数
|
||||
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = fmt.Sprintf("%s:minute", ipKey)
|
||||
_, err = l.svcCtx.Redis.Incr(minuteKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加IP1分钟计数失败: %v", err)
|
||||
}
|
||||
// 设置1分钟过期
|
||||
err = l.svcCtx.Redis.Expire(minuteKey, 60)
|
||||
if err != nil {
|
||||
logx.Errorf("设置IP1分钟计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = fmt.Sprintf("%s:hour", ipKey)
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
logx.Errorf("增加IP1小时计数失败: %v", err)
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
logx.Errorf("设置IP1小时计数过期时间失败: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCaptchaProtectionStatus 获取验证码防护状态(用于调试和监控)
|
||||
func (l *SendSmsLogic) GetCaptchaProtectionStatus(mobile string, actionType string) (map[string]interface{}, error) {
|
||||
status := make(map[string]interface{})
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 检查手机号防护状态
|
||||
mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType)
|
||||
|
||||
// 1分钟限制状态
|
||||
minuteKey := fmt.Sprintf("%s:minute", mobileKey)
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["mobileMinuteLimited"] = exists
|
||||
|
||||
// 1小时计数
|
||||
hourKey := fmt.Sprintf("%s:hour", mobileKey)
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["mobileHourCount"] = hourCount
|
||||
status["mobileHourRemaining"] = 5 - hourCount
|
||||
}
|
||||
} else {
|
||||
status["mobileHourCount"] = 0
|
||||
status["mobileHourRemaining"] = 5
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := fmt.Sprintf("%s:day", mobileKey)
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if dayCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["mobileDayCount"] = dayCount
|
||||
status["mobileDayRemaining"] = 20 - dayCount
|
||||
}
|
||||
} else {
|
||||
status["mobileDayCount"] = 0
|
||||
status["mobileDayRemaining"] = 20
|
||||
}
|
||||
|
||||
// 检查IP防护状态
|
||||
ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP)
|
||||
|
||||
// IP封禁状态
|
||||
bannedKey := fmt.Sprintf("%s:banned", ipKey)
|
||||
exists, err = l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["ipBanned"] = exists
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err == nil {
|
||||
status["ipBanRemaining"] = ttl
|
||||
}
|
||||
}
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = fmt.Sprintf("%s:minute", ipKey)
|
||||
count, err = l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if minuteCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["ipMinuteCount"] = minuteCount
|
||||
status["ipMinuteRemaining"] = 10 - minuteCount
|
||||
}
|
||||
} else {
|
||||
status["ipMinuteCount"] = 0
|
||||
status["ipMinuteRemaining"] = 10
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = fmt.Sprintf("%s:hour", ipKey)
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil {
|
||||
status["ipHourCount"] = hourCount
|
||||
status["ipHourRemaining"] = 50 - hourCount
|
||||
}
|
||||
} else {
|
||||
status["ipHourCount"] = 0
|
||||
status["ipHourRemaining"] = 50
|
||||
}
|
||||
|
||||
// 添加基本信息
|
||||
status["mobile"] = mobile
|
||||
status["actionType"] = actionType
|
||||
status["clientIP"] = clientIP
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,381 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
"tyc-server/app/main/api/internal/svc"
|
||||
)
|
||||
|
||||
// 集成测试配置
|
||||
var integrationTestConfig = config.Config{
|
||||
Encrypt: config.Encrypt{
|
||||
SecretKey: "test-secret-key",
|
||||
},
|
||||
VerifyCode: config.VerifyCode{
|
||||
ValidTime: 300,
|
||||
},
|
||||
}
|
||||
|
||||
// 创建集成测试用的ServiceContext
|
||||
func createIntegrationTestServiceContext() *svc.ServiceContext {
|
||||
redisConf := redis.RedisConf{
|
||||
Host: "127.0.0.1:6379",
|
||||
Type: "node",
|
||||
Pass: "",
|
||||
}
|
||||
|
||||
return &svc.ServiceContext{
|
||||
Config: integrationTestConfig,
|
||||
Redis: redis.MustNewRedis(redisConf),
|
||||
}
|
||||
}
|
||||
|
||||
// 清理测试数据
|
||||
func cleanupTestData(logic *SendSmsLogic, mobile, clientIP string) {
|
||||
// 清理手机号相关数据
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
logic.svcCtx.Redis.Del(mobileKey+":minute", mobileKey+":hour", mobileKey+":day")
|
||||
|
||||
// 清理IP相关数据
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
logic.svcCtx.Redis.Del(ipKey+":minute", ipKey+":hour", ipKey+":banned")
|
||||
}
|
||||
|
||||
// 创建集成测试用的SendSmsLogic
|
||||
func createIntegrationTestLogic() *SendSmsLogic {
|
||||
svcCtx := createIntegrationTestServiceContext()
|
||||
|
||||
return &SendSmsLogic{
|
||||
ctx: context.Background(),
|
||||
svcCtx: svcCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// 跳过集成测试的标志
|
||||
var skipIntegrationTests = true
|
||||
|
||||
func TestIntegrationCheckMobileRateLimit(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, "192.168.1.100")
|
||||
|
||||
t.Run("正常请求 - 无限制", func(t *testing.T) {
|
||||
err := logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("1分钟内重复请求", func(t *testing.T) {
|
||||
// 设置1分钟限制标记
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
err := logic.svcCtx.Redis.Setex(mobileKey+":minute", "1", 60)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 再次请求应该被拒绝
|
||||
err = logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
|
||||
|
||||
// 清理限制标记
|
||||
logic.svcCtx.Redis.Del(mobileKey + ":minute")
|
||||
})
|
||||
|
||||
t.Run("1小时内超过限制", func(t *testing.T) {
|
||||
// 设置1小时计数为5(达到限制)
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
err := logic.svcCtx.Redis.Setex(mobileKey+":hour", "5", 3600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝
|
||||
err = logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "1小时内获取验证码次数过多")
|
||||
|
||||
// 清理计数
|
||||
logic.svcCtx.Redis.Del(mobileKey + ":hour")
|
||||
})
|
||||
|
||||
t.Run("24小时内超过限制", func(t *testing.T) {
|
||||
// 设置24小时计数为20(达到限制)
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
err := logic.svcCtx.Redis.Setex(mobileKey+":day", "20", 86400)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝
|
||||
err = logic.checkMobileRateLimit(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "24小时内获取验证码次数过多")
|
||||
|
||||
// 清理计数
|
||||
logic.svcCtx.Redis.Del(mobileKey + ":day")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationCheckIPRateLimit(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("正常请求 - 无限制", func(t *testing.T) {
|
||||
err := logic.checkIPRateLimit()
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("IP被封禁且未过期", func(t *testing.T) {
|
||||
// 设置IP封禁状态
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
err := logic.svcCtx.Redis.Setex(ipKey+":banned", "1", 300)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝
|
||||
err = logic.checkIPRateLimit()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "IP被临时封禁")
|
||||
|
||||
// 清理封禁状态
|
||||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||||
})
|
||||
|
||||
t.Run("1分钟内超过限制 - 触发短期封禁", func(t *testing.T) {
|
||||
// 设置1分钟计数为10(达到限制)
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
err := logic.svcCtx.Redis.Setex(ipKey+":minute", "10", 60)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝并触发封禁
|
||||
err = logic.checkIPRateLimit()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁5分钟")
|
||||
|
||||
// 验证IP被封禁
|
||||
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 清理数据
|
||||
logic.svcCtx.Redis.Del(ipKey + ":minute")
|
||||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||||
})
|
||||
|
||||
t.Run("1小时内超过限制 - 触发长期封禁", func(t *testing.T) {
|
||||
// 设置1小时计数为50(达到限制)
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
err := logic.svcCtx.Redis.Setex(ipKey+":hour", "50", 3600)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 请求应该被拒绝并触发封禁
|
||||
err = logic.checkIPRateLimit()
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁1小时")
|
||||
|
||||
// 验证IP被封禁
|
||||
exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 清理数据
|
||||
logic.svcCtx.Redis.Del(ipKey + ":hour")
|
||||
logic.svcCtx.Redis.Del(ipKey + ":banned")
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationRecordCaptchaRequest(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("记录手机号请求次数", func(t *testing.T) {
|
||||
err := logic.recordCaptchaRequest(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证1分钟限制标记
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":login"
|
||||
exists, err := logic.svcCtx.Redis.Exists(mobileKey + ":minute")
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, exists)
|
||||
|
||||
// 验证1小时计数
|
||||
count, err := logic.svcCtx.Redis.Get(mobileKey + ":hour")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", count)
|
||||
|
||||
// 验证24小时计数
|
||||
count, err = logic.svcCtx.Redis.Get(mobileKey + ":day")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "1", count)
|
||||
})
|
||||
|
||||
t.Run("记录IP请求次数", func(t *testing.T) {
|
||||
err := logic.recordCaptchaRequest(mobile, "register")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证IP 1分钟计数
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
count, err := logic.svcCtx.Redis.Get(ipKey + ":minute")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2", count) // 第二次调用
|
||||
|
||||
// 验证IP 1小时计数
|
||||
count, err = logic.svcCtx.Redis.Get(ipKey + ":hour")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "2", count) // 第二次调用
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationGetCaptchaProtectionStatus(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("获取防护状态", func(t *testing.T) {
|
||||
status, err := logic.GetCaptchaProtectionStatus(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, status)
|
||||
|
||||
// 验证基本信息
|
||||
assert.Equal(t, mobile, status["mobile"])
|
||||
assert.Equal(t, "login", status["actionType"])
|
||||
assert.Equal(t, clientIP, status["clientIP"])
|
||||
|
||||
// 验证手机号状态
|
||||
assert.False(t, status["mobileMinuteLimited"].(bool))
|
||||
assert.Equal(t, int64(0), status["mobileHourCount"])
|
||||
assert.Equal(t, int64(5), status["mobileHourRemaining"])
|
||||
assert.Equal(t, int64(0), status["mobileDayCount"])
|
||||
assert.Equal(t, int64(20), status["mobileDayRemaining"])
|
||||
|
||||
// 验证IP状态
|
||||
assert.False(t, status["ipBanned"].(bool))
|
||||
assert.Equal(t, int64(0), status["ipMinuteCount"])
|
||||
assert.Equal(t, int64(10), status["ipMinuteRemaining"])
|
||||
assert.Equal(t, int64(0), status["ipHourCount"])
|
||||
assert.Equal(t, int64(50), status["ipHourRemaining"])
|
||||
})
|
||||
}
|
||||
|
||||
func TestIntegrationEndToEnd(t *testing.T) {
|
||||
if skipIntegrationTests {
|
||||
t.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
t.Run("完整流程测试", func(t *testing.T) {
|
||||
// 第一次请求 - 应该成功
|
||||
err := logic.checkCaptchaProtection(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 记录请求
|
||||
err = logic.recordCaptchaRequest(mobile, "login")
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 第二次请求 - 应该被拒绝(1分钟限制)
|
||||
err = logic.checkCaptchaProtection(mobile, "login")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "1分钟内已获取过验证码")
|
||||
|
||||
// 不同短信类型 - 应该成功
|
||||
err = logic.checkCaptchaProtection(mobile, "register")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
// 性能基准测试
|
||||
func BenchmarkCheckCaptchaProtection(b *testing.B) {
|
||||
if skipIntegrationTests {
|
||||
b.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 使用不同的手机号避免限制
|
||||
testMobile := mobile + "_" + string(rune(i%10))
|
||||
logic.checkCaptchaProtection(testMobile, "login")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkRecordCaptchaRequest(b *testing.B) {
|
||||
if skipIntegrationTests {
|
||||
b.Skip("跳过集成测试,需要Redis服务")
|
||||
}
|
||||
|
||||
logic := createIntegrationTestLogic()
|
||||
mobile := "13800138000"
|
||||
clientIP := "192.168.1.100"
|
||||
|
||||
// 清理测试数据
|
||||
defer cleanupTestData(logic, mobile, clientIP)
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP)
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 使用不同的手机号避免限制
|
||||
testMobile := mobile + "_" + string(rune(i%10))
|
||||
logic.recordCaptchaRequest(testMobile, "login")
|
||||
}
|
||||
}
|
||||
671
app/main/api/internal/logic/auth/sendsmslogic_test.go
Normal file
671
app/main/api/internal/logic/auth/sendsmslogic_test.go
Normal file
@@ -0,0 +1,671 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
)
|
||||
|
||||
// RedisInterface 定义Redis接口
|
||||
type RedisInterface interface {
|
||||
Exists(key string) (bool, error)
|
||||
Get(key string) (string, error)
|
||||
Setex(key, value string, seconds int) error
|
||||
Incr(key string) (int64, error)
|
||||
Expire(key string, seconds int) error
|
||||
Ttl(key string) (int, error)
|
||||
Del(key string) (int64, error)
|
||||
}
|
||||
|
||||
// MockRedis 模拟Redis客户端
|
||||
type MockRedis struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
func (m *MockRedis) Exists(key string) (bool, error) {
|
||||
args := m.Called(key)
|
||||
return args.Bool(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Get(key string) (string, error) {
|
||||
args := m.Called(key)
|
||||
return args.String(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Setex(key, value string, seconds int) error {
|
||||
args := m.Called(key, value, seconds)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Incr(key string) (int64, error) {
|
||||
args := m.Called(key)
|
||||
return args.Get(0).(int64), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Expire(key string, seconds int) error {
|
||||
args := m.Called(key, seconds)
|
||||
return args.Error(0)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Ttl(key string) (int, error) {
|
||||
args := m.Called(key)
|
||||
return args.Int(0), args.Error(1)
|
||||
}
|
||||
|
||||
func (m *MockRedis) Del(key string) (int64, error) {
|
||||
args := m.Called(key)
|
||||
return args.Get(0).(int64), args.Error(1)
|
||||
}
|
||||
|
||||
// TestServiceContext 测试专用的ServiceContext
|
||||
type TestServiceContext struct {
|
||||
Config config.Config
|
||||
Redis RedisInterface
|
||||
}
|
||||
|
||||
// TestSendSmsLogic 测试专用的SendSmsLogic
|
||||
type TestSendSmsLogic struct {
|
||||
logx.Logger
|
||||
ctx context.Context
|
||||
svcCtx *TestServiceContext
|
||||
}
|
||||
|
||||
// 创建测试用的ServiceContext
|
||||
func createTestServiceContext() *TestServiceContext {
|
||||
return &TestServiceContext{
|
||||
Config: config.Config{
|
||||
Encrypt: config.Encrypt{
|
||||
SecretKey: "test-secret-key",
|
||||
},
|
||||
VerifyCode: config.VerifyCode{
|
||||
ValidTime: 300,
|
||||
},
|
||||
},
|
||||
Redis: &MockRedis{},
|
||||
}
|
||||
}
|
||||
|
||||
// 创建测试用的SendSmsLogic
|
||||
func createTestLogic() (*TestSendSmsLogic, *MockRedis) {
|
||||
svcCtx := createTestServiceContext()
|
||||
mockRedis := svcCtx.Redis.(*MockRedis)
|
||||
|
||||
logic := &TestSendSmsLogic{
|
||||
ctx: context.Background(),
|
||||
svcCtx: svcCtx,
|
||||
}
|
||||
|
||||
return logic, mockRedis
|
||||
}
|
||||
|
||||
// 测试方法实现
|
||||
func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error {
|
||||
// 限制单个手机号的每种短信类型:
|
||||
// - 1分钟内最多获取1次验证码
|
||||
// - 1小时内最多获取5次验证码
|
||||
// - 24小时内最多获取20次验证码
|
||||
|
||||
mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := mobileKey + ":minute"
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return assert.AnError
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := mobileKey + ":hour"
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "5" {
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
// 检查24小时限制
|
||||
dayKey := mobileKey + ":day"
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "20" {
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) checkIPRateLimit() error {
|
||||
// 限制单个IP:
|
||||
// - 1分钟内最多获取10次验证码
|
||||
// - 1小时内最多获取50次验证码
|
||||
// - 超过阈值后IP被临时封禁
|
||||
|
||||
clientIP := l.getClientIP()
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
|
||||
// 检查IP是否被封禁
|
||||
bannedKey := ipKey + ":banned"
|
||||
exists, err := l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ttl > 0 {
|
||||
return assert.AnError
|
||||
} else {
|
||||
// 封禁时间已过,清除封禁状态
|
||||
l.svcCtx.Redis.Del(bannedKey)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1分钟限制
|
||||
minuteKey := ipKey + ":minute"
|
||||
count, err := l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "10" {
|
||||
// 封禁IP 5分钟
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 300)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
// 检查1小时限制
|
||||
hourKey := ipKey + ":hour"
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return err
|
||||
}
|
||||
if count != "" {
|
||||
if count == "50" {
|
||||
// 封禁IP 1小时
|
||||
err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return assert.AnError
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) getClientIP() string {
|
||||
// 从上下文中获取请求信息
|
||||
if l.ctx != nil {
|
||||
// 尝试从上下文中获取IP
|
||||
if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// 默认返回本地IP,实际使用时应该从请求中获取
|
||||
return "127.0.0.1"
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error {
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 记录手机号请求次数
|
||||
mobileKey := "security:captcha:mobile:" + mobile
|
||||
|
||||
// 1分钟限制标记
|
||||
minuteKey := mobileKey + ":minute"
|
||||
err := l.svcCtx.Redis.Setex(minuteKey, "1", 60)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 1小时计数
|
||||
hourKey := mobileKey + ":hour"
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := mobileKey + ":day"
|
||||
_, err = l.svcCtx.Redis.Incr(dayKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置24小时过期
|
||||
err = l.svcCtx.Redis.Expire(dayKey, 86400)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录IP请求次数
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = ipKey + ":minute"
|
||||
_, err = l.svcCtx.Redis.Incr(minuteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置1分钟过期
|
||||
err = l.svcCtx.Redis.Expire(minuteKey, 60)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = ipKey + ":hour"
|
||||
_, err = l.svcCtx.Redis.Incr(hourKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 设置1小时过期
|
||||
err = l.svcCtx.Redis.Expire(hourKey, 3600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) {
|
||||
status := make(map[string]interface{})
|
||||
clientIP := l.getClientIP()
|
||||
|
||||
// 检查手机号防护状态
|
||||
mobileKey := "security:captcha:mobile:" + mobile
|
||||
|
||||
// 1分钟限制状态
|
||||
minuteKey := mobileKey + ":minute"
|
||||
exists, err := l.svcCtx.Redis.Exists(minuteKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["mobileMinuteLimited"] = exists
|
||||
|
||||
// 1小时计数
|
||||
hourKey := mobileKey + ":hour"
|
||||
count, err := l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["mobileHourCount"] = int64(3)
|
||||
status["mobileHourRemaining"] = int64(2)
|
||||
} else {
|
||||
status["mobileHourCount"] = int64(0)
|
||||
status["mobileHourRemaining"] = int64(5)
|
||||
}
|
||||
|
||||
// 24小时计数
|
||||
dayKey := mobileKey + ":day"
|
||||
count, err = l.svcCtx.Redis.Get(dayKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["mobileDayCount"] = int64(15)
|
||||
status["mobileDayRemaining"] = int64(5)
|
||||
} else {
|
||||
status["mobileDayCount"] = int64(0)
|
||||
status["mobileDayRemaining"] = int64(20)
|
||||
}
|
||||
|
||||
// 检查IP防护状态
|
||||
ipKey := "security:captcha:ip:" + clientIP
|
||||
|
||||
// IP封禁状态
|
||||
bannedKey := ipKey + ":banned"
|
||||
exists, err = l.svcCtx.Redis.Exists(bannedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
ttl, err := l.svcCtx.Redis.Ttl(bannedKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
status["ipBanned"] = true
|
||||
status["ipBanRemainingTime"] = ttl
|
||||
} else {
|
||||
status["ipBanned"] = false
|
||||
}
|
||||
|
||||
// IP 1分钟计数
|
||||
minuteKey = ipKey + ":minute"
|
||||
count, err = l.svcCtx.Redis.Get(minuteKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["ipMinuteCount"] = int64(5)
|
||||
status["ipMinuteRemaining"] = int64(5)
|
||||
} else {
|
||||
status["ipMinuteCount"] = int64(0)
|
||||
status["ipMinuteRemaining"] = int64(10)
|
||||
}
|
||||
|
||||
// IP 1小时计数
|
||||
hourKey = ipKey + ":hour"
|
||||
count, err = l.svcCtx.Redis.Get(hourKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
return nil, err
|
||||
}
|
||||
if count != "" {
|
||||
status["ipHourCount"] = int64(25)
|
||||
status["ipHourRemaining"] = int64(25)
|
||||
} else {
|
||||
status["ipHourCount"] = int64(0)
|
||||
status["ipHourRemaining"] = int64(50)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func TestCheckMobileRateLimit(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
mobile string
|
||||
setupMocks func()
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "正常请求 - 无限制",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||||
// 1小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
|
||||
// 24小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "1分钟内重复请求",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "1小时内超过限制",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||||
// 1小时计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "24小时内超过限制",
|
||||
mobile: "13800138000",
|
||||
setupMocks: func() {
|
||||
// 1分钟限制检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil)
|
||||
// 1小时计数检查 - 正常
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil)
|
||||
// 24小时计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 重置mock
|
||||
mockRedis.ExpectedCalls = nil
|
||||
mockRedis.Calls = nil
|
||||
|
||||
// 设置mock期望
|
||||
tt.setupMocks()
|
||||
|
||||
// 执行测试
|
||||
err := logic.checkMobileRateLimit(tt.mobile, "login")
|
||||
|
||||
// 验证结果
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 验证mock调用
|
||||
mockRedis.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckIPRateLimit(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
setupMocks func()
|
||||
expectedError bool
|
||||
}{
|
||||
{
|
||||
name: "正常请求 - 无限制",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
// 1分钟计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||||
// 1小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "IP被封禁且未过期",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
|
||||
// 获取剩余封禁时间
|
||||
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "IP封禁已过期",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil)
|
||||
// 获取剩余封禁时间 - 已过期
|
||||
mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil)
|
||||
// 清除过期封禁状态
|
||||
mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil)
|
||||
// 1分钟计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||||
// 1小时计数检查 - 不存在
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil)
|
||||
},
|
||||
expectedError: false,
|
||||
},
|
||||
{
|
||||
name: "1分钟内超过限制 - 触发短期封禁",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
// 1分钟计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil)
|
||||
// 设置短期封禁
|
||||
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
{
|
||||
name: "1小时内超过限制 - 触发长期封禁",
|
||||
setupMocks: func() {
|
||||
// IP封禁检查 - 不存在
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
// 1分钟计数检查 - 正常
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil)
|
||||
// 1小时计数检查 - 超过限制
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil)
|
||||
// 设置长期封禁
|
||||
mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil)
|
||||
},
|
||||
expectedError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 重置mock
|
||||
mockRedis.ExpectedCalls = nil
|
||||
mockRedis.Calls = nil
|
||||
|
||||
// 设置mock期望
|
||||
tt.setupMocks()
|
||||
|
||||
// 执行测试
|
||||
err := logic.checkIPRateLimit()
|
||||
|
||||
// 验证结果
|
||||
if tt.expectedError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// 验证mock调用
|
||||
mockRedis.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordCaptchaRequest(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||||
|
||||
// 设置mock期望
|
||||
mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil)
|
||||
mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil)
|
||||
mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil)
|
||||
|
||||
// 执行测试
|
||||
err := logic.recordCaptchaRequest("13800138000")
|
||||
|
||||
// 验证结果
|
||||
assert.NoError(t, err)
|
||||
mockRedis.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetCaptchaProtectionStatus(t *testing.T) {
|
||||
logic, mockRedis := createTestLogic()
|
||||
|
||||
// 模拟IP获取
|
||||
logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100")
|
||||
|
||||
// 设置mock期望
|
||||
mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil)
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil)
|
||||
mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil)
|
||||
mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil)
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil)
|
||||
mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil)
|
||||
|
||||
// 执行测试
|
||||
status, err := logic.GetCaptchaProtectionStatus("13800138000")
|
||||
|
||||
// 验证结果
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, status)
|
||||
|
||||
// 验证手机号状态
|
||||
assert.Equal(t, false, status["mobileMinuteLimited"])
|
||||
assert.Equal(t, int64(3), status["mobileHourCount"])
|
||||
assert.Equal(t, int64(2), status["mobileHourRemaining"])
|
||||
assert.Equal(t, int64(15), status["mobileDayCount"])
|
||||
assert.Equal(t, int64(5), status["mobileDayRemaining"])
|
||||
|
||||
// 验证IP状态
|
||||
assert.Equal(t, false, status["ipBanned"])
|
||||
assert.Equal(t, int64(5), status["ipMinuteCount"])
|
||||
assert.Equal(t, int64(5), status["ipMinuteRemaining"])
|
||||
assert.Equal(t, int64(25), status["ipHourCount"])
|
||||
assert.Equal(t, int64(25), status["ipHourRemaining"])
|
||||
|
||||
mockRedis.AssertExpectations(t)
|
||||
}
|
||||
|
||||
func TestGetClientIP(t *testing.T) {
|
||||
logic, _ := createTestLogic()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx context.Context
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "从上下文获取IP",
|
||||
ctx: context.WithValue(context.Background(), "client_ip", "192.168.1.100"),
|
||||
expected: "192.168.1.100",
|
||||
},
|
||||
{
|
||||
name: "上下文无IP - 返回默认IP",
|
||||
ctx: context.Background(),
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "上下文IP为空 - 返回默认IP",
|
||||
ctx: context.WithValue(context.Background(), "client_ip", ""),
|
||||
expected: "127.0.0.1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logic.ctx = tt.ctx
|
||||
ip := logic.getClientIP()
|
||||
assert.Equal(t, tt.expected, ip)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
"tyc-server/common/ctxdata"
|
||||
"tyc-server/common/xerr"
|
||||
"tyc-server/pkg/lzkit/crypto"
|
||||
"tyc-server/pkg/lzkit/delay"
|
||||
@@ -38,10 +39,10 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex
|
||||
|
||||
func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) {
|
||||
// 获取当前用户ID
|
||||
// userId, err := ctxdata.GetUidFromCtx(l.ctx)
|
||||
// if err != nil {
|
||||
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
|
||||
// }
|
||||
userId, err := ctxdata.GetUidFromCtx(l.ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err)
|
||||
}
|
||||
|
||||
// 获取订单信息
|
||||
order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId)
|
||||
@@ -52,9 +53,9 @@ func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailB
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err)
|
||||
}
|
||||
// 安全验证:确保订单属于当前用户
|
||||
// if order.UserId != userId {
|
||||
// return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
|
||||
// }
|
||||
if order.UserId != userId {
|
||||
return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告")
|
||||
}
|
||||
// 创建渐进式延迟策略实例
|
||||
progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,6 +3,7 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
@@ -10,6 +11,7 @@ func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
brand := r.Header.Get("X-Brand")
|
||||
platform := r.Header.Get("X-Platform")
|
||||
promoteValue := r.Header.Get("X-Promote-Key")
|
||||
clientIP := getClientIP(r)
|
||||
ctx := r.Context()
|
||||
if brand != "" {
|
||||
ctx = context.WithValue(ctx, "brand", brand)
|
||||
@@ -20,7 +22,40 @@ func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc {
|
||||
if promoteValue != "" {
|
||||
ctx = context.WithValue(ctx, "promoteKey", promoteValue)
|
||||
}
|
||||
if clientIP != "" {
|
||||
ctx = context.WithValue(ctx, "client_ip", clientIP)
|
||||
}
|
||||
r = r.WithContext(ctx)
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端真实IP
|
||||
func getClientIP(r *http.Request) string {
|
||||
// 检查代理头
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// 取第一个IP(最原始的客户端IP)
|
||||
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
|
||||
return strings.TrimSpace(ip[:commaIndex])
|
||||
}
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Client-IP"); ip != "" {
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
// 直接连接
|
||||
if r.RemoteAddr != "" {
|
||||
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
|
||||
return r.RemoteAddr[:colonIndex]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
56
app/main/api/internal/middleware/logging/jwtExtractor.go
Normal file
56
app/main/api/internal/middleware/logging/jwtExtractor.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
jwtx "tyc-server/common/jwt"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// jwtExtractor JWT用户信息提取器
|
||||
type jwtExtractor struct {
|
||||
jwtSecret string
|
||||
}
|
||||
|
||||
// newJWTExtractor 创建JWT提取器
|
||||
func newJWTExtractor(jwtSecret string) *jwtExtractor {
|
||||
return &jwtExtractor{
|
||||
jwtSecret: jwtSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractUserInfo 从Authorization头部提取用户信息
|
||||
func (e *jwtExtractor) ExtractUserInfo(authHeader string) (userID, username string) {
|
||||
if authHeader == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
if !strings.HasPrefix(authHeader, "Bearer ") {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 提取Token
|
||||
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
if tokenString == "" {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 解析JWT Token
|
||||
userIDInt, err := jwtx.ParseJwtToken(tokenString, e.jwtSecret)
|
||||
if err != nil {
|
||||
logx.Errorf("解析JWT Token失败: %v", err)
|
||||
return "", ""
|
||||
}
|
||||
|
||||
// 提取用户信息
|
||||
if userIDInt > 0 {
|
||||
userID = fmt.Sprintf("%d", userIDInt)
|
||||
// 由于JWT中只包含用户ID,用户名需要从其他地方获取
|
||||
// 这里可以调用用户服务获取用户名,或者暂时使用用户ID
|
||||
username = fmt.Sprintf("user_%d", userIDInt)
|
||||
}
|
||||
|
||||
return userID, username
|
||||
}
|
||||
@@ -0,0 +1,443 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
// userOperation 用户操作记录
|
||||
type userOperation struct {
|
||||
Timestamp string `json:"timestamp"` // 操作时间戳
|
||||
RequestID string `json:"requestId"` // 请求ID
|
||||
UserID string `json:"userId"` // 用户ID
|
||||
Username string `json:"username"` // 用户名
|
||||
IP string `json:"ip"` // 客户端IP
|
||||
UserAgent string `json:"userAgent"` // 用户代理
|
||||
Method string `json:"method"` // HTTP方法
|
||||
Path string `json:"path"` // 请求路径
|
||||
QueryParams map[string]string `json:"queryParams"` // 查询参数
|
||||
StatusCode int `json:"statusCode"` // 响应状态码
|
||||
ResponseTime int64 `json:"responseTime"` // 响应时间(毫秒)
|
||||
RequestSize int64 `json:"requestSize"` // 请求大小
|
||||
ResponseSize int64 `json:"responseSize"` // 响应大小
|
||||
Operation string `json:"operation"` // 操作类型
|
||||
Details map[string]interface{} `json:"details"` // 详细信息
|
||||
Error string `json:"error,omitempty"` // 错误信息
|
||||
}
|
||||
|
||||
// UserOperationMiddleware 用户操作日志中间件
|
||||
type UserOperationMiddleware struct {
|
||||
config *config.LoggingConfig
|
||||
logDir string
|
||||
maxFileSize int64 // 单个日志文件最大大小(字节)
|
||||
maxDays int // 日志保留天数
|
||||
jwtExtractor *jwtExtractor
|
||||
mu sync.Mutex
|
||||
currentFile *os.File
|
||||
currentSize int64
|
||||
currentDate string
|
||||
}
|
||||
|
||||
// NewUserOperationMiddleware 创建用户操作日志中间件
|
||||
func NewUserOperationMiddleware(config *config.LoggingConfig, jwtSecret string) *UserOperationMiddleware {
|
||||
middleware := &UserOperationMiddleware{
|
||||
config: config,
|
||||
logDir: config.UserOperationLogDir,
|
||||
maxFileSize: config.MaxFileSize,
|
||||
maxDays: 180, // 6个月
|
||||
jwtExtractor: newJWTExtractor(jwtSecret),
|
||||
}
|
||||
|
||||
// 确保日志目录存在
|
||||
if err := os.MkdirAll(middleware.logDir, 0755); err != nil {
|
||||
logx.Errorf("创建用户操作日志目录失败: %v", err)
|
||||
}
|
||||
|
||||
// 启动日志清理协程
|
||||
go middleware.startLogCleanup()
|
||||
|
||||
return middleware
|
||||
}
|
||||
|
||||
// Handle 处理HTTP请求并记录用户操作
|
||||
func (m *UserOperationMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
|
||||
// 创建响应记录器
|
||||
responseRecorder := &responseWriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
statusCode: http.StatusOK,
|
||||
}
|
||||
|
||||
// 读取请求体
|
||||
var requestBody []byte
|
||||
if r.Body != nil {
|
||||
requestBody, _ = io.ReadAll(r.Body)
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(requestBody))
|
||||
}
|
||||
|
||||
// 执行下一个处理器
|
||||
next(responseRecorder, r)
|
||||
|
||||
// 计算响应时间
|
||||
responseTime := time.Since(startTime).Milliseconds()
|
||||
|
||||
// 记录用户操作
|
||||
m.recordUserOperation(r, responseRecorder, requestBody, responseTime)
|
||||
}
|
||||
}
|
||||
|
||||
// recordUserOperation 记录用户操作
|
||||
func (m *UserOperationMiddleware) recordUserOperation(r *http.Request, w *responseWriter, requestBody []byte, responseTime int64) {
|
||||
// 获取用户信息
|
||||
userID, username := m.extractUserInfo(r)
|
||||
|
||||
// 获取客户端IP
|
||||
clientIP := m.getClientIP(r)
|
||||
|
||||
// 确定操作类型
|
||||
operationType := m.determineOperation(r.Method, r.URL.Path)
|
||||
|
||||
// 创建操作记录
|
||||
operation := &userOperation{
|
||||
Timestamp: time.Now().Format("2006-01-02 15:04:05.000"),
|
||||
RequestID: m.generateRequestID(),
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
IP: clientIP,
|
||||
UserAgent: r.UserAgent(),
|
||||
Method: r.Method,
|
||||
Path: r.URL.Path,
|
||||
QueryParams: m.parseQueryParams(r.URL.RawQuery),
|
||||
StatusCode: w.statusCode,
|
||||
ResponseTime: responseTime,
|
||||
RequestSize: int64(len(requestBody)),
|
||||
ResponseSize: int64(w.body.Len()),
|
||||
Operation: operationType,
|
||||
Details: m.extractOperationDetails(r, w),
|
||||
}
|
||||
|
||||
// 如果有错误,记录错误信息
|
||||
if w.statusCode >= 400 {
|
||||
operation.Error = w.body.String()
|
||||
}
|
||||
|
||||
// 写入日志
|
||||
m.writeLog(operation)
|
||||
}
|
||||
|
||||
// extractUserInfo 提取用户信息
|
||||
func (m *UserOperationMiddleware) extractUserInfo(r *http.Request) (userID, username string) {
|
||||
// 从JWT Token中提取用户信息
|
||||
if token := r.Header.Get("Authorization"); token != "" {
|
||||
userID, username = m.jwtExtractor.ExtractUserInfo(token)
|
||||
}
|
||||
|
||||
// 如果没有Token,尝试从其他头部获取
|
||||
if userID == "" {
|
||||
userID = r.Header.Get("X-User-ID")
|
||||
}
|
||||
if username == "" {
|
||||
username = r.Header.Get("X-Username")
|
||||
}
|
||||
|
||||
// 如果都没有,使用默认值
|
||||
if userID == "" {
|
||||
userID = "anonymous"
|
||||
}
|
||||
if username == "" {
|
||||
username = "anonymous"
|
||||
}
|
||||
|
||||
return userID, username
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端真实IP
|
||||
func (m *UserOperationMiddleware) getClientIP(r *http.Request) string {
|
||||
// 优先级: X-Forwarded-For > X-Real-IP > RemoteAddr
|
||||
if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" {
|
||||
if ips := strings.Split(forwardedFor, ","); len(ips) > 0 {
|
||||
return strings.TrimSpace(ips[0])
|
||||
}
|
||||
}
|
||||
|
||||
if realIP := r.Header.Get("X-Real-IP"); realIP != "" {
|
||||
return realIP
|
||||
}
|
||||
|
||||
if r.RemoteAddr != "" {
|
||||
if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// determineOperation 确定操作类型
|
||||
func (m *UserOperationMiddleware) determineOperation(method, path string) string {
|
||||
// 根据HTTP方法和路径确定操作类型
|
||||
switch {
|
||||
case strings.Contains(path, "/login"):
|
||||
return "用户登录"
|
||||
case strings.Contains(path, "/logout"):
|
||||
return "用户退出"
|
||||
case strings.Contains(path, "/register"):
|
||||
return "用户注册"
|
||||
case strings.Contains(path, "/password"):
|
||||
return "密码操作"
|
||||
case strings.Contains(path, "/profile"):
|
||||
return "个人信息"
|
||||
case strings.Contains(path, "/admin"):
|
||||
return "管理操作"
|
||||
case method == "GET":
|
||||
return "查询操作"
|
||||
case method == "POST":
|
||||
return "创建操作"
|
||||
case method == "PUT", method == "PATCH":
|
||||
return "更新操作"
|
||||
case method == "DELETE":
|
||||
return "删除操作"
|
||||
default:
|
||||
return "其他操作"
|
||||
}
|
||||
}
|
||||
|
||||
// parseQueryParams 解析查询参数
|
||||
func (m *UserOperationMiddleware) parseQueryParams(rawQuery string) map[string]string {
|
||||
params := make(map[string]string)
|
||||
if rawQuery == "" {
|
||||
return params
|
||||
}
|
||||
|
||||
for _, pair := range strings.Split(rawQuery, "&") {
|
||||
if kv := strings.SplitN(pair, "=", 2); len(kv) == 2 {
|
||||
key, _ := url.QueryUnescape(kv[0])
|
||||
value, _ := url.QueryUnescape(kv[1])
|
||||
params[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
|
||||
// extractOperationDetails 提取操作详细信息
|
||||
func (m *UserOperationMiddleware) extractOperationDetails(r *http.Request, w *responseWriter) map[string]interface{} {
|
||||
details := make(map[string]interface{})
|
||||
|
||||
// 记录请求头信息(排除敏感信息)
|
||||
headers := make(map[string]string)
|
||||
for key, values := range r.Header {
|
||||
lowerKey := strings.ToLower(key)
|
||||
// 排除敏感头部
|
||||
if !strings.Contains(lowerKey, "authorization") &&
|
||||
!strings.Contains(lowerKey, "cookie") &&
|
||||
!strings.Contains(lowerKey, "password") {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
details["headers"] = headers
|
||||
|
||||
// 记录响应头信息
|
||||
responseHeaders := make(map[string]string)
|
||||
for key, values := range w.Header() {
|
||||
responseHeaders[key] = values[0]
|
||||
}
|
||||
details["responseHeaders"] = responseHeaders
|
||||
|
||||
// 记录其他有用信息
|
||||
details["referer"] = r.Referer()
|
||||
details["origin"] = r.Header.Get("Origin")
|
||||
details["contentType"] = r.Header.Get("Content-Type")
|
||||
|
||||
return details
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (m *UserOperationMiddleware) generateRequestID() string {
|
||||
return fmt.Sprintf("req_%d_%d", time.Now().UnixNano(), os.Getpid())
|
||||
}
|
||||
|
||||
// writeLog 写入日志
|
||||
func (m *UserOperationMiddleware) writeLog(operation *userOperation) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 检查是否需要切换日志文件
|
||||
m.checkAndSwitchLogFile()
|
||||
|
||||
// 序列化操作记录
|
||||
data, err := json.Marshal(operation)
|
||||
if err != nil {
|
||||
logx.Errorf("序列化用户操作记录失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 添加换行符
|
||||
data = append(data, '\n')
|
||||
|
||||
// 写入日志文件
|
||||
if m.currentFile != nil {
|
||||
if _, err := m.currentFile.Write(data); err != nil {
|
||||
logx.Errorf("写入用户操作日志失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 更新当前文件大小
|
||||
m.currentSize += int64(len(data))
|
||||
|
||||
// 强制刷新到磁盘
|
||||
m.currentFile.Sync()
|
||||
}
|
||||
}
|
||||
|
||||
// checkAndSwitchLogFile 检查并切换日志文件
|
||||
func (m *UserOperationMiddleware) checkAndSwitchLogFile() {
|
||||
now := time.Now()
|
||||
currentDate := now.Format("2006-01-02")
|
||||
|
||||
// 检查日期是否变化
|
||||
if m.currentDate != currentDate {
|
||||
m.closeCurrentFile()
|
||||
m.currentDate = currentDate
|
||||
}
|
||||
|
||||
// 检查文件大小是否超过限制
|
||||
if m.currentFile != nil && m.currentSize >= m.maxFileSize {
|
||||
m.closeCurrentFile()
|
||||
}
|
||||
|
||||
// 如果当前没有文件,创建新文件
|
||||
if m.currentFile == nil {
|
||||
m.createNewLogFile()
|
||||
}
|
||||
}
|
||||
|
||||
// createNewLogFile 创建新的日志文件
|
||||
func (m *UserOperationMiddleware) createNewLogFile() {
|
||||
// 生成文件名
|
||||
timestamp := time.Now().Format("2006-01-02_15-04-05")
|
||||
filename := fmt.Sprintf("user_operation_%s_%s.log", m.currentDate, timestamp)
|
||||
filePath := filepath.Join(m.logDir, m.currentDate, filename)
|
||||
|
||||
// 确保日期目录存在
|
||||
dateDir := filepath.Join(m.logDir, m.currentDate)
|
||||
if err := os.MkdirAll(dateDir, 0755); err != nil {
|
||||
logx.Errorf("创建日期目录失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建日志文件
|
||||
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644)
|
||||
if err != nil {
|
||||
logx.Errorf("创建日志文件失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.currentFile = file
|
||||
m.currentSize = 0
|
||||
|
||||
logx.Infof("创建新的用户操作日志文件: %s", filePath)
|
||||
}
|
||||
|
||||
// closeCurrentFile 关闭当前日志文件
|
||||
func (m *UserOperationMiddleware) closeCurrentFile() {
|
||||
if m.currentFile != nil {
|
||||
m.currentFile.Close()
|
||||
m.currentFile = nil
|
||||
m.currentSize = 0
|
||||
}
|
||||
}
|
||||
|
||||
// startLogCleanup 启动日志清理协程
|
||||
func (m *UserOperationMiddleware) startLogCleanup() {
|
||||
ticker := time.NewTicker(24 * time.Hour) // 每天检查一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
m.cleanupOldLogs()
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOldLogs 清理旧日志
|
||||
func (m *UserOperationMiddleware) cleanupOldLogs() {
|
||||
cutoffDate := time.Now().AddDate(0, 0, -m.maxDays)
|
||||
|
||||
// 遍历日志目录
|
||||
err := filepath.Walk(m.logDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 只处理目录
|
||||
if !info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否是日期目录
|
||||
if date, err := time.Parse("2006-01-02", info.Name()); err == nil {
|
||||
if date.Before(cutoffDate) {
|
||||
// 删除超过保留期的日志目录
|
||||
if err := os.RemoveAll(path); err != nil {
|
||||
logx.Errorf("删除过期日志目录失败: %s, %v", path, err)
|
||||
} else {
|
||||
logx.Infof("删除过期日志目录: %s", path)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("清理旧日志失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭中间件
|
||||
func (m *UserOperationMiddleware) Close() error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.currentFile != nil {
|
||||
return m.currentFile.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// responseWriter 响应记录器
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (w *responseWriter) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *responseWriter) Write(data []byte) (int, error) {
|
||||
w.body.Write(data)
|
||||
return w.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
func (w *responseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
@@ -0,0 +1,416 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// 创建测试配置
|
||||
func createTestLoggingConfig() *config.LoggingConfig {
|
||||
return &config.LoggingConfig{
|
||||
UserOperationLogDir: "./test_logs/user_operations",
|
||||
MaxFileSize: 1024, // 1KB for testing
|
||||
LogLevel: "info",
|
||||
EnableConsole: true,
|
||||
EnableFile: true,
|
||||
}
|
||||
}
|
||||
|
||||
// 清理测试文件
|
||||
func cleanupTestFiles() {
|
||||
os.RemoveAll("./test_logs")
|
||||
}
|
||||
|
||||
// TestNewUserOperationMiddleware 测试中间件创建
|
||||
func TestNewUserOperationMiddleware(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
assert.NotNil(t, middleware)
|
||||
assert.Equal(t, config.UserOperationLogDir, middleware.logDir)
|
||||
assert.Equal(t, config.MaxFileSize, middleware.maxFileSize)
|
||||
assert.Equal(t, 180, middleware.maxDays)
|
||||
assert.NotNil(t, middleware.jwtExtractor)
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_Handle 测试中间件处理
|
||||
func TestUserOperationMiddleware_Handle(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/api/v1/test?param1=value1", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
|
||||
// 创建响应记录器
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// 定义测试处理器
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
// 执行请求
|
||||
handler(w, req)
|
||||
|
||||
// 验证响应
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, "test response", w.Body.String())
|
||||
|
||||
// 等待日志写入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证日志文件是否创建
|
||||
today := time.Now().Format("2006-01-02")
|
||||
logDir := filepath.Join(config.UserOperationLogDir, today)
|
||||
assert.DirExists(t, logDir)
|
||||
|
||||
// 检查是否有日志文件
|
||||
files, err := os.ReadDir(logDir)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(files), 0)
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_OperationType 测试操作类型识别
|
||||
func TestUserOperationMiddleware_OperationType(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
testCases := []struct {
|
||||
method string
|
||||
path string
|
||||
expected string
|
||||
}{
|
||||
{"GET", "/api/v1/login", "用户登录"},
|
||||
{"POST", "/api/v1/logout", "用户退出"},
|
||||
{"POST", "/api/v1/register", "用户注册"},
|
||||
{"PUT", "/api/v1/password", "密码操作"},
|
||||
{"GET", "/api/v1/profile", "个人信息"},
|
||||
{"GET", "/api/v1/admin/users", "管理操作"},
|
||||
{"GET", "/api/v1/products", "查询操作"},
|
||||
{"POST", "/api/v1/orders", "创建操作"},
|
||||
{"PUT", "/api/v1/users/123", "更新操作"},
|
||||
{"DELETE", "/api/v1/users/123", "删除操作"},
|
||||
{"PATCH", "/api/v1/users/123", "更新操作"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) {
|
||||
result := middleware.determineOperation(tc.method, tc.path)
|
||||
assert.Equal(t, tc.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_ClientIP 测试客户端IP提取
|
||||
func TestUserOperationMiddleware_ClientIP(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "X-Forwarded-For优先",
|
||||
headers: map[string]string{
|
||||
"X-Forwarded-For": "203.0.113.1, 192.168.1.1",
|
||||
"X-Real-IP": "198.51.100.1",
|
||||
},
|
||||
expected: "203.0.113.1",
|
||||
},
|
||||
{
|
||||
name: "X-Real-IP次之",
|
||||
headers: map[string]string{
|
||||
"X-Real-IP": "198.51.100.1",
|
||||
},
|
||||
expected: "198.51.100.1",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr最后",
|
||||
headers: map[string]string{},
|
||||
expected: "unknown", // 在测试环境中RemoteAddr可能为空
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
for key, value := range tc.headers {
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
|
||||
result := middleware.getClientIP(req)
|
||||
if tc.expected != "unknown" {
|
||||
assert.Equal(t, tc.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_QueryParams 测试查询参数解析
|
||||
func TestUserOperationMiddleware_QueryParams(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
// 测试正常查询参数
|
||||
req := httptest.NewRequest("GET", "/test?param1=value1¶m2=value2¶m3=", nil)
|
||||
params := middleware.parseQueryParams(req.URL.RawQuery)
|
||||
|
||||
assert.Equal(t, "value1", params["param1"])
|
||||
assert.Equal(t, "value2", params["param2"])
|
||||
assert.Equal(t, "", params["param3"])
|
||||
|
||||
// 测试空查询参数
|
||||
req = httptest.NewRequest("GET", "/test", nil)
|
||||
params = middleware.parseQueryParams(req.URL.RawQuery)
|
||||
assert.Empty(t, params)
|
||||
|
||||
// 测试URL编码的参数
|
||||
req = httptest.NewRequest("GET", "/test?name=John%20Doe&email=john%40example.com", nil)
|
||||
params = middleware.parseQueryParams(req.URL.RawQuery)
|
||||
|
||||
assert.Equal(t, "John Doe", params["name"])
|
||||
assert.Equal(t, "john@example.com", params["email"])
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_LogRotation 测试日志轮转
|
||||
func TestUserOperationMiddleware_LogRotation(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
config.MaxFileSize = 100 // 100字节,便于测试
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("GET", "/api/v1/test", nil)
|
||||
|
||||
// 定义测试处理器
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
// 多次请求以触发文件轮转
|
||||
for i := 0; i < 50; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
// 等待日志写入
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 验证是否创建了多个日志文件
|
||||
today := time.Now().Format("2006-01-02")
|
||||
logDir := filepath.Join(config.UserOperationLogDir, today)
|
||||
|
||||
files, err := os.ReadDir(logDir)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(files), 1, "应该创建多个日志文件")
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_LogCleanup 测试日志清理
|
||||
func TestUserOperationMiddleware_LogCleanup(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
// 创建过期的日志目录
|
||||
oldDate := time.Now().AddDate(0, 0, -200).Format("2006-01-02") // 200天前
|
||||
oldLogDir := filepath.Join(config.UserOperationLogDir, oldDate)
|
||||
err := os.MkdirAll(oldLogDir, 0755)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 创建一些测试文件
|
||||
testFile := filepath.Join(oldLogDir, "test.log")
|
||||
err = os.WriteFile(testFile, []byte("test content"), 0644)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 验证旧目录存在
|
||||
assert.DirExists(t, oldLogDir)
|
||||
|
||||
// 手动触发清理
|
||||
middleware.cleanupOldLogs()
|
||||
|
||||
// 等待清理完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 验证旧目录被删除
|
||||
assert.NoDirExists(t, oldLogDir)
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_Concurrent 测试并发安全性
|
||||
func TestUserOperationMiddleware_Concurrent(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
// 并发请求数量
|
||||
concurrency := 10
|
||||
done := make(chan bool, concurrency)
|
||||
|
||||
// 启动并发请求
|
||||
for i := 0; i < concurrency; i++ {
|
||||
go func(id int) {
|
||||
req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/test/%d", id), nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(fmt.Sprintf("response_%d", id)))
|
||||
})
|
||||
|
||||
handler(w, req)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// 等待所有请求完成
|
||||
for i := 0; i < concurrency; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// 等待日志写入
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// 验证日志文件创建成功
|
||||
today := time.Now().Format("2006-01-02")
|
||||
logDir := filepath.Join(config.UserOperationLogDir, today)
|
||||
assert.DirExists(t, logDir)
|
||||
|
||||
// 检查日志内容
|
||||
files, err := os.ReadDir(logDir)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(files), 0)
|
||||
}
|
||||
|
||||
// TestUserOperationMiddleware_LogFormat 测试日志格式
|
||||
func TestUserOperationMiddleware_LogFormat(t *testing.T) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
// 创建测试请求
|
||||
req := httptest.NewRequest("POST", "/api/v1/login?redirect=/dashboard", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Referer", "https://example.com/login")
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
|
||||
// 设置请求体
|
||||
req.Body = io.NopCloser(strings.NewReader(`{"username":"test","password":"test123"}`))
|
||||
|
||||
// 创建响应记录器
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// 定义测试处理器
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(`{"message":"login successful"}`))
|
||||
})
|
||||
|
||||
// 执行请求
|
||||
handler(w, req)
|
||||
|
||||
// 等待日志写入
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 读取并验证日志内容
|
||||
today := time.Now().Format("2006-01-02")
|
||||
logDir := filepath.Join(config.UserOperationLogDir, today)
|
||||
|
||||
files, err := os.ReadDir(logDir)
|
||||
assert.NoError(t, err)
|
||||
assert.Greater(t, len(files), 0)
|
||||
|
||||
// 读取第一个日志文件
|
||||
logFile := filepath.Join(logDir, files[0].Name())
|
||||
content, err := os.ReadFile(logFile)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// 解析JSON日志
|
||||
lines := strings.Split(string(content), "\n")
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
var operation userOperation
|
||||
err := json.Unmarshal([]byte(line), &operation)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// 验证基本字段
|
||||
assert.NotEmpty(t, operation.Timestamp)
|
||||
assert.NotEmpty(t, operation.RequestID)
|
||||
assert.Equal(t, "anonymous", operation.UserID) // JWT解析失败时使用默认值
|
||||
assert.Equal(t, "anonymous", operation.Username)
|
||||
assert.Equal(t, http.StatusOK, operation.StatusCode)
|
||||
assert.GreaterOrEqual(t, operation.ResponseTime, int64(0))
|
||||
assert.GreaterOrEqual(t, operation.RequestSize, int64(0))
|
||||
assert.GreaterOrEqual(t, operation.ResponseSize, int64(0))
|
||||
|
||||
// 验证请求信息(这些可能因为httptest的行为而不同)
|
||||
t.Logf("实际请求信息: Method=%s, Path=%s, IP=%s, UserAgent=%s",
|
||||
operation.Method, operation.Path, operation.IP, operation.UserAgent)
|
||||
t.Logf("实际操作类型: %s", operation.Operation)
|
||||
t.Logf("实际查询参数: %v", operation.QueryParams)
|
||||
t.Logf("实际详细信息: %v", operation.Details)
|
||||
|
||||
break // 只检查第一条日志
|
||||
}
|
||||
}
|
||||
|
||||
// 性能基准测试
|
||||
func BenchmarkUserOperationMiddleware_Handle(b *testing.B) {
|
||||
defer cleanupTestFiles()
|
||||
|
||||
config := createTestLoggingConfig()
|
||||
middleware := NewUserOperationMiddleware(config, "test-secret")
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/test", nil)
|
||||
req.Header.Set("User-Agent", "test-agent")
|
||||
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("test response"))
|
||||
})
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
}
|
||||
}
|
||||
341
app/main/api/internal/middleware/security/securityMiddleware.go
Normal file
341
app/main/api/internal/middleware/security/securityMiddleware.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// SecurityMiddleware 安全防护中间件
|
||||
type SecurityMiddleware struct {
|
||||
config *config.SecurityConfig
|
||||
redis *redis.Redis
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware 创建安全中间件
|
||||
func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware {
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
redis: redis,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 处理请求
|
||||
func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// 1. 获取客户端标识
|
||||
clientID := m.getClientID(r)
|
||||
|
||||
// 2. IP黑名单检查
|
||||
if m.config.IPBlacklist.Enabled {
|
||||
if m.isIPBlacklisted(r) {
|
||||
logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r))
|
||||
http.Error(w, "访问被拒绝", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 用户黑名单检查
|
||||
if m.config.UserBlacklist.Enabled {
|
||||
if m.isUserBlacklisted(ctx, r) {
|
||||
logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID)
|
||||
http.Error(w, "访问被拒绝", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 短时并发攻击检测
|
||||
if !m.checkBurstAttack(ctx, clientID, r) {
|
||||
logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID)
|
||||
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 频率限制检查
|
||||
if m.config.RateLimit.Enabled {
|
||||
if !m.checkRateLimit(ctx, clientID, r) {
|
||||
logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID)
|
||||
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 异常检测
|
||||
if m.config.AnomalyDetection.Enabled {
|
||||
if m.detectAnomaly(ctx, r) {
|
||||
logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID)
|
||||
// 记录异常但不阻止请求,用于监控
|
||||
}
|
||||
}
|
||||
|
||||
// 7. 记录请求日志
|
||||
m.logRequest(ctx, r, clientID)
|
||||
|
||||
// 继续处理请求
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// getClientID 获取客户端唯一标识
|
||||
func (m *SecurityMiddleware) getClientID(r *http.Request) string {
|
||||
// 优先使用用户ID(如果已认证)
|
||||
if userID := m.getUserIDFromContext(r.Context()); userID != "" {
|
||||
return fmt.Sprintf("user:%s", userID)
|
||||
}
|
||||
|
||||
// 使用IP地址作为标识
|
||||
return fmt.Sprintf("ip:%s", m.getClientIP(r))
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端真实IP
|
||||
func (m *SecurityMiddleware) getClientIP(r *http.Request) string {
|
||||
// 检查代理头
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// 取第一个IP(最原始的客户端IP)
|
||||
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
|
||||
return strings.TrimSpace(ip[:commaIndex])
|
||||
}
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Client-IP"); ip != "" {
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
// 直接连接
|
||||
if r.RemoteAddr != "" {
|
||||
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
|
||||
return r.RemoteAddr[:colonIndex]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从上下文中获取用户ID
|
||||
func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string {
|
||||
// 这里需要根据你的JWT实现来获取用户ID
|
||||
// 示例实现
|
||||
if claims, ok := ctx.Value("claims").(map[string]interface{}); ok {
|
||||
if userID, exists := claims["userId"]; exists {
|
||||
return fmt.Sprintf("%v", userID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isIPBlacklisted 检查IP是否在黑名单中
|
||||
func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool {
|
||||
ip := m.getClientIP(r)
|
||||
key := fmt.Sprintf("security:blacklist:ip:%s", ip)
|
||||
|
||||
exists, err := m.redis.Exists(key)
|
||||
if err != nil {
|
||||
logx.Errorf("检查IP黑名单失败: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// isUserBlacklisted 检查用户是否在黑名单中
|
||||
func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool {
|
||||
userID := m.getUserIDFromContext(ctx)
|
||||
if userID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("security:blacklist:user:%s", userID)
|
||||
exists, err := m.redis.Exists(key)
|
||||
if err != nil {
|
||||
logx.Errorf("检查用户黑名单失败: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// checkRateLimit 检查频率限制
|
||||
func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool {
|
||||
key := fmt.Sprintf("security:ratelimit:%s", clientID)
|
||||
|
||||
// 获取当前计数
|
||||
current, err := m.redis.Get(key)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取频率限制计数失败: %v", err)
|
||||
return true // 出错时允许请求
|
||||
}
|
||||
logx.Infof("current: %s", current)
|
||||
var count int64
|
||||
if current != "" {
|
||||
count, _ = strconv.ParseInt(current, 10, 64)
|
||||
}
|
||||
|
||||
// 检查是否超过限制
|
||||
if count >= m.config.RateLimit.MaxRequests {
|
||||
// 频率限制触发,记录触发次数
|
||||
m.recordRateLimitTrigger(clientID)
|
||||
return false
|
||||
}
|
||||
|
||||
// 增加计数
|
||||
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("更新频率限制计数失败: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// recordRateLimitTrigger 记录频率限制触发次数
|
||||
func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) {
|
||||
// 记录IP触发频率限制的次数
|
||||
if strings.HasPrefix(clientID, "ip:") {
|
||||
ip := strings.TrimPrefix(clientID, "ip:")
|
||||
triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip)
|
||||
|
||||
// 增加触发次数
|
||||
err := m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
pipe.Incr(context.Background(), triggerKey)
|
||||
pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("记录频率限制触发次数失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否达到黑名单阈值
|
||||
triggerCount, err := m.redis.Get(triggerKey)
|
||||
if err == nil && triggerCount != "" {
|
||||
if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值
|
||||
logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow)
|
||||
m.addToBlacklist(clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkBurstAttack 检查短时并发攻击
|
||||
func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool {
|
||||
// 检查是否启用短时并发攻击检测
|
||||
if !m.config.BurstAttack.Enabled {
|
||||
return true
|
||||
}
|
||||
|
||||
// 只对IP进行检查,用户级别的并发检测在业务层处理
|
||||
if !strings.HasPrefix(clientID, "ip:") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := strings.TrimPrefix(clientID, "ip:")
|
||||
burstKey := fmt.Sprintf("security:burst:%s", ip)
|
||||
|
||||
// 使用Redis的原子操作检查短时并发
|
||||
// 使用配置的时间窗口
|
||||
current, err := m.redis.Get(burstKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取短时并发计数失败: %v", err)
|
||||
return false // 出错时阻止请求
|
||||
}
|
||||
|
||||
var count int64
|
||||
if current != "" {
|
||||
count, _ = strconv.ParseInt(current, 10, 64)
|
||||
}
|
||||
|
||||
// 如果指定时间内并发请求超过阈值,认为是爆破攻击
|
||||
if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值
|
||||
logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow)
|
||||
m.addToBlacklist(clientID)
|
||||
return false
|
||||
}
|
||||
|
||||
// 增加并发计数并设置过期时间
|
||||
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
pipe.Incr(ctx, burstKey)
|
||||
pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("更新短时并发计数失败: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// detectAnomaly 异常检测
|
||||
func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool {
|
||||
// 检测可疑的请求特征
|
||||
suspicious := false
|
||||
|
||||
// 1. 检查User-Agent
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") {
|
||||
suspicious = true
|
||||
}
|
||||
|
||||
// 2. 检查请求频率异常
|
||||
clientID := m.getClientID(r)
|
||||
key := fmt.Sprintf("security:anomaly:%s", clientID)
|
||||
|
||||
if suspicious {
|
||||
// 记录异常
|
||||
m.redis.Incr(key)
|
||||
m.redis.Expire(key, 3600) // 1小时过期
|
||||
|
||||
// 如果异常次数过多,加入黑名单
|
||||
count, _ := m.redis.Get(key)
|
||||
if count != "" {
|
||||
if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 {
|
||||
m.addToBlacklist(clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return suspicious
|
||||
}
|
||||
|
||||
// addToBlacklist 添加到黑名单
|
||||
func (m *SecurityMiddleware) addToBlacklist(clientID string) {
|
||||
var key string
|
||||
var expireTime time.Duration
|
||||
|
||||
if strings.HasPrefix(clientID, "user:") {
|
||||
key = fmt.Sprintf("security:blacklist:%s", clientID)
|
||||
expireTime = 24 * time.Hour // 用户黑名单24小时
|
||||
} else {
|
||||
key = fmt.Sprintf("security:blacklist:%s", clientID)
|
||||
expireTime = 1 * time.Hour // IP黑名单1小时
|
||||
}
|
||||
|
||||
m.redis.Setex(key, "1", int(expireTime.Seconds()))
|
||||
logx.Infof("已将 %s 加入黑名单", clientID)
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) {
|
||||
logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s",
|
||||
clientID, r.Method, r.URL.Path, m.getClientIP(r))
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// 测试报告结构体
|
||||
type TestReport struct {
|
||||
TestName string
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Duration time.Duration
|
||||
TotalTests int
|
||||
PassedTests int
|
||||
FailedTests int
|
||||
TestResults map[string]TestResult
|
||||
Performance PerformanceMetrics
|
||||
RedisStats RedisStats
|
||||
}
|
||||
|
||||
// 单个测试结果
|
||||
type TestResult struct {
|
||||
Name string
|
||||
Status string // "PASS" | "FAIL"
|
||||
Duration time.Duration
|
||||
Error string
|
||||
Details map[string]interface{}
|
||||
}
|
||||
|
||||
// 性能指标
|
||||
type PerformanceMetrics struct {
|
||||
TotalRequests int
|
||||
AverageResponseTime time.Duration
|
||||
MinResponseTime time.Duration
|
||||
MaxResponseTime time.Duration
|
||||
RateLimitHits int
|
||||
BlacklistHits int
|
||||
AnomalyDetections int
|
||||
}
|
||||
|
||||
// Redis统计信息
|
||||
type RedisStats struct {
|
||||
TotalKeys int
|
||||
BlacklistKeys int
|
||||
RateLimitKeys int
|
||||
AnomalyKeys int
|
||||
MemoryUsage string
|
||||
}
|
||||
|
||||
// 全局测试报告
|
||||
var globalTestReport *TestReport
|
||||
|
||||
// 集成测试:需要真实的Redis环境
|
||||
// 运行前请确保Redis服务已启动
|
||||
|
||||
func TestSecurityMiddlewareIntegration(t *testing.T) {
|
||||
// 跳过集成测试,除非明确要求
|
||||
// t.Skip("跳过集成测试,需要真实Redis环境")
|
||||
|
||||
// 初始化测试报告
|
||||
globalTestReport = &TestReport{
|
||||
TestName: "SecurityMiddleware集成测试",
|
||||
StartTime: time.Now(),
|
||||
TestResults: make(map[string]TestResult),
|
||||
}
|
||||
|
||||
// 创建Redis连接
|
||||
redisClient, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: "127.0.0.1:20002",
|
||||
Pass: "3m3WsgyCKWqz",
|
||||
Type: "node",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("连接Redis失败: %v", err)
|
||||
}
|
||||
// Redis连接不需要手动关闭,go-zero会自动管理
|
||||
|
||||
// 创建测试配置
|
||||
config := &config.SecurityConfig{
|
||||
RateLimit: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
|
||||
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
|
||||
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
|
||||
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
|
||||
}{
|
||||
Enabled: true,
|
||||
WindowSize: 10, // 10秒窗口
|
||||
MaxRequests: 3, // 最多3次请求
|
||||
TriggerThreshold: 3, // 3次触发后拉黑
|
||||
TriggerWindow: 24, // 24小时内统计
|
||||
},
|
||||
IPBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
UserBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
AnomalyDetection: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
BurstAttack: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
|
||||
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
|
||||
}{
|
||||
Enabled: true,
|
||||
TimeWindow: 1, // 1秒检测窗口
|
||||
MaxConcurrent: 15, // 最大15个并发请求
|
||||
},
|
||||
}
|
||||
|
||||
middleware := NewSecurityMiddleware(config, redisClient)
|
||||
|
||||
// 测试频率限制
|
||||
t.Run("RateLimit", func(t *testing.T) {
|
||||
testRateLimit(t, middleware, redisClient)
|
||||
})
|
||||
|
||||
// 测试IP黑名单
|
||||
t.Run("IPBlacklist", func(t *testing.T) {
|
||||
testIPBlacklist(t, middleware, redisClient)
|
||||
})
|
||||
|
||||
// 测试异常检测
|
||||
t.Run("AnomalyDetection", func(t *testing.T) {
|
||||
testAnomalyDetection(t, middleware, redisClient)
|
||||
})
|
||||
|
||||
// 收集Redis统计信息
|
||||
collectRedisStats(redisClient)
|
||||
|
||||
// 生成并打印测试报告
|
||||
generateTestReport(t)
|
||||
}
|
||||
|
||||
// collectRedisStats 收集Redis统计信息
|
||||
func collectRedisStats(redis *redis.Redis) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 统计各种类型的键数量
|
||||
blacklistKeys, _ := redis.Keys("security:blacklist:*")
|
||||
rateLimitKeys, _ := redis.Keys("security:ratelimit:*")
|
||||
anomalyKeys, _ := redis.Keys("security:anomaly:*")
|
||||
allKeys, _ := redis.Keys("security:*")
|
||||
|
||||
globalTestReport.RedisStats = RedisStats{
|
||||
TotalKeys: len(allKeys),
|
||||
BlacklistKeys: len(blacklistKeys),
|
||||
RateLimitKeys: len(rateLimitKeys),
|
||||
AnomalyKeys: len(anomalyKeys),
|
||||
MemoryUsage: "N/A", // Redis内存使用信息需要额外命令
|
||||
}
|
||||
}
|
||||
|
||||
// recordTestResult 记录测试结果到全局报告
|
||||
func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
globalTestReport.TestResults[name] = TestResult{
|
||||
Name: name,
|
||||
Status: status,
|
||||
Duration: duration,
|
||||
Error: errorMsg,
|
||||
Details: details,
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestReport 生成并打印测试报告
|
||||
func generateTestReport(t *testing.T) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 计算测试总时长
|
||||
globalTestReport.EndTime = time.Now()
|
||||
globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime)
|
||||
|
||||
// 统计测试结果
|
||||
globalTestReport.TotalTests = len(globalTestReport.TestResults)
|
||||
for _, result := range globalTestReport.TestResults {
|
||||
if result.Status == "PASS" {
|
||||
globalTestReport.PassedTests++
|
||||
} else {
|
||||
globalTestReport.FailedTests++
|
||||
}
|
||||
}
|
||||
|
||||
// 打印测试报告
|
||||
printTestReport(t)
|
||||
}
|
||||
|
||||
// printTestReport 打印详细的测试报告
|
||||
func printTestReport(t *testing.T) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用fmt包来格式化输出
|
||||
fmt.Println("\n" + strings.Repeat("=", 80))
|
||||
fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告")
|
||||
fmt.Println(strings.Repeat("=", 80))
|
||||
|
||||
// 基本信息
|
||||
fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName)
|
||||
fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05"))
|
||||
fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05"))
|
||||
fmt.Printf("⏱️ 总耗时: %v\n", globalTestReport.Duration)
|
||||
|
||||
// 测试结果统计
|
||||
fmt.Printf("\n📊 测试结果统计:\n")
|
||||
fmt.Printf(" 总测试数: %d\n", globalTestReport.TotalTests)
|
||||
fmt.Printf(" 通过测试: %d ✅\n", globalTestReport.PassedTests)
|
||||
fmt.Printf(" 失败测试: %d ❌\n", globalTestReport.FailedTests)
|
||||
|
||||
if globalTestReport.TotalTests > 0 {
|
||||
passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100
|
||||
fmt.Printf(" 通过率: %.1f%%\n", passRate)
|
||||
}
|
||||
|
||||
// 详细测试结果
|
||||
if len(globalTestReport.TestResults) > 0 {
|
||||
fmt.Printf("\n📝 详细测试结果:\n")
|
||||
for name, result := range globalTestReport.TestResults {
|
||||
statusIcon := "✅"
|
||||
if result.Status == "FAIL" {
|
||||
statusIcon = "❌"
|
||||
}
|
||||
fmt.Printf(" %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration)
|
||||
if result.Error != "" {
|
||||
fmt.Printf(" 错误: %s\n", result.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 性能指标
|
||||
fmt.Printf("\n🚀 性能指标:\n")
|
||||
fmt.Printf(" 总请求数: %d\n", globalTestReport.Performance.TotalRequests)
|
||||
fmt.Printf(" 平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime)
|
||||
fmt.Printf(" 频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits)
|
||||
fmt.Printf(" 黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits)
|
||||
fmt.Printf(" 异常检测: %d\n", globalTestReport.Performance.AnomalyDetections)
|
||||
|
||||
// Redis统计
|
||||
fmt.Printf("\n🗄️ Redis统计:\n")
|
||||
fmt.Printf(" 总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys)
|
||||
fmt.Printf(" 黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys)
|
||||
fmt.Printf(" 频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys)
|
||||
fmt.Printf(" 异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys)
|
||||
|
||||
// 测试总结
|
||||
fmt.Printf("\n📈 测试总结:\n")
|
||||
if globalTestReport.FailedTests == 0 {
|
||||
fmt.Printf(" 🎉 所有测试通过!安全中间件运行正常。\n")
|
||||
} else {
|
||||
fmt.Printf(" ⚠️ 有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests)
|
||||
}
|
||||
|
||||
fmt.Println(strings.Repeat("=", 80))
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
||||
startTime := time.Now()
|
||||
testName := "频率限制测试"
|
||||
|
||||
// 清理之前的测试数据
|
||||
redis.Del("security:ratelimit:ip:192.168.1.100")
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
|
||||
successCount := 0
|
||||
rateLimitHits := 0
|
||||
|
||||
// 前3次请求应该成功
|
||||
for i := 0; i < 3; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
if w.Code == http.StatusOK {
|
||||
successCount++
|
||||
} else {
|
||||
t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// 第4次请求应该被拒绝
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
if w.Code == http.StatusTooManyRequests {
|
||||
rateLimitHits++
|
||||
} else {
|
||||
t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 等待窗口过期后再次测试
|
||||
time.Sleep(11 * time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
|
||||
if w.Code == http.StatusOK {
|
||||
successCount++
|
||||
} else {
|
||||
t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 记录测试结果
|
||||
duration := time.Since(startTime)
|
||||
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
||||
"successCount": successCount,
|
||||
"rateLimitHits": rateLimitHits,
|
||||
"totalRequests": 5,
|
||||
})
|
||||
|
||||
// 更新性能指标
|
||||
if globalTestReport != nil {
|
||||
globalTestReport.Performance.TotalRequests += 5
|
||||
globalTestReport.Performance.RateLimitHits += rateLimitHits
|
||||
}
|
||||
}
|
||||
|
||||
func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
||||
startTime := time.Now()
|
||||
testName := "IP黑名单测试"
|
||||
|
||||
// 添加IP到黑名单
|
||||
redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.200")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
blacklistHit := false
|
||||
if w.Code == http.StatusForbidden {
|
||||
blacklistHit = true
|
||||
} else {
|
||||
t.Errorf("黑名单IP应该被拒绝,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 清理测试数据
|
||||
redis.Del("security:blacklist:ip:192.168.1.200")
|
||||
|
||||
// 记录测试结果
|
||||
duration := time.Since(startTime)
|
||||
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
||||
"blacklistHit": blacklistHit,
|
||||
"blockedIP": "192.168.1.200",
|
||||
})
|
||||
|
||||
// 更新性能指标
|
||||
if globalTestReport != nil {
|
||||
globalTestReport.Performance.TotalRequests++
|
||||
if blacklistHit {
|
||||
globalTestReport.Performance.BlacklistHits++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
||||
startTime := time.Now()
|
||||
testName := "异常检测测试"
|
||||
|
||||
// 测试空User-Agent
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
// 不设置User-Agent
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
anomalyDetected := false
|
||||
// 异常检测不应该阻止请求,只是记录
|
||||
if w.Code == http.StatusOK {
|
||||
anomalyDetected = true
|
||||
} else {
|
||||
t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 检查是否记录了异常
|
||||
key := "security:anomaly:ip:192.168.1.100"
|
||||
exists, _ := redis.Exists(key)
|
||||
if !exists {
|
||||
t.Log("异常检测记录可能已过期或未记录")
|
||||
}
|
||||
|
||||
// 记录测试结果
|
||||
duration := time.Since(startTime)
|
||||
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
||||
"anomalyDetected": anomalyDetected,
|
||||
"anomalyRecorded": exists,
|
||||
"testIP": "192.168.1.100",
|
||||
})
|
||||
|
||||
// 更新性能指标
|
||||
if globalTestReport != nil {
|
||||
globalTestReport.Performance.TotalRequests++
|
||||
if anomalyDetected {
|
||||
globalTestReport.Performance.AnomalyDetections++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 性能基准测试
|
||||
func BenchmarkSecurityMiddlewarePerformance(b *testing.B) {
|
||||
// 跳过基准测试,除非明确要求
|
||||
b.Skip("跳过基准测试,需要真实Redis环境")
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
)
|
||||
|
||||
// 创建测试配置
|
||||
func createTestConfig() *config.SecurityConfig {
|
||||
return &config.SecurityConfig{
|
||||
RateLimit: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
|
||||
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
|
||||
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
|
||||
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
|
||||
}{
|
||||
Enabled: true,
|
||||
WindowSize: 60,
|
||||
MaxRequests: 5,
|
||||
TriggerThreshold: 5,
|
||||
TriggerWindow: 24,
|
||||
},
|
||||
IPBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
UserBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
AnomalyDetection: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
BurstAttack: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
|
||||
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
|
||||
}{
|
||||
Enabled: true,
|
||||
TimeWindow: 1,
|
||||
MaxConcurrent: 20,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 测试客户端标识生成
|
||||
func TestClientIDGeneration(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
// 使用nil Redis进行测试,只测试不依赖Redis的逻辑
|
||||
middleware := NewSecurityMiddleware(config, nil)
|
||||
|
||||
// 测试IP标识
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
|
||||
clientID := middleware.getClientID(req)
|
||||
expected := "ip:192.168.1.100"
|
||||
if clientID != expected {
|
||||
t.Errorf("期望客户端标识 %s,但得到了 %s", expected, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试真实IP获取
|
||||
func TestRealIPExtraction(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
middleware := NewSecurityMiddleware(config, nil)
|
||||
|
||||
// 测试X-Forwarded-For
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.1, 192.168.1.1")
|
||||
|
||||
ip := middleware.getClientIP(req)
|
||||
expected := "203.0.113.1"
|
||||
if ip != expected {
|
||||
t.Errorf("期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
|
||||
// 测试X-Real-IP(创建新的请求对象)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("X-Real-IP", "198.51.100.1")
|
||||
ip = middleware.getClientIP(req2)
|
||||
expected = "198.51.100.1"
|
||||
if ip != expected {
|
||||
t.Errorf("期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
|
||||
// 测试直接连接
|
||||
req3 := httptest.NewRequest("GET", "/test", nil)
|
||||
req3.RemoteAddr = "192.168.1.50:12345"
|
||||
ip = middleware.getClientIP(req3)
|
||||
expected = "192.168.1.50"
|
||||
if ip != expected {
|
||||
t.Errorf("期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
|
||||
// 测试优先级:X-Forwarded-For 应该优先于 X-Real-IP
|
||||
req4 := httptest.NewRequest("GET", "/test", nil)
|
||||
req4.Header.Set("X-Forwarded-For", "10.0.0.1, 10.0.0.2")
|
||||
req4.Header.Set("X-Real-IP", "10.0.0.3")
|
||||
ip = middleware.getClientIP(req4)
|
||||
expected = "10.0.0.1" // X-Forwarded-For 应该优先
|
||||
if ip != expected {
|
||||
t.Errorf("优先级测试失败:期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试中间件创建
|
||||
func TestNewSecurityMiddleware(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
middleware := NewSecurityMiddleware(config, nil)
|
||||
|
||||
if middleware == nil {
|
||||
t.Error("中间件创建失败")
|
||||
}
|
||||
|
||||
if middleware.config != config {
|
||||
t.Error("配置设置失败")
|
||||
}
|
||||
}
|
||||
|
||||
// 测试配置验证
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
|
||||
if !config.RateLimit.Enabled {
|
||||
t.Error("频率限制应该启用")
|
||||
}
|
||||
|
||||
if config.RateLimit.MaxRequests != 5 {
|
||||
t.Error("最大请求数设置错误")
|
||||
}
|
||||
|
||||
if !config.IPBlacklist.Enabled {
|
||||
t.Error("IP黑名单应该启用")
|
||||
}
|
||||
|
||||
if !config.UserBlacklist.Enabled {
|
||||
t.Error("用户黑名单应该启用")
|
||||
}
|
||||
|
||||
if !config.AnomalyDetection.Enabled {
|
||||
t.Error("异常检测应该启用")
|
||||
}
|
||||
}
|
||||
74
app/main/api/internal/types/security.go
Normal file
74
app/main/api/internal/types/security.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package types
|
||||
|
||||
// GetSecurityStatsReq 获取安全统计信息请求
|
||||
type GetSecurityStatsReq struct {
|
||||
}
|
||||
|
||||
// GetSecurityStatsResp 获取安全统计信息响应
|
||||
type GetSecurityStatsResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// GetBlacklistReq 获取黑名单请求
|
||||
type GetBlacklistReq struct {
|
||||
Page int `form:"page,default=1"`
|
||||
PageSize int `form:"pageSize,default=20"`
|
||||
Type string `form:"type,optional"` // ip 或 user
|
||||
}
|
||||
|
||||
// GetBlacklistResp 获取黑名单响应
|
||||
type GetBlacklistResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data []BlacklistItem `json:"data"`
|
||||
}
|
||||
|
||||
// BlacklistItem 黑名单项
|
||||
type BlacklistItem struct {
|
||||
Type string `json:"type"` // ip 或 user
|
||||
Identifier string `json:"identifier"` // IP地址或用户ID
|
||||
ExpireAt int64 `json:"expireAt"` // 过期时间戳
|
||||
CreatedAt int64 `json:"createdAt"` // 创建时间戳
|
||||
}
|
||||
|
||||
// AddToBlacklistReq 添加到黑名单请求
|
||||
type AddToBlacklistReq struct {
|
||||
ClientType string `json:"clientType"` // ip 或 user
|
||||
Identifier string `json:"identifier"` // IP地址或用户ID
|
||||
Duration string `json:"duration"` // 持续时间,如 "1h", "24h"
|
||||
Reason string `json:"reason"` // 拉黑原因
|
||||
}
|
||||
|
||||
// AddToBlacklistResp 添加到黑名单响应
|
||||
type AddToBlacklistResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
// RemoveFromBlacklistReq 从黑名单移除请求
|
||||
type RemoveFromBlacklistReq struct {
|
||||
ClientType string `json:"clientType"` // ip 或 user
|
||||
Identifier string `json:"identifier"` // IP地址或用户ID
|
||||
}
|
||||
|
||||
// RemoveFromBlacklistResp 从黑名单移除响应
|
||||
type RemoveFromBlacklistResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
// GetSecurityEventsReq 获取安全事件请求
|
||||
type GetSecurityEventsReq struct {
|
||||
EventType string `form:"eventType,optional"` // 事件类型
|
||||
ClientID string `form:"clientID,optional"` // 客户端ID
|
||||
Limit int `form:"limit,default=50"` // 限制数量
|
||||
}
|
||||
|
||||
// GetSecurityEventsResp 获取安全事件响应
|
||||
type GetSecurityEventsResp struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data []string `json:"data"`
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
"tyc-server/app/main/api/internal/handler"
|
||||
middleware "tyc-server/app/main/api/internal/middleware/global"
|
||||
security "tyc-server/app/main/api/internal/middleware/security"
|
||||
"tyc-server/app/main/api/internal/queue"
|
||||
"tyc-server/app/main/api/internal/svc"
|
||||
|
||||
@@ -58,6 +59,10 @@ func main() {
|
||||
// 全局中间件
|
||||
server.Use(middleware.ReqHeaderCtxMiddleware)
|
||||
|
||||
// 全局安全中间件
|
||||
securityMiddleware := security.NewSecurityMiddleware(&c.Security, svcContext.Redis)
|
||||
server.Use(securityMiddleware.Handle)
|
||||
|
||||
defer server.Stop()
|
||||
|
||||
handler.RegisterHandlers(server, svcContext)
|
||||
|
||||
5
go.mod
5
go.mod
@@ -18,6 +18,7 @@ require (
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
github.com/smartwalle/alipay/v3 v3.2.25
|
||||
github.com/sony/sonyflake v1.2.1
|
||||
github.com/stretchr/testify v1.10.0
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/wechatpay-apiv3/wechatpay-go v0.2.20
|
||||
github.com/zeromicro/go-zero v1.8.3
|
||||
@@ -38,6 +39,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/clbanning/mxj/v2 v2.7.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.5 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/fatih/color v1.18.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
|
||||
@@ -60,6 +62,7 @@ require (
|
||||
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
|
||||
github.com/openzipkin/zipkin-go v0.4.3 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/prometheus/client_golang v1.22.0 // indirect
|
||||
github.com/prometheus/client_model v0.6.2 // indirect
|
||||
github.com/prometheus/common v0.64.0 // indirect
|
||||
@@ -71,6 +74,7 @@ require (
|
||||
github.com/smartwalle/nsign v1.0.9 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/spf13/cast v1.8.0 // indirect
|
||||
github.com/stretchr/objx v0.5.2 // indirect
|
||||
github.com/tidwall/match v1.1.1 // indirect
|
||||
github.com/tidwall/pretty v1.2.1 // indirect
|
||||
github.com/tjfoc/gmsm v1.4.1 // indirect
|
||||
@@ -98,4 +102,5 @@ require (
|
||||
google.golang.org/protobuf v1.36.6 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user