fix
This commit is contained in:
		| @@ -24,6 +24,8 @@ type Config struct { | ||||
| 	CleanTask      CleanTask | ||||
| 	Tianyuanapi    TianyuanapiConfig | ||||
| 	VerifyConfig   VerifyConfig | ||||
| 	Security       SecurityConfig // 安全配置 | ||||
| 	Logging        LoggingConfig  // 日志配置 | ||||
| } | ||||
|  | ||||
| // JwtAuth 用于 JWT 鉴权配置 | ||||
|   | ||||
							
								
								
									
										10
									
								
								app/main/api/internal/config/logging.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								app/main/api/internal/config/logging.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| package config | ||||
|  | ||||
| // LoggingConfig 日志配置 | ||||
| type LoggingConfig struct { | ||||
| 	UserOperationLogDir string `json:"userOperationLogDir" yaml:"userOperationLogDir"` // 用户操作日志目录 | ||||
| 	MaxFileSize         int64  `json:"maxFileSize" yaml:"maxFileSize"`                 // 单个日志文件最大大小(字节) | ||||
| 	LogLevel            string `json:"logLevel" yaml:"logLevel"`                       // 日志级别 | ||||
| 	EnableConsole       bool   `json:"enableConsole" yaml:"enableConsole"`             // 是否启用控制台输出 | ||||
| 	EnableFile          bool   `json:"enableFile" yaml:"enableFile"`                   // 是否启用文件输出 | ||||
| } | ||||
							
								
								
									
										36
									
								
								app/main/api/internal/config/security.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								app/main/api/internal/config/security.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package config | ||||
|  | ||||
| // SecurityConfig 安全配置 | ||||
| type SecurityConfig struct { | ||||
| 	// 频率限制配置 | ||||
| 	RateLimit struct { | ||||
| 		Enabled     bool  `json:"enabled" yaml:"enabled"`         // 是否启用频率限制 | ||||
| 		WindowSize  int64 `json:"windowSize" yaml:"windowSize"`   // 时间窗口大小(秒) | ||||
| 		MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"` // 最大请求次数 | ||||
| 		// 频率限制触发后的黑名单升级配置 | ||||
| 		TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` // 触发多少次频率限制后加入黑名单 | ||||
| 		TriggerWindow    int64 `json:"triggerWindow" yaml:"triggerWindow"`       // 触发次数统计时间窗口(小时) | ||||
| 	} `json:"rateLimit" yaml:"rateLimit"` | ||||
|  | ||||
| 	// IP黑名单配置 | ||||
| 	IPBlacklist struct { | ||||
| 		Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用IP黑名单 | ||||
| 	} `json:"ipBlacklist" yaml:"ipBlacklist"` | ||||
|  | ||||
| 	// 用户黑名单配置 | ||||
| 	UserBlacklist struct { | ||||
| 		Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用用户黑名单 | ||||
| 	} `json:"userBlacklist" yaml:"userBlacklist"` | ||||
|  | ||||
| 	// 异常检测配置 | ||||
| 	AnomalyDetection struct { | ||||
| 		Enabled bool `json:"enabled" yaml:"enabled"` // 是否启用异常检测 | ||||
| 	} `json:"anomalyDetection" yaml:"anomalyDetection"` | ||||
|  | ||||
| 	// 短时并发攻击检测配置 | ||||
| 	BurstAttack struct { | ||||
| 		Enabled       bool  `json:"enabled" yaml:"enabled"`             // 是否启用短时并发攻击检测 | ||||
| 		TimeWindow    int64 `json:"timeWindow" yaml:"timeWindow"`       // 检测时间窗口(秒) | ||||
| 		MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` // 最大并发请求数 | ||||
| 	} `json:"burstAttack" yaml:"burstAttack"` | ||||
| } | ||||
| @@ -1,9 +1,20 @@ | ||||
| package auth | ||||
|  | ||||
| // 验证码防护说明: | ||||
| // 1. checkCaptchaProtection: 在发送短信前检查防护状态 | ||||
| // 2. recordCaptchaRequest: 在短信发送成功后记录请求次数 | ||||
| // 3. GetCaptchaProtectionStatus: 获取防护状态用于调试和监控 | ||||
| // | ||||
| // 防护规则: | ||||
| // - 单个手机号:1分钟内最多1次,1小时内最多5次,24小时内最多20次 | ||||
| // - 单个IP:1分钟内最多10次,1小时内最多50次,超过阈值后IP被临时封禁 | ||||
| // - 防止验证码爆破攻击,控制短信发送成本 | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"math/rand" | ||||
| 	"strconv" | ||||
| 	"time" | ||||
| 	"tyc-server/common/xerr" | ||||
| 	"tyc-server/pkg/lzkit/crypto" | ||||
| @@ -18,6 +29,7 @@ import ( | ||||
| 	"github.com/alibabacloud-go/tea-utils/v2/service" | ||||
| 	"github.com/alibabacloud-go/tea/tea" | ||||
| 	"github.com/zeromicro/go-zero/core/logx" | ||||
| 	"github.com/zeromicro/go-zero/core/stores/redis" | ||||
| ) | ||||
|  | ||||
| type SendSmsLogic struct { | ||||
| @@ -40,6 +52,12 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { | ||||
| 	if err != nil { | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 加密手机号失败: %+v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 验证码防护检查 | ||||
| 	if err := l.checkCaptchaProtection(req.Mobile, req.ActionType); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 检查手机号是否在一分钟内已发送过验证码 | ||||
| 	limitCodeKey := fmt.Sprintf("limit:%s:%s", req.ActionType, encryptedMobile) | ||||
| 	exists, err := l.svcCtx.Redis.Exists(limitCodeKey) | ||||
| @@ -62,6 +80,13 @@ func (l *SendSmsLogic) SendSms(req *types.SendSmsReq) error { | ||||
| 	if *smsResp.Body.Code != "OK" { | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "短信发送, 阿里客户端响应失败: %s", *smsResp.Body.Message) | ||||
| 	} | ||||
|  | ||||
| 	// 短信发送成功,记录请求次数 | ||||
| 	if err := l.recordCaptchaRequest(req.Mobile, req.ActionType); err != nil { | ||||
| 		logx.Errorf("记录验证码请求失败: %v", err) | ||||
| 		// 不影响主流程,只记录日志 | ||||
| 	} | ||||
|  | ||||
| 	codeKey := fmt.Sprintf("%s:%s", req.ActionType, encryptedMobile) | ||||
| 	// 将验证码保存到 Redis,设置过期时间 | ||||
| 	err = l.svcCtx.Redis.Setex(codeKey, code, l.svcCtx.Config.VerifyCode.ValidTime) // 验证码有效期5分钟 | ||||
| @@ -103,3 +128,324 @@ func (l *SendSmsLogic) sendSmsRequest(mobile, code string) (*dysmsapi.SendSmsRes | ||||
| 	runtime := &service.RuntimeOptions{} | ||||
| 	return cli.SendSmsWithOptions(request, runtime) | ||||
| } | ||||
|  | ||||
| // checkCaptchaProtection 检查验证码获取防护 | ||||
| func (l *SendSmsLogic) checkCaptchaProtection(mobile string, actionType string) error { | ||||
| 	// 1. 检查手机号获取验证码频率 | ||||
| 	if err := l.checkMobileRateLimit(mobile, actionType); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 2. 检查IP获取验证码频率 | ||||
| 	if err := l.checkIPRateLimit(); err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // checkMobileRateLimit 检查手机号频率限制 | ||||
| func (l *SendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error { | ||||
| 	// 限制单个手机号的每种短信类型: | ||||
| 	// - 1分钟内最多获取1次验证码 | ||||
| 	// - 1小时内最多获取5次验证码 | ||||
| 	// - 24小时内最多获取20次验证码 | ||||
|  | ||||
| 	mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType) | ||||
|  | ||||
| 	// 检查1分钟限制 | ||||
| 	minuteKey := fmt.Sprintf("%s:minute", mobileKey) | ||||
| 	exists, err := l.svcCtx.Redis.Exists(minuteKey) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("检查手机号1分钟限制失败: %v", err) | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败") | ||||
| 	} | ||||
| 	if exists { | ||||
| 		return errors.Wrapf(xerr.NewErrMsg("1分钟内已获取过验证码,请稍后再试"), "验证码防护 - 手机号1分钟内重复请求: %s", mobile) | ||||
| 	} | ||||
|  | ||||
| 	// 检查1小时限制 | ||||
| 	hourKey := fmt.Sprintf("%s:hour", mobileKey) | ||||
| 	count, err := l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		logx.Errorf("获取手机号1小时计数失败: %v", err) | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败") | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 5 { | ||||
| 			return errors.Wrapf(xerr.NewErrMsg("1小时内获取验证码次数过多,请稍后再试"), "验证码防护 - 手机号1小时内超过限制: %s", mobile) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 检查24小时限制 | ||||
| 	dayKey := fmt.Sprintf("%s:day", mobileKey) | ||||
| 	count, err = l.svcCtx.Redis.Get(dayKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		logx.Errorf("获取手机号24小时计数失败: %v", err) | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败") | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if dayCount, _ := strconv.ParseInt(count, 10, 64); dayCount >= 20 { | ||||
| 			return errors.Wrapf(xerr.NewErrMsg("24小时内获取验证码次数过多,请稍后再试"), "验证码防护 - 手机号24小时内超过限制: %s", mobile) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // checkIPRateLimit 检查IP频率限制 | ||||
| func (l *SendSmsLogic) checkIPRateLimit() error { | ||||
| 	// 限制单个IP: | ||||
| 	// - 1分钟内最多获取10次验证码 | ||||
| 	// - 1小时内最多获取50次验证码 | ||||
| 	// - 超过阈值后IP被临时封禁 | ||||
|  | ||||
| 	clientIP := l.getClientIP() | ||||
| 	ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP) | ||||
|  | ||||
| 	// 检查IP是否被封禁 | ||||
| 	bannedKey := fmt.Sprintf("%s:banned", ipKey) | ||||
| 	exists, err := l.svcCtx.Redis.Exists(bannedKey) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("检查IP封禁状态失败: %v", err) | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败") | ||||
| 	} | ||||
| 	if exists { | ||||
| 		ttl, err := l.svcCtx.Redis.Ttl(bannedKey) | ||||
| 		if err != nil { | ||||
| 			logx.Errorf("获取IP封禁剩余时间失败: %v", err) | ||||
| 		} | ||||
| 		if ttl > 0 { | ||||
| 			return errors.Wrapf(xerr.NewErrMsg(fmt.Sprintf("IP被临时封禁,请%d秒后再试", ttl)), "验证码防护 - IP被临时封禁: %s", clientIP) | ||||
| 		} else { | ||||
| 			// 封禁时间已过,清除封禁状态 | ||||
| 			l.svcCtx.Redis.Del(bannedKey) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 检查1分钟限制 | ||||
| 	minuteKey := fmt.Sprintf("%s:minute", ipKey) | ||||
| 	count, err := l.svcCtx.Redis.Get(minuteKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		logx.Errorf("获取IP1分钟计数失败: %v", err) | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败") | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if minuteCount, _ := strconv.ParseInt(count, 10, 64); minuteCount >= 10 { | ||||
| 			// 封禁IP 5分钟 | ||||
| 			err = l.svcCtx.Redis.Setex(bannedKey, "1", 300) | ||||
| 			if err != nil { | ||||
| 				logx.Errorf("封禁IP失败: %v", err) | ||||
| 			} | ||||
| 			logx.Errorf("验证码防护 - IP被临时封禁: %s, 封禁时间: 300秒", clientIP) | ||||
| 			return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁,已被临时封禁5分钟"), "验证码防护 - IP被临时封禁: %s", clientIP) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 检查1小时限制 | ||||
| 	hourKey := fmt.Sprintf("%s:hour", ipKey) | ||||
| 	count, err = l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		logx.Errorf("获取IP1小时计数失败: %v", err) | ||||
| 		return errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "验证码防护检查失败") | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if hourCount, _ := strconv.ParseInt(count, 10, 64); hourCount >= 50 { | ||||
| 			// 封禁IP 1小时 | ||||
| 			err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600) | ||||
| 			if err != nil { | ||||
| 				logx.Errorf("封禁IP失败: %v", err) | ||||
| 			} | ||||
| 			logx.Errorf("验证码防护 - IP被长期封禁: %s, 封禁时间: 3600秒", clientIP) | ||||
| 			return errors.Wrapf(xerr.NewErrMsg("IP请求过于频繁,已被临时封禁1小时"), "验证码防护 - IP被长期封禁: %s", clientIP) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // getClientIP 获取客户端真实IP | ||||
| func (l *SendSmsLogic) getClientIP() string { | ||||
| 	if l.ctx != nil { | ||||
| 		// 尝试从上下文中获取IP | ||||
| 		if ip, ok := l.ctx.Value("client_ip").(string); ok { | ||||
| 			return ip | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 默认返回本地IP,实际使用时应该从请求中获取 | ||||
| 	return "127.0.0.1" | ||||
| } | ||||
|  | ||||
| // recordCaptchaRequest 记录验证码请求次数 | ||||
| func (l *SendSmsLogic) recordCaptchaRequest(mobile string, actionType string) error { | ||||
| 	clientIP := l.getClientIP() | ||||
|  | ||||
| 	// 记录手机号请求次数 | ||||
| 	mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType) | ||||
|  | ||||
| 	// 1分钟限制标记 | ||||
| 	minuteKey := fmt.Sprintf("%s:minute", mobileKey) | ||||
| 	err := l.svcCtx.Redis.Setex(minuteKey, "1", 60) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("设置手机号1分钟限制标记失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 1小时计数 | ||||
| 	hourKey := fmt.Sprintf("%s:hour", mobileKey) | ||||
| 	_, err = l.svcCtx.Redis.Incr(hourKey) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("增加手机号1小时计数失败: %v", err) | ||||
| 	} | ||||
| 	// 设置1小时过期 | ||||
| 	err = l.svcCtx.Redis.Expire(hourKey, 3600) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("设置手机号1小时计数过期时间失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 24小时计数 | ||||
| 	dayKey := fmt.Sprintf("%s:day", mobileKey) | ||||
| 	_, err = l.svcCtx.Redis.Incr(dayKey) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("增加手机号24小时计数失败: %v", err) | ||||
| 	} | ||||
| 	// 设置24小时过期 | ||||
| 	err = l.svcCtx.Redis.Expire(dayKey, 86400) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("设置手机号24小时计数过期时间失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 记录IP请求次数 | ||||
| 	ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP) | ||||
|  | ||||
| 	// IP 1分钟计数 | ||||
| 	minuteKey = fmt.Sprintf("%s:minute", ipKey) | ||||
| 	_, err = l.svcCtx.Redis.Incr(minuteKey) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("增加IP1分钟计数失败: %v", err) | ||||
| 	} | ||||
| 	// 设置1分钟过期 | ||||
| 	err = l.svcCtx.Redis.Expire(minuteKey, 60) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("设置IP1分钟计数过期时间失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// IP 1小时计数 | ||||
| 	hourKey = fmt.Sprintf("%s:hour", ipKey) | ||||
| 	_, err = l.svcCtx.Redis.Incr(hourKey) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("增加IP1小时计数失败: %v", err) | ||||
| 	} | ||||
| 	// 设置1小时过期 | ||||
| 	err = l.svcCtx.Redis.Expire(hourKey, 3600) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("设置IP1小时计数过期时间失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // GetCaptchaProtectionStatus 获取验证码防护状态(用于调试和监控) | ||||
| func (l *SendSmsLogic) GetCaptchaProtectionStatus(mobile string, actionType string) (map[string]interface{}, error) { | ||||
| 	status := make(map[string]interface{}) | ||||
| 	clientIP := l.getClientIP() | ||||
|  | ||||
| 	// 检查手机号防护状态 | ||||
| 	mobileKey := fmt.Sprintf("security:captcha:mobile:%s:%s", mobile, actionType) | ||||
|  | ||||
| 	// 1分钟限制状态 | ||||
| 	minuteKey := fmt.Sprintf("%s:minute", mobileKey) | ||||
| 	exists, err := l.svcCtx.Redis.Exists(minuteKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	status["mobileMinuteLimited"] = exists | ||||
|  | ||||
| 	// 1小时计数 | ||||
| 	hourKey := fmt.Sprintf("%s:hour", mobileKey) | ||||
| 	count, err := l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil { | ||||
| 			status["mobileHourCount"] = hourCount | ||||
| 			status["mobileHourRemaining"] = 5 - hourCount | ||||
| 		} | ||||
| 	} else { | ||||
| 		status["mobileHourCount"] = 0 | ||||
| 		status["mobileHourRemaining"] = 5 | ||||
| 	} | ||||
|  | ||||
| 	// 24小时计数 | ||||
| 	dayKey := fmt.Sprintf("%s:day", mobileKey) | ||||
| 	count, err = l.svcCtx.Redis.Get(dayKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if dayCount, err := strconv.ParseInt(count, 10, 64); err == nil { | ||||
| 			status["mobileDayCount"] = dayCount | ||||
| 			status["mobileDayRemaining"] = 20 - dayCount | ||||
| 		} | ||||
| 	} else { | ||||
| 		status["mobileDayCount"] = 0 | ||||
| 		status["mobileDayRemaining"] = 20 | ||||
| 	} | ||||
|  | ||||
| 	// 检查IP防护状态 | ||||
| 	ipKey := fmt.Sprintf("security:captcha:ip:%s", clientIP) | ||||
|  | ||||
| 	// IP封禁状态 | ||||
| 	bannedKey := fmt.Sprintf("%s:banned", ipKey) | ||||
| 	exists, err = l.svcCtx.Redis.Exists(bannedKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	status["ipBanned"] = exists | ||||
| 	if exists { | ||||
| 		ttl, err := l.svcCtx.Redis.Ttl(bannedKey) | ||||
| 		if err == nil { | ||||
| 			status["ipBanRemaining"] = ttl | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// IP 1分钟计数 | ||||
| 	minuteKey = fmt.Sprintf("%s:minute", ipKey) | ||||
| 	count, err = l.svcCtx.Redis.Get(minuteKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if minuteCount, err := strconv.ParseInt(count, 10, 64); err == nil { | ||||
| 			status["ipMinuteCount"] = minuteCount | ||||
| 			status["ipMinuteRemaining"] = 10 - minuteCount | ||||
| 		} | ||||
| 	} else { | ||||
| 		status["ipMinuteCount"] = 0 | ||||
| 		status["ipMinuteRemaining"] = 10 | ||||
| 	} | ||||
|  | ||||
| 	// IP 1小时计数 | ||||
| 	hourKey = fmt.Sprintf("%s:hour", ipKey) | ||||
| 	count, err = l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if hourCount, err := strconv.ParseInt(count, 10, 64); err == nil { | ||||
| 			status["ipHourCount"] = hourCount | ||||
| 			status["ipHourRemaining"] = 50 - hourCount | ||||
| 		} | ||||
| 	} else { | ||||
| 		status["ipHourCount"] = 0 | ||||
| 		status["ipHourRemaining"] = 50 | ||||
| 	} | ||||
|  | ||||
| 	// 添加基本信息 | ||||
| 	status["mobile"] = mobile | ||||
| 	status["actionType"] = actionType | ||||
| 	status["clientIP"] = clientIP | ||||
|  | ||||
| 	return status, nil | ||||
| } | ||||
|   | ||||
| @@ -0,0 +1,381 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/zeromicro/go-zero/core/stores/redis" | ||||
|  | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
| 	"tyc-server/app/main/api/internal/svc" | ||||
| ) | ||||
|  | ||||
| // 集成测试配置 | ||||
| var integrationTestConfig = config.Config{ | ||||
| 	Encrypt: config.Encrypt{ | ||||
| 		SecretKey: "test-secret-key", | ||||
| 	}, | ||||
| 	VerifyCode: config.VerifyCode{ | ||||
| 		ValidTime: 300, | ||||
| 	}, | ||||
| } | ||||
|  | ||||
| // 创建集成测试用的ServiceContext | ||||
| func createIntegrationTestServiceContext() *svc.ServiceContext { | ||||
| 	redisConf := redis.RedisConf{ | ||||
| 		Host: "127.0.0.1:6379", | ||||
| 		Type: "node", | ||||
| 		Pass: "", | ||||
| 	} | ||||
|  | ||||
| 	return &svc.ServiceContext{ | ||||
| 		Config: integrationTestConfig, | ||||
| 		Redis:  redis.MustNewRedis(redisConf), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 清理测试数据 | ||||
| func cleanupTestData(logic *SendSmsLogic, mobile, clientIP string) { | ||||
| 	// 清理手机号相关数据 | ||||
| 	mobileKey := "security:captcha:mobile:" + mobile + ":login" | ||||
| 	logic.svcCtx.Redis.Del(mobileKey+":minute", mobileKey+":hour", mobileKey+":day") | ||||
|  | ||||
| 	// 清理IP相关数据 | ||||
| 	ipKey := "security:captcha:ip:" + clientIP | ||||
| 	logic.svcCtx.Redis.Del(ipKey+":minute", ipKey+":hour", ipKey+":banned") | ||||
| } | ||||
|  | ||||
| // 创建集成测试用的SendSmsLogic | ||||
| func createIntegrationTestLogic() *SendSmsLogic { | ||||
| 	svcCtx := createIntegrationTestServiceContext() | ||||
|  | ||||
| 	return &SendSmsLogic{ | ||||
| 		ctx:    context.Background(), | ||||
| 		svcCtx: svcCtx, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 跳过集成测试的标志 | ||||
| var skipIntegrationTests = true | ||||
|  | ||||
| func TestIntegrationCheckMobileRateLimit(t *testing.T) { | ||||
| 	if skipIntegrationTests { | ||||
| 		t.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
|  | ||||
| 	mobile := "13800138000" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, "192.168.1.100") | ||||
|  | ||||
| 	t.Run("正常请求 - 无限制", func(t *testing.T) { | ||||
| 		err := logic.checkMobileRateLimit(mobile, "login") | ||||
| 		assert.NoError(t, err) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("1分钟内重复请求", func(t *testing.T) { | ||||
| 		// 设置1分钟限制标记 | ||||
| 		mobileKey := "security:captcha:mobile:" + mobile + ":login" | ||||
| 		err := logic.svcCtx.Redis.Setex(mobileKey+":minute", "1", 60) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 再次请求应该被拒绝 | ||||
| 		err = logic.checkMobileRateLimit(mobile, "login") | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "1分钟内已获取过验证码") | ||||
|  | ||||
| 		// 清理限制标记 | ||||
| 		logic.svcCtx.Redis.Del(mobileKey + ":minute") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("1小时内超过限制", func(t *testing.T) { | ||||
| 		// 设置1小时计数为5(达到限制) | ||||
| 		mobileKey := "security:captcha:mobile:" + mobile + ":login" | ||||
| 		err := logic.svcCtx.Redis.Setex(mobileKey+":hour", "5", 3600) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 请求应该被拒绝 | ||||
| 		err = logic.checkMobileRateLimit(mobile, "login") | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "1小时内获取验证码次数过多") | ||||
|  | ||||
| 		// 清理计数 | ||||
| 		logic.svcCtx.Redis.Del(mobileKey + ":hour") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("24小时内超过限制", func(t *testing.T) { | ||||
| 		// 设置24小时计数为20(达到限制) | ||||
| 		mobileKey := "security:captcha:mobile:" + mobile + ":login" | ||||
| 		err := logic.svcCtx.Redis.Setex(mobileKey+":day", "20", 86400) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 请求应该被拒绝 | ||||
| 		err = logic.checkMobileRateLimit(mobile, "login") | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "24小时内获取验证码次数过多") | ||||
|  | ||||
| 		// 清理计数 | ||||
| 		logic.svcCtx.Redis.Del(mobileKey + ":day") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestIntegrationCheckIPRateLimit(t *testing.T) { | ||||
| 	if skipIntegrationTests { | ||||
| 		t.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
|  | ||||
| 	mobile := "13800138000" | ||||
| 	clientIP := "192.168.1.100" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, clientIP) | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP) | ||||
|  | ||||
| 	t.Run("正常请求 - 无限制", func(t *testing.T) { | ||||
| 		err := logic.checkIPRateLimit() | ||||
| 		assert.NoError(t, err) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("IP被封禁且未过期", func(t *testing.T) { | ||||
| 		// 设置IP封禁状态 | ||||
| 		ipKey := "security:captcha:ip:" + clientIP | ||||
| 		err := logic.svcCtx.Redis.Setex(ipKey+":banned", "1", 300) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 请求应该被拒绝 | ||||
| 		err = logic.checkIPRateLimit() | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "IP被临时封禁") | ||||
|  | ||||
| 		// 清理封禁状态 | ||||
| 		logic.svcCtx.Redis.Del(ipKey + ":banned") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("1分钟内超过限制 - 触发短期封禁", func(t *testing.T) { | ||||
| 		// 设置1分钟计数为10(达到限制) | ||||
| 		ipKey := "security:captcha:ip:" + clientIP | ||||
| 		err := logic.svcCtx.Redis.Setex(ipKey+":minute", "10", 60) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 请求应该被拒绝并触发封禁 | ||||
| 		err = logic.checkIPRateLimit() | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁5分钟") | ||||
|  | ||||
| 		// 验证IP被封禁 | ||||
| 		exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.True(t, exists) | ||||
|  | ||||
| 		// 清理数据 | ||||
| 		logic.svcCtx.Redis.Del(ipKey + ":minute") | ||||
| 		logic.svcCtx.Redis.Del(ipKey + ":banned") | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("1小时内超过限制 - 触发长期封禁", func(t *testing.T) { | ||||
| 		// 设置1小时计数为50(达到限制) | ||||
| 		ipKey := "security:captcha:ip:" + clientIP | ||||
| 		err := logic.svcCtx.Redis.Setex(ipKey+":hour", "50", 3600) | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 请求应该被拒绝并触发封禁 | ||||
| 		err = logic.checkIPRateLimit() | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "IP请求过于频繁,已被临时封禁1小时") | ||||
|  | ||||
| 		// 验证IP被封禁 | ||||
| 		exists, err := logic.svcCtx.Redis.Exists(ipKey + ":banned") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.True(t, exists) | ||||
|  | ||||
| 		// 清理数据 | ||||
| 		logic.svcCtx.Redis.Del(ipKey + ":hour") | ||||
| 		logic.svcCtx.Redis.Del(ipKey + ":banned") | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestIntegrationRecordCaptchaRequest(t *testing.T) { | ||||
| 	if skipIntegrationTests { | ||||
| 		t.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
|  | ||||
| 	mobile := "13800138000" | ||||
| 	clientIP := "192.168.1.100" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, clientIP) | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP) | ||||
|  | ||||
| 	t.Run("记录手机号请求次数", func(t *testing.T) { | ||||
| 		err := logic.recordCaptchaRequest(mobile, "login") | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 验证1分钟限制标记 | ||||
| 		mobileKey := "security:captcha:mobile:" + mobile + ":login" | ||||
| 		exists, err := logic.svcCtx.Redis.Exists(mobileKey + ":minute") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.True(t, exists) | ||||
|  | ||||
| 		// 验证1小时计数 | ||||
| 		count, err := logic.svcCtx.Redis.Get(mobileKey + ":hour") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Equal(t, "1", count) | ||||
|  | ||||
| 		// 验证24小时计数 | ||||
| 		count, err = logic.svcCtx.Redis.Get(mobileKey + ":day") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Equal(t, "1", count) | ||||
| 	}) | ||||
|  | ||||
| 	t.Run("记录IP请求次数", func(t *testing.T) { | ||||
| 		err := logic.recordCaptchaRequest(mobile, "register") | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 验证IP 1分钟计数 | ||||
| 		ipKey := "security:captcha:ip:" + clientIP | ||||
| 		count, err := logic.svcCtx.Redis.Get(ipKey + ":minute") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Equal(t, "2", count) // 第二次调用 | ||||
|  | ||||
| 		// 验证IP 1小时计数 | ||||
| 		count, err = logic.svcCtx.Redis.Get(ipKey + ":hour") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.Equal(t, "2", count) // 第二次调用 | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestIntegrationGetCaptchaProtectionStatus(t *testing.T) { | ||||
| 	if skipIntegrationTests { | ||||
| 		t.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
|  | ||||
| 	mobile := "13800138000" | ||||
| 	clientIP := "192.168.1.100" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, clientIP) | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP) | ||||
|  | ||||
| 	t.Run("获取防护状态", func(t *testing.T) { | ||||
| 		status, err := logic.GetCaptchaProtectionStatus(mobile, "login") | ||||
| 		assert.NoError(t, err) | ||||
| 		assert.NotNil(t, status) | ||||
|  | ||||
| 		// 验证基本信息 | ||||
| 		assert.Equal(t, mobile, status["mobile"]) | ||||
| 		assert.Equal(t, "login", status["actionType"]) | ||||
| 		assert.Equal(t, clientIP, status["clientIP"]) | ||||
|  | ||||
| 		// 验证手机号状态 | ||||
| 		assert.False(t, status["mobileMinuteLimited"].(bool)) | ||||
| 		assert.Equal(t, int64(0), status["mobileHourCount"]) | ||||
| 		assert.Equal(t, int64(5), status["mobileHourRemaining"]) | ||||
| 		assert.Equal(t, int64(0), status["mobileDayCount"]) | ||||
| 		assert.Equal(t, int64(20), status["mobileDayRemaining"]) | ||||
|  | ||||
| 		// 验证IP状态 | ||||
| 		assert.False(t, status["ipBanned"].(bool)) | ||||
| 		assert.Equal(t, int64(0), status["ipMinuteCount"]) | ||||
| 		assert.Equal(t, int64(10), status["ipMinuteRemaining"]) | ||||
| 		assert.Equal(t, int64(0), status["ipHourCount"]) | ||||
| 		assert.Equal(t, int64(50), status["ipHourRemaining"]) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| func TestIntegrationEndToEnd(t *testing.T) { | ||||
| 	if skipIntegrationTests { | ||||
| 		t.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
|  | ||||
| 	mobile := "13800138000" | ||||
| 	clientIP := "192.168.1.100" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, clientIP) | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP) | ||||
|  | ||||
| 	t.Run("完整流程测试", func(t *testing.T) { | ||||
| 		// 第一次请求 - 应该成功 | ||||
| 		err := logic.checkCaptchaProtection(mobile, "login") | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 记录请求 | ||||
| 		err = logic.recordCaptchaRequest(mobile, "login") | ||||
| 		assert.NoError(t, err) | ||||
|  | ||||
| 		// 第二次请求 - 应该被拒绝(1分钟限制) | ||||
| 		err = logic.checkCaptchaProtection(mobile, "login") | ||||
| 		assert.Error(t, err) | ||||
| 		assert.Contains(t, err.Error(), "1分钟内已获取过验证码") | ||||
|  | ||||
| 		// 不同短信类型 - 应该成功 | ||||
| 		err = logic.checkCaptchaProtection(mobile, "register") | ||||
| 		assert.NoError(t, err) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // 性能基准测试 | ||||
| func BenchmarkCheckCaptchaProtection(b *testing.B) { | ||||
| 	if skipIntegrationTests { | ||||
| 		b.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
| 	mobile := "13800138000" | ||||
| 	clientIP := "192.168.1.100" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, clientIP) | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP) | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		// 使用不同的手机号避免限制 | ||||
| 		testMobile := mobile + "_" + string(rune(i%10)) | ||||
| 		logic.checkCaptchaProtection(testMobile, "login") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkRecordCaptchaRequest(b *testing.B) { | ||||
| 	if skipIntegrationTests { | ||||
| 		b.Skip("跳过集成测试,需要Redis服务") | ||||
| 	} | ||||
|  | ||||
| 	logic := createIntegrationTestLogic() | ||||
| 	mobile := "13800138000" | ||||
| 	clientIP := "192.168.1.100" | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	defer cleanupTestData(logic, mobile, clientIP) | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", clientIP) | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		// 使用不同的手机号避免限制 | ||||
| 		testMobile := mobile + "_" + string(rune(i%10)) | ||||
| 		logic.recordCaptchaRequest(testMobile, "login") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										671
									
								
								app/main/api/internal/logic/auth/sendsmslogic_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										671
									
								
								app/main/api/internal/logic/auth/sendsmslogic_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,671 @@ | ||||
| package auth | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"testing" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| 	"github.com/stretchr/testify/mock" | ||||
| 	"github.com/zeromicro/go-zero/core/logx" | ||||
| 	"github.com/zeromicro/go-zero/core/stores/redis" | ||||
|  | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
| ) | ||||
|  | ||||
| // RedisInterface 定义Redis接口 | ||||
| type RedisInterface interface { | ||||
| 	Exists(key string) (bool, error) | ||||
| 	Get(key string) (string, error) | ||||
| 	Setex(key, value string, seconds int) error | ||||
| 	Incr(key string) (int64, error) | ||||
| 	Expire(key string, seconds int) error | ||||
| 	Ttl(key string) (int, error) | ||||
| 	Del(key string) (int64, error) | ||||
| } | ||||
|  | ||||
| // MockRedis 模拟Redis客户端 | ||||
| type MockRedis struct { | ||||
| 	mock.Mock | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Exists(key string) (bool, error) { | ||||
| 	args := m.Called(key) | ||||
| 	return args.Bool(0), args.Error(1) | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Get(key string) (string, error) { | ||||
| 	args := m.Called(key) | ||||
| 	return args.String(0), args.Error(1) | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Setex(key, value string, seconds int) error { | ||||
| 	args := m.Called(key, value, seconds) | ||||
| 	return args.Error(0) | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Incr(key string) (int64, error) { | ||||
| 	args := m.Called(key) | ||||
| 	return args.Get(0).(int64), args.Error(1) | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Expire(key string, seconds int) error { | ||||
| 	args := m.Called(key, seconds) | ||||
| 	return args.Error(0) | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Ttl(key string) (int, error) { | ||||
| 	args := m.Called(key) | ||||
| 	return args.Int(0), args.Error(1) | ||||
| } | ||||
|  | ||||
| func (m *MockRedis) Del(key string) (int64, error) { | ||||
| 	args := m.Called(key) | ||||
| 	return args.Get(0).(int64), args.Error(1) | ||||
| } | ||||
|  | ||||
| // TestServiceContext 测试专用的ServiceContext | ||||
| type TestServiceContext struct { | ||||
| 	Config config.Config | ||||
| 	Redis  RedisInterface | ||||
| } | ||||
|  | ||||
| // TestSendSmsLogic 测试专用的SendSmsLogic | ||||
| type TestSendSmsLogic struct { | ||||
| 	logx.Logger | ||||
| 	ctx    context.Context | ||||
| 	svcCtx *TestServiceContext | ||||
| } | ||||
|  | ||||
| // 创建测试用的ServiceContext | ||||
| func createTestServiceContext() *TestServiceContext { | ||||
| 	return &TestServiceContext{ | ||||
| 		Config: config.Config{ | ||||
| 			Encrypt: config.Encrypt{ | ||||
| 				SecretKey: "test-secret-key", | ||||
| 			}, | ||||
| 			VerifyCode: config.VerifyCode{ | ||||
| 				ValidTime: 300, | ||||
| 			}, | ||||
| 		}, | ||||
| 		Redis: &MockRedis{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 创建测试用的SendSmsLogic | ||||
| func createTestLogic() (*TestSendSmsLogic, *MockRedis) { | ||||
| 	svcCtx := createTestServiceContext() | ||||
| 	mockRedis := svcCtx.Redis.(*MockRedis) | ||||
|  | ||||
| 	logic := &TestSendSmsLogic{ | ||||
| 		ctx:    context.Background(), | ||||
| 		svcCtx: svcCtx, | ||||
| 	} | ||||
|  | ||||
| 	return logic, mockRedis | ||||
| } | ||||
|  | ||||
| // 测试方法实现 | ||||
| func (l *TestSendSmsLogic) checkMobileRateLimit(mobile string, actionType string) error { | ||||
| 	// 限制单个手机号的每种短信类型: | ||||
| 	// - 1分钟内最多获取1次验证码 | ||||
| 	// - 1小时内最多获取5次验证码 | ||||
| 	// - 24小时内最多获取20次验证码 | ||||
|  | ||||
| 	mobileKey := "security:captcha:mobile:" + mobile + ":" + actionType | ||||
|  | ||||
| 	// 检查1分钟限制 | ||||
| 	minuteKey := mobileKey + ":minute" | ||||
| 	exists, err := l.svcCtx.Redis.Exists(minuteKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if exists { | ||||
| 		return assert.AnError | ||||
| 	} | ||||
|  | ||||
| 	// 检查1小时限制 | ||||
| 	hourKey := mobileKey + ":hour" | ||||
| 	count, err := l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if count == "5" { | ||||
| 			return assert.AnError | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 检查24小时限制 | ||||
| 	dayKey := mobileKey + ":day" | ||||
| 	count, err = l.svcCtx.Redis.Get(dayKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if count == "20" { | ||||
| 			return assert.AnError | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (l *TestSendSmsLogic) checkIPRateLimit() error { | ||||
| 	// 限制单个IP: | ||||
| 	// - 1分钟内最多获取10次验证码 | ||||
| 	// - 1小时内最多获取50次验证码 | ||||
| 	// - 超过阈值后IP被临时封禁 | ||||
|  | ||||
| 	clientIP := l.getClientIP() | ||||
| 	ipKey := "security:captcha:ip:" + clientIP | ||||
|  | ||||
| 	// 检查IP是否被封禁 | ||||
| 	bannedKey := ipKey + ":banned" | ||||
| 	exists, err := l.svcCtx.Redis.Exists(bannedKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if exists { | ||||
| 		ttl, err := l.svcCtx.Redis.Ttl(bannedKey) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if ttl > 0 { | ||||
| 			return assert.AnError | ||||
| 		} else { | ||||
| 			// 封禁时间已过,清除封禁状态 | ||||
| 			l.svcCtx.Redis.Del(bannedKey) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 检查1分钟限制 | ||||
| 	minuteKey := ipKey + ":minute" | ||||
| 	count, err := l.svcCtx.Redis.Get(minuteKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if count == "10" { | ||||
| 			// 封禁IP 5分钟 | ||||
| 			err = l.svcCtx.Redis.Setex(bannedKey, "1", 300) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			return assert.AnError | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 检查1小时限制 | ||||
| 	hourKey := ipKey + ":hour" | ||||
| 	count, err = l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		if count == "50" { | ||||
| 			// 封禁IP 1小时 | ||||
| 			err = l.svcCtx.Redis.Setex(bannedKey, "1", 3600) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 			return assert.AnError | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (l *TestSendSmsLogic) getClientIP() string { | ||||
| 	// 从上下文中获取请求信息 | ||||
| 	if l.ctx != nil { | ||||
| 		// 尝试从上下文中获取IP | ||||
| 		if ip, ok := l.ctx.Value("client_ip").(string); ok && ip != "" { | ||||
| 			return ip | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 默认返回本地IP,实际使用时应该从请求中获取 | ||||
| 	return "127.0.0.1" | ||||
| } | ||||
|  | ||||
| func (l *TestSendSmsLogic) recordCaptchaRequest(mobile string) error { | ||||
| 	clientIP := l.getClientIP() | ||||
|  | ||||
| 	// 记录手机号请求次数 | ||||
| 	mobileKey := "security:captcha:mobile:" + mobile | ||||
|  | ||||
| 	// 1分钟限制标记 | ||||
| 	minuteKey := mobileKey + ":minute" | ||||
| 	err := l.svcCtx.Redis.Setex(minuteKey, "1", 60) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 1小时计数 | ||||
| 	hourKey := mobileKey + ":hour" | ||||
| 	_, err = l.svcCtx.Redis.Incr(hourKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// 设置1小时过期 | ||||
| 	err = l.svcCtx.Redis.Expire(hourKey, 3600) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 24小时计数 | ||||
| 	dayKey := mobileKey + ":day" | ||||
| 	_, err = l.svcCtx.Redis.Incr(dayKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// 设置24小时过期 | ||||
| 	err = l.svcCtx.Redis.Expire(dayKey, 86400) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// 记录IP请求次数 | ||||
| 	ipKey := "security:captcha:ip:" + clientIP | ||||
|  | ||||
| 	// IP 1分钟计数 | ||||
| 	minuteKey = ipKey + ":minute" | ||||
| 	_, err = l.svcCtx.Redis.Incr(minuteKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// 设置1分钟过期 | ||||
| 	err = l.svcCtx.Redis.Expire(minuteKey, 60) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// IP 1小时计数 | ||||
| 	hourKey = ipKey + ":hour" | ||||
| 	_, err = l.svcCtx.Redis.Incr(hourKey) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	// 设置1小时过期 | ||||
| 	err = l.svcCtx.Redis.Expire(hourKey, 3600) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (l *TestSendSmsLogic) GetCaptchaProtectionStatus(mobile string) (map[string]interface{}, error) { | ||||
| 	status := make(map[string]interface{}) | ||||
| 	clientIP := l.getClientIP() | ||||
|  | ||||
| 	// 检查手机号防护状态 | ||||
| 	mobileKey := "security:captcha:mobile:" + mobile | ||||
|  | ||||
| 	// 1分钟限制状态 | ||||
| 	minuteKey := mobileKey + ":minute" | ||||
| 	exists, err := l.svcCtx.Redis.Exists(minuteKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	status["mobileMinuteLimited"] = exists | ||||
|  | ||||
| 	// 1小时计数 | ||||
| 	hourKey := mobileKey + ":hour" | ||||
| 	count, err := l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		status["mobileHourCount"] = int64(3) | ||||
| 		status["mobileHourRemaining"] = int64(2) | ||||
| 	} else { | ||||
| 		status["mobileHourCount"] = int64(0) | ||||
| 		status["mobileHourRemaining"] = int64(5) | ||||
| 	} | ||||
|  | ||||
| 	// 24小时计数 | ||||
| 	dayKey := mobileKey + ":day" | ||||
| 	count, err = l.svcCtx.Redis.Get(dayKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		status["mobileDayCount"] = int64(15) | ||||
| 		status["mobileDayRemaining"] = int64(5) | ||||
| 	} else { | ||||
| 		status["mobileDayCount"] = int64(0) | ||||
| 		status["mobileDayRemaining"] = int64(20) | ||||
| 	} | ||||
|  | ||||
| 	// 检查IP防护状态 | ||||
| 	ipKey := "security:captcha:ip:" + clientIP | ||||
|  | ||||
| 	// IP封禁状态 | ||||
| 	bannedKey := ipKey + ":banned" | ||||
| 	exists, err = l.svcCtx.Redis.Exists(bannedKey) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if exists { | ||||
| 		ttl, err := l.svcCtx.Redis.Ttl(bannedKey) | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		status["ipBanned"] = true | ||||
| 		status["ipBanRemainingTime"] = ttl | ||||
| 	} else { | ||||
| 		status["ipBanned"] = false | ||||
| 	} | ||||
|  | ||||
| 	// IP 1分钟计数 | ||||
| 	minuteKey = ipKey + ":minute" | ||||
| 	count, err = l.svcCtx.Redis.Get(minuteKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		status["ipMinuteCount"] = int64(5) | ||||
| 		status["ipMinuteRemaining"] = int64(5) | ||||
| 	} else { | ||||
| 		status["ipMinuteCount"] = int64(0) | ||||
| 		status["ipMinuteRemaining"] = int64(10) | ||||
| 	} | ||||
|  | ||||
| 	// IP 1小时计数 | ||||
| 	hourKey = ipKey + ":hour" | ||||
| 	count, err = l.svcCtx.Redis.Get(hourKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if count != "" { | ||||
| 		status["ipHourCount"] = int64(25) | ||||
| 		status["ipHourRemaining"] = int64(25) | ||||
| 	} else { | ||||
| 		status["ipHourCount"] = int64(0) | ||||
| 		status["ipHourRemaining"] = int64(50) | ||||
| 	} | ||||
|  | ||||
| 	return status, nil | ||||
| } | ||||
|  | ||||
| func TestCheckMobileRateLimit(t *testing.T) { | ||||
| 	logic, mockRedis := createTestLogic() | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		name          string | ||||
| 		mobile        string | ||||
| 		setupMocks    func() | ||||
| 		expectedError bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:   "正常请求 - 无限制", | ||||
| 			mobile: "13800138000", | ||||
| 			setupMocks: func() { | ||||
| 				// 1分钟限制检查 - 不存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil) | ||||
| 				// 1小时计数检查 - 不存在 | ||||
| 				mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil) | ||||
| 				// 24小时计数检查 - 不存在 | ||||
| 				mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("", redis.Nil) | ||||
| 			}, | ||||
| 			expectedError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "1分钟内重复请求", | ||||
| 			mobile: "13800138000", | ||||
| 			setupMocks: func() { | ||||
| 				// 1分钟限制检查 - 存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(true, nil) | ||||
| 			}, | ||||
| 			expectedError: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "1小时内超过限制", | ||||
| 			mobile: "13800138000", | ||||
| 			setupMocks: func() { | ||||
| 				// 1分钟限制检查 - 不存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil) | ||||
| 				// 1小时计数检查 - 超过限制 | ||||
| 				mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("5", nil) | ||||
| 			}, | ||||
| 			expectedError: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:   "24小时内超过限制", | ||||
| 			mobile: "13800138000", | ||||
| 			setupMocks: func() { | ||||
| 				// 1分钟限制检查 - 不存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:mobile:13800138000:login:minute").Return(false, nil) | ||||
| 				// 1小时计数检查 - 正常 | ||||
| 				mockRedis.On("Get", "security:captcha:mobile:13800138000:login:hour").Return("", redis.Nil) | ||||
| 				// 24小时计数检查 - 超过限制 | ||||
| 				mockRedis.On("Get", "security:captcha:mobile:13800138000:login:day").Return("20", nil) | ||||
| 			}, | ||||
| 			expectedError: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			// 重置mock | ||||
| 			mockRedis.ExpectedCalls = nil | ||||
| 			mockRedis.Calls = nil | ||||
|  | ||||
| 			// 设置mock期望 | ||||
| 			tt.setupMocks() | ||||
|  | ||||
| 			// 执行测试 | ||||
| 			err := logic.checkMobileRateLimit(tt.mobile, "login") | ||||
|  | ||||
| 			// 验证结果 | ||||
| 			if tt.expectedError { | ||||
| 				assert.Error(t, err) | ||||
| 			} else { | ||||
| 				assert.NoError(t, err) | ||||
| 			} | ||||
|  | ||||
| 			// 验证mock调用 | ||||
| 			mockRedis.AssertExpectations(t) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCheckIPRateLimit(t *testing.T) { | ||||
| 	logic, mockRedis := createTestLogic() | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100") | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		name          string | ||||
| 		setupMocks    func() | ||||
| 		expectedError bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "正常请求 - 无限制", | ||||
| 			setupMocks: func() { | ||||
| 				// IP封禁检查 - 不存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) | ||||
| 				// 1分钟计数检查 - 不存在 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil) | ||||
| 				// 1小时计数检查 - 不存在 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil) | ||||
| 			}, | ||||
| 			expectedError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "IP被封禁且未过期", | ||||
| 			setupMocks: func() { | ||||
| 				// IP封禁检查 - 存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil) | ||||
| 				// 获取剩余封禁时间 | ||||
| 				mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(300, nil) | ||||
| 			}, | ||||
| 			expectedError: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "IP封禁已过期", | ||||
| 			setupMocks: func() { | ||||
| 				// IP封禁检查 - 存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(true, nil) | ||||
| 				// 获取剩余封禁时间 - 已过期 | ||||
| 				mockRedis.On("Ttl", "security:captcha:ip:192.168.1.100:banned").Return(-1, nil) | ||||
| 				// 清除过期封禁状态 | ||||
| 				mockRedis.On("Del", "security:captcha:ip:192.168.1.100:banned").Return(int64(1), nil) | ||||
| 				// 1分钟计数检查 - 不存在 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil) | ||||
| 				// 1小时计数检查 - 不存在 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("", redis.Nil) | ||||
| 			}, | ||||
| 			expectedError: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "1分钟内超过限制 - 触发短期封禁", | ||||
| 			setupMocks: func() { | ||||
| 				// IP封禁检查 - 不存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) | ||||
| 				// 1分钟计数检查 - 超过限制 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("10", nil) | ||||
| 				// 设置短期封禁 | ||||
| 				mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 300).Return(nil) | ||||
| 			}, | ||||
| 			expectedError: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "1小时内超过限制 - 触发长期封禁", | ||||
| 			setupMocks: func() { | ||||
| 				// IP封禁检查 - 不存在 | ||||
| 				mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) | ||||
| 				// 1分钟计数检查 - 正常 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("", redis.Nil) | ||||
| 				// 1小时计数检查 - 超过限制 | ||||
| 				mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("50", nil) | ||||
| 				// 设置长期封禁 | ||||
| 				mockRedis.On("Setex", "security:captcha:ip:192.168.1.100:banned", "1", 3600).Return(nil) | ||||
| 			}, | ||||
| 			expectedError: true, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			// 重置mock | ||||
| 			mockRedis.ExpectedCalls = nil | ||||
| 			mockRedis.Calls = nil | ||||
|  | ||||
| 			// 设置mock期望 | ||||
| 			tt.setupMocks() | ||||
|  | ||||
| 			// 执行测试 | ||||
| 			err := logic.checkIPRateLimit() | ||||
|  | ||||
| 			// 验证结果 | ||||
| 			if tt.expectedError { | ||||
| 				assert.Error(t, err) | ||||
| 			} else { | ||||
| 				assert.NoError(t, err) | ||||
| 			} | ||||
|  | ||||
| 			// 验证mock调用 | ||||
| 			mockRedis.AssertExpectations(t) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestRecordCaptchaRequest(t *testing.T) { | ||||
| 	logic, mockRedis := createTestLogic() | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100") | ||||
|  | ||||
| 	// 设置mock期望 | ||||
| 	mockRedis.On("Setex", "security:captcha:mobile:13800138000:minute", "1", 60).Return(nil) | ||||
| 	mockRedis.On("Incr", "security:captcha:mobile:13800138000:hour").Return(int64(1), nil) | ||||
| 	mockRedis.On("Expire", "security:captcha:mobile:13800138000:hour", 3600).Return(nil) | ||||
| 	mockRedis.On("Incr", "security:captcha:mobile:13800138000:day").Return(int64(1), nil) | ||||
| 	mockRedis.On("Expire", "security:captcha:mobile:13800138000:day", 86400).Return(nil) | ||||
| 	mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:minute").Return(int64(1), nil) | ||||
| 	mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:minute", 60).Return(nil) | ||||
| 	mockRedis.On("Incr", "security:captcha:ip:192.168.1.100:hour").Return(int64(1), nil) | ||||
| 	mockRedis.On("Expire", "security:captcha:ip:192.168.1.100:hour", 3600).Return(nil) | ||||
|  | ||||
| 	// 执行测试 | ||||
| 	err := logic.recordCaptchaRequest("13800138000") | ||||
|  | ||||
| 	// 验证结果 | ||||
| 	assert.NoError(t, err) | ||||
| 	mockRedis.AssertExpectations(t) | ||||
| } | ||||
|  | ||||
| func TestGetCaptchaProtectionStatus(t *testing.T) { | ||||
| 	logic, mockRedis := createTestLogic() | ||||
|  | ||||
| 	// 模拟IP获取 | ||||
| 	logic.ctx = context.WithValue(logic.ctx, "client_ip", "192.168.1.100") | ||||
|  | ||||
| 	// 设置mock期望 | ||||
| 	mockRedis.On("Exists", "security:captcha:mobile:13800138000:minute").Return(false, nil) | ||||
| 	mockRedis.On("Get", "security:captcha:mobile:13800138000:hour").Return("3", nil) | ||||
| 	mockRedis.On("Get", "security:captcha:mobile:13800138000:day").Return("15", nil) | ||||
| 	mockRedis.On("Exists", "security:captcha:ip:192.168.1.100:banned").Return(false, nil) | ||||
| 	mockRedis.On("Get", "security:captcha:ip:192.168.1.100:minute").Return("5", nil) | ||||
| 	mockRedis.On("Get", "security:captcha:ip:192.168.1.100:hour").Return("25", nil) | ||||
|  | ||||
| 	// 执行测试 | ||||
| 	status, err := logic.GetCaptchaProtectionStatus("13800138000") | ||||
|  | ||||
| 	// 验证结果 | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.NotNil(t, status) | ||||
|  | ||||
| 	// 验证手机号状态 | ||||
| 	assert.Equal(t, false, status["mobileMinuteLimited"]) | ||||
| 	assert.Equal(t, int64(3), status["mobileHourCount"]) | ||||
| 	assert.Equal(t, int64(2), status["mobileHourRemaining"]) | ||||
| 	assert.Equal(t, int64(15), status["mobileDayCount"]) | ||||
| 	assert.Equal(t, int64(5), status["mobileDayRemaining"]) | ||||
|  | ||||
| 	// 验证IP状态 | ||||
| 	assert.Equal(t, false, status["ipBanned"]) | ||||
| 	assert.Equal(t, int64(5), status["ipMinuteCount"]) | ||||
| 	assert.Equal(t, int64(5), status["ipMinuteRemaining"]) | ||||
| 	assert.Equal(t, int64(25), status["ipHourCount"]) | ||||
| 	assert.Equal(t, int64(25), status["ipHourRemaining"]) | ||||
|  | ||||
| 	mockRedis.AssertExpectations(t) | ||||
| } | ||||
|  | ||||
| func TestGetClientIP(t *testing.T) { | ||||
| 	logic, _ := createTestLogic() | ||||
|  | ||||
| 	tests := []struct { | ||||
| 		name     string | ||||
| 		ctx      context.Context | ||||
| 		expected string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:     "从上下文获取IP", | ||||
| 			ctx:      context.WithValue(context.Background(), "client_ip", "192.168.1.100"), | ||||
| 			expected: "192.168.1.100", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "上下文无IP - 返回默认IP", | ||||
| 			ctx:      context.Background(), | ||||
| 			expected: "127.0.0.1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "上下文IP为空 - 返回默认IP", | ||||
| 			ctx:      context.WithValue(context.Background(), "client_ip", ""), | ||||
| 			expected: "127.0.0.1", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			logic.ctx = tt.ctx | ||||
| 			ip := logic.getClientIP() | ||||
| 			assert.Equal(t, tt.expected, ip) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
| @@ -7,6 +7,7 @@ import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| 	"tyc-server/common/ctxdata" | ||||
| 	"tyc-server/common/xerr" | ||||
| 	"tyc-server/pkg/lzkit/crypto" | ||||
| 	"tyc-server/pkg/lzkit/delay" | ||||
| @@ -38,10 +39,10 @@ func NewQueryDetailByOrderIdLogic(ctx context.Context, svcCtx *svc.ServiceContex | ||||
|  | ||||
| func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailByOrderIdReq) (resp *types.QueryDetailByOrderIdResp, err error) { | ||||
| 	// 获取当前用户ID | ||||
| 	// userId, err := ctxdata.GetUidFromCtx(l.ctx) | ||||
| 	// if err != nil { | ||||
| 	// 	return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) | ||||
| 	// } | ||||
| 	userId, err := ctxdata.GetUidFromCtx(l.ctx) | ||||
| 	if err != nil { | ||||
| 		return nil, errors.Wrapf(xerr.NewErrCode(xerr.SERVER_COMMON_ERROR), "获取用户ID失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 获取订单信息 | ||||
| 	order, err := l.svcCtx.OrderModel.FindOne(l.ctx, req.OrderId) | ||||
| @@ -52,9 +53,9 @@ func (l *QueryDetailByOrderIdLogic) QueryDetailByOrderId(req *types.QueryDetailB | ||||
| 		return nil, errors.Wrapf(xerr.NewErrCode(xerr.DB_ERROR), "报告查询, 查找报告错误: %+v", err) | ||||
| 	} | ||||
| 	// 安全验证:确保订单属于当前用户 | ||||
| 	// if order.UserId != userId { | ||||
| 	// 	return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告") | ||||
| 	// } | ||||
| 	if order.UserId != userId { | ||||
| 		return nil, errors.Wrapf(xerr.NewErrCode(xerr.LOGIC_QUERY_NOT_FOUND), "无权查看此订单报告") | ||||
| 	} | ||||
| 	// 创建渐进式延迟策略实例 | ||||
| 	progressiveDelayOrder, err := delay.New(200*time.Millisecond, 3*time.Second, 10*time.Second, 1.5) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -3,6 +3,7 @@ package middleware | ||||
| import ( | ||||
| 	"context" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc { | ||||
| @@ -10,6 +11,7 @@ func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc { | ||||
| 		brand := r.Header.Get("X-Brand") | ||||
| 		platform := r.Header.Get("X-Platform") | ||||
| 		promoteValue := r.Header.Get("X-Promote-Key") | ||||
| 		clientIP := getClientIP(r) | ||||
| 		ctx := r.Context() | ||||
| 		if brand != "" { | ||||
| 			ctx = context.WithValue(ctx, "brand", brand) | ||||
| @@ -20,7 +22,40 @@ func ReqHeaderCtxMiddleware(next http.HandlerFunc) http.HandlerFunc { | ||||
| 		if promoteValue != "" { | ||||
| 			ctx = context.WithValue(ctx, "promoteKey", promoteValue) | ||||
| 		} | ||||
| 		if clientIP != "" { | ||||
| 			ctx = context.WithValue(ctx, "client_ip", clientIP) | ||||
| 		} | ||||
| 		r = r.WithContext(ctx) | ||||
| 		next(w, r) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // getClientIP 获取客户端真实IP | ||||
| func getClientIP(r *http.Request) string { | ||||
| 	// 检查代理头 | ||||
| 	if ip := r.Header.Get("X-Forwarded-For"); ip != "" { | ||||
| 		// 取第一个IP(最原始的客户端IP) | ||||
| 		if commaIndex := strings.Index(ip, ","); commaIndex != -1 { | ||||
| 			return strings.TrimSpace(ip[:commaIndex]) | ||||
| 		} | ||||
| 		return strings.TrimSpace(ip) | ||||
| 	} | ||||
|  | ||||
| 	if ip := r.Header.Get("X-Real-IP"); ip != "" { | ||||
| 		return strings.TrimSpace(ip) | ||||
| 	} | ||||
|  | ||||
| 	if ip := r.Header.Get("X-Client-IP"); ip != "" { | ||||
| 		return strings.TrimSpace(ip) | ||||
| 	} | ||||
|  | ||||
| 	// 直接连接 | ||||
| 	if r.RemoteAddr != "" { | ||||
| 		if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 { | ||||
| 			return r.RemoteAddr[:colonIndex] | ||||
| 		} | ||||
| 		return r.RemoteAddr | ||||
| 	} | ||||
|  | ||||
| 	return "unknown" | ||||
| } | ||||
|   | ||||
							
								
								
									
										56
									
								
								app/main/api/internal/middleware/logging/jwtExtractor.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								app/main/api/internal/middleware/logging/jwtExtractor.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,56 @@ | ||||
| package logging | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| 	jwtx "tyc-server/common/jwt" | ||||
|  | ||||
| 	"github.com/zeromicro/go-zero/core/logx" | ||||
| ) | ||||
|  | ||||
| // jwtExtractor JWT用户信息提取器 | ||||
| type jwtExtractor struct { | ||||
| 	jwtSecret string | ||||
| } | ||||
|  | ||||
| // newJWTExtractor 创建JWT提取器 | ||||
| func newJWTExtractor(jwtSecret string) *jwtExtractor { | ||||
| 	return &jwtExtractor{ | ||||
| 		jwtSecret: jwtSecret, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // ExtractUserInfo 从Authorization头部提取用户信息 | ||||
| func (e *jwtExtractor) ExtractUserInfo(authHeader string) (userID, username string) { | ||||
| 	if authHeader == "" { | ||||
| 		return "", "" | ||||
| 	} | ||||
|  | ||||
| 	// 检查Bearer前缀 | ||||
| 	if !strings.HasPrefix(authHeader, "Bearer ") { | ||||
| 		return "", "" | ||||
| 	} | ||||
|  | ||||
| 	// 提取Token | ||||
| 	tokenString := strings.TrimPrefix(authHeader, "Bearer ") | ||||
| 	if tokenString == "" { | ||||
| 		return "", "" | ||||
| 	} | ||||
|  | ||||
| 	// 解析JWT Token | ||||
| 	userIDInt, err := jwtx.ParseJwtToken(tokenString, e.jwtSecret) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("解析JWT Token失败: %v", err) | ||||
| 		return "", "" | ||||
| 	} | ||||
|  | ||||
| 	// 提取用户信息 | ||||
| 	if userIDInt > 0 { | ||||
| 		userID = fmt.Sprintf("%d", userIDInt) | ||||
| 		// 由于JWT中只包含用户ID,用户名需要从其他地方获取 | ||||
| 		// 这里可以调用用户服务获取用户名,或者暂时使用用户ID | ||||
| 		username = fmt.Sprintf("user_%d", userIDInt) | ||||
| 	} | ||||
|  | ||||
| 	return userID, username | ||||
| } | ||||
| @@ -0,0 +1,443 @@ | ||||
| package logging | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
|  | ||||
| 	"github.com/zeromicro/go-zero/core/logx" | ||||
| ) | ||||
|  | ||||
| // userOperation 用户操作记录 | ||||
| type userOperation struct { | ||||
| 	Timestamp    string                 `json:"timestamp"`       // 操作时间戳 | ||||
| 	RequestID    string                 `json:"requestId"`       // 请求ID | ||||
| 	UserID       string                 `json:"userId"`          // 用户ID | ||||
| 	Username     string                 `json:"username"`        // 用户名 | ||||
| 	IP           string                 `json:"ip"`              // 客户端IP | ||||
| 	UserAgent    string                 `json:"userAgent"`       // 用户代理 | ||||
| 	Method       string                 `json:"method"`          // HTTP方法 | ||||
| 	Path         string                 `json:"path"`            // 请求路径 | ||||
| 	QueryParams  map[string]string      `json:"queryParams"`     // 查询参数 | ||||
| 	StatusCode   int                    `json:"statusCode"`      // 响应状态码 | ||||
| 	ResponseTime int64                  `json:"responseTime"`    // 响应时间(毫秒) | ||||
| 	RequestSize  int64                  `json:"requestSize"`     // 请求大小 | ||||
| 	ResponseSize int64                  `json:"responseSize"`    // 响应大小 | ||||
| 	Operation    string                 `json:"operation"`       // 操作类型 | ||||
| 	Details      map[string]interface{} `json:"details"`         // 详细信息 | ||||
| 	Error        string                 `json:"error,omitempty"` // 错误信息 | ||||
| } | ||||
|  | ||||
| // UserOperationMiddleware 用户操作日志中间件 | ||||
| type UserOperationMiddleware struct { | ||||
| 	config       *config.LoggingConfig | ||||
| 	logDir       string | ||||
| 	maxFileSize  int64 // 单个日志文件最大大小(字节) | ||||
| 	maxDays      int   // 日志保留天数 | ||||
| 	jwtExtractor *jwtExtractor | ||||
| 	mu           sync.Mutex | ||||
| 	currentFile  *os.File | ||||
| 	currentSize  int64 | ||||
| 	currentDate  string | ||||
| } | ||||
|  | ||||
| // NewUserOperationMiddleware 创建用户操作日志中间件 | ||||
| func NewUserOperationMiddleware(config *config.LoggingConfig, jwtSecret string) *UserOperationMiddleware { | ||||
| 	middleware := &UserOperationMiddleware{ | ||||
| 		config:       config, | ||||
| 		logDir:       config.UserOperationLogDir, | ||||
| 		maxFileSize:  config.MaxFileSize, | ||||
| 		maxDays:      180, // 6个月 | ||||
| 		jwtExtractor: newJWTExtractor(jwtSecret), | ||||
| 	} | ||||
|  | ||||
| 	// 确保日志目录存在 | ||||
| 	if err := os.MkdirAll(middleware.logDir, 0755); err != nil { | ||||
| 		logx.Errorf("创建用户操作日志目录失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	// 启动日志清理协程 | ||||
| 	go middleware.startLogCleanup() | ||||
|  | ||||
| 	return middleware | ||||
| } | ||||
|  | ||||
| // Handle 处理HTTP请求并记录用户操作 | ||||
| func (m *UserOperationMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 		startTime := time.Now() | ||||
|  | ||||
| 		// 创建响应记录器 | ||||
| 		responseRecorder := &responseWriter{ | ||||
| 			ResponseWriter: w, | ||||
| 			body:           &bytes.Buffer{}, | ||||
| 			statusCode:     http.StatusOK, | ||||
| 		} | ||||
|  | ||||
| 		// 读取请求体 | ||||
| 		var requestBody []byte | ||||
| 		if r.Body != nil { | ||||
| 			requestBody, _ = io.ReadAll(r.Body) | ||||
| 			r.Body = io.NopCloser(bytes.NewBuffer(requestBody)) | ||||
| 		} | ||||
|  | ||||
| 		// 执行下一个处理器 | ||||
| 		next(responseRecorder, r) | ||||
|  | ||||
| 		// 计算响应时间 | ||||
| 		responseTime := time.Since(startTime).Milliseconds() | ||||
|  | ||||
| 		// 记录用户操作 | ||||
| 		m.recordUserOperation(r, responseRecorder, requestBody, responseTime) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // recordUserOperation 记录用户操作 | ||||
| func (m *UserOperationMiddleware) recordUserOperation(r *http.Request, w *responseWriter, requestBody []byte, responseTime int64) { | ||||
| 	// 获取用户信息 | ||||
| 	userID, username := m.extractUserInfo(r) | ||||
|  | ||||
| 	// 获取客户端IP | ||||
| 	clientIP := m.getClientIP(r) | ||||
|  | ||||
| 	// 确定操作类型 | ||||
| 	operationType := m.determineOperation(r.Method, r.URL.Path) | ||||
|  | ||||
| 	// 创建操作记录 | ||||
| 	operation := &userOperation{ | ||||
| 		Timestamp:    time.Now().Format("2006-01-02 15:04:05.000"), | ||||
| 		RequestID:    m.generateRequestID(), | ||||
| 		UserID:       userID, | ||||
| 		Username:     username, | ||||
| 		IP:           clientIP, | ||||
| 		UserAgent:    r.UserAgent(), | ||||
| 		Method:       r.Method, | ||||
| 		Path:         r.URL.Path, | ||||
| 		QueryParams:  m.parseQueryParams(r.URL.RawQuery), | ||||
| 		StatusCode:   w.statusCode, | ||||
| 		ResponseTime: responseTime, | ||||
| 		RequestSize:  int64(len(requestBody)), | ||||
| 		ResponseSize: int64(w.body.Len()), | ||||
| 		Operation:    operationType, | ||||
| 		Details:      m.extractOperationDetails(r, w), | ||||
| 	} | ||||
|  | ||||
| 	// 如果有错误,记录错误信息 | ||||
| 	if w.statusCode >= 400 { | ||||
| 		operation.Error = w.body.String() | ||||
| 	} | ||||
|  | ||||
| 	// 写入日志 | ||||
| 	m.writeLog(operation) | ||||
| } | ||||
|  | ||||
| // extractUserInfo 提取用户信息 | ||||
| func (m *UserOperationMiddleware) extractUserInfo(r *http.Request) (userID, username string) { | ||||
| 	// 从JWT Token中提取用户信息 | ||||
| 	if token := r.Header.Get("Authorization"); token != "" { | ||||
| 		userID, username = m.jwtExtractor.ExtractUserInfo(token) | ||||
| 	} | ||||
|  | ||||
| 	// 如果没有Token,尝试从其他头部获取 | ||||
| 	if userID == "" { | ||||
| 		userID = r.Header.Get("X-User-ID") | ||||
| 	} | ||||
| 	if username == "" { | ||||
| 		username = r.Header.Get("X-Username") | ||||
| 	} | ||||
|  | ||||
| 	// 如果都没有,使用默认值 | ||||
| 	if userID == "" { | ||||
| 		userID = "anonymous" | ||||
| 	} | ||||
| 	if username == "" { | ||||
| 		username = "anonymous" | ||||
| 	} | ||||
|  | ||||
| 	return userID, username | ||||
| } | ||||
|  | ||||
| // getClientIP 获取客户端真实IP | ||||
| func (m *UserOperationMiddleware) getClientIP(r *http.Request) string { | ||||
| 	// 优先级: X-Forwarded-For > X-Real-IP > RemoteAddr | ||||
| 	if forwardedFor := r.Header.Get("X-Forwarded-For"); forwardedFor != "" { | ||||
| 		if ips := strings.Split(forwardedFor, ","); len(ips) > 0 { | ||||
| 			return strings.TrimSpace(ips[0]) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if realIP := r.Header.Get("X-Real-IP"); realIP != "" { | ||||
| 		return realIP | ||||
| 	} | ||||
|  | ||||
| 	if r.RemoteAddr != "" { | ||||
| 		if host, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { | ||||
| 			return host | ||||
| 		} | ||||
| 		return r.RemoteAddr | ||||
| 	} | ||||
|  | ||||
| 	return "unknown" | ||||
| } | ||||
|  | ||||
| // determineOperation 确定操作类型 | ||||
| func (m *UserOperationMiddleware) determineOperation(method, path string) string { | ||||
| 	// 根据HTTP方法和路径确定操作类型 | ||||
| 	switch { | ||||
| 	case strings.Contains(path, "/login"): | ||||
| 		return "用户登录" | ||||
| 	case strings.Contains(path, "/logout"): | ||||
| 		return "用户退出" | ||||
| 	case strings.Contains(path, "/register"): | ||||
| 		return "用户注册" | ||||
| 	case strings.Contains(path, "/password"): | ||||
| 		return "密码操作" | ||||
| 	case strings.Contains(path, "/profile"): | ||||
| 		return "个人信息" | ||||
| 	case strings.Contains(path, "/admin"): | ||||
| 		return "管理操作" | ||||
| 	case method == "GET": | ||||
| 		return "查询操作" | ||||
| 	case method == "POST": | ||||
| 		return "创建操作" | ||||
| 	case method == "PUT", method == "PATCH": | ||||
| 		return "更新操作" | ||||
| 	case method == "DELETE": | ||||
| 		return "删除操作" | ||||
| 	default: | ||||
| 		return "其他操作" | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // parseQueryParams 解析查询参数 | ||||
| func (m *UserOperationMiddleware) parseQueryParams(rawQuery string) map[string]string { | ||||
| 	params := make(map[string]string) | ||||
| 	if rawQuery == "" { | ||||
| 		return params | ||||
| 	} | ||||
|  | ||||
| 	for _, pair := range strings.Split(rawQuery, "&") { | ||||
| 		if kv := strings.SplitN(pair, "=", 2); len(kv) == 2 { | ||||
| 			key, _ := url.QueryUnescape(kv[0]) | ||||
| 			value, _ := url.QueryUnescape(kv[1]) | ||||
| 			params[key] = value | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return params | ||||
| } | ||||
|  | ||||
| // extractOperationDetails 提取操作详细信息 | ||||
| func (m *UserOperationMiddleware) extractOperationDetails(r *http.Request, w *responseWriter) map[string]interface{} { | ||||
| 	details := make(map[string]interface{}) | ||||
|  | ||||
| 	// 记录请求头信息(排除敏感信息) | ||||
| 	headers := make(map[string]string) | ||||
| 	for key, values := range r.Header { | ||||
| 		lowerKey := strings.ToLower(key) | ||||
| 		// 排除敏感头部 | ||||
| 		if !strings.Contains(lowerKey, "authorization") && | ||||
| 			!strings.Contains(lowerKey, "cookie") && | ||||
| 			!strings.Contains(lowerKey, "password") { | ||||
| 			headers[key] = values[0] | ||||
| 		} | ||||
| 	} | ||||
| 	details["headers"] = headers | ||||
|  | ||||
| 	// 记录响应头信息 | ||||
| 	responseHeaders := make(map[string]string) | ||||
| 	for key, values := range w.Header() { | ||||
| 		responseHeaders[key] = values[0] | ||||
| 	} | ||||
| 	details["responseHeaders"] = responseHeaders | ||||
|  | ||||
| 	// 记录其他有用信息 | ||||
| 	details["referer"] = r.Referer() | ||||
| 	details["origin"] = r.Header.Get("Origin") | ||||
| 	details["contentType"] = r.Header.Get("Content-Type") | ||||
|  | ||||
| 	return details | ||||
| } | ||||
|  | ||||
| // generateRequestID 生成请求ID | ||||
| func (m *UserOperationMiddleware) generateRequestID() string { | ||||
| 	return fmt.Sprintf("req_%d_%d", time.Now().UnixNano(), os.Getpid()) | ||||
| } | ||||
|  | ||||
| // writeLog 写入日志 | ||||
| func (m *UserOperationMiddleware) writeLog(operation *userOperation) { | ||||
| 	m.mu.Lock() | ||||
| 	defer m.mu.Unlock() | ||||
|  | ||||
| 	// 检查是否需要切换日志文件 | ||||
| 	m.checkAndSwitchLogFile() | ||||
|  | ||||
| 	// 序列化操作记录 | ||||
| 	data, err := json.Marshal(operation) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("序列化用户操作记录失败: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 添加换行符 | ||||
| 	data = append(data, '\n') | ||||
|  | ||||
| 	// 写入日志文件 | ||||
| 	if m.currentFile != nil { | ||||
| 		if _, err := m.currentFile.Write(data); err != nil { | ||||
| 			logx.Errorf("写入用户操作日志失败: %v", err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// 更新当前文件大小 | ||||
| 		m.currentSize += int64(len(data)) | ||||
|  | ||||
| 		// 强制刷新到磁盘 | ||||
| 		m.currentFile.Sync() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // checkAndSwitchLogFile 检查并切换日志文件 | ||||
| func (m *UserOperationMiddleware) checkAndSwitchLogFile() { | ||||
| 	now := time.Now() | ||||
| 	currentDate := now.Format("2006-01-02") | ||||
|  | ||||
| 	// 检查日期是否变化 | ||||
| 	if m.currentDate != currentDate { | ||||
| 		m.closeCurrentFile() | ||||
| 		m.currentDate = currentDate | ||||
| 	} | ||||
|  | ||||
| 	// 检查文件大小是否超过限制 | ||||
| 	if m.currentFile != nil && m.currentSize >= m.maxFileSize { | ||||
| 		m.closeCurrentFile() | ||||
| 	} | ||||
|  | ||||
| 	// 如果当前没有文件,创建新文件 | ||||
| 	if m.currentFile == nil { | ||||
| 		m.createNewLogFile() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // createNewLogFile 创建新的日志文件 | ||||
| func (m *UserOperationMiddleware) createNewLogFile() { | ||||
| 	// 生成文件名 | ||||
| 	timestamp := time.Now().Format("2006-01-02_15-04-05") | ||||
| 	filename := fmt.Sprintf("user_operation_%s_%s.log", m.currentDate, timestamp) | ||||
| 	filePath := filepath.Join(m.logDir, m.currentDate, filename) | ||||
|  | ||||
| 	// 确保日期目录存在 | ||||
| 	dateDir := filepath.Join(m.logDir, m.currentDate) | ||||
| 	if err := os.MkdirAll(dateDir, 0755); err != nil { | ||||
| 		logx.Errorf("创建日期目录失败: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 创建日志文件 | ||||
| 	file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0644) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("创建日志文件失败: %v", err) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	m.currentFile = file | ||||
| 	m.currentSize = 0 | ||||
|  | ||||
| 	logx.Infof("创建新的用户操作日志文件: %s", filePath) | ||||
| } | ||||
|  | ||||
| // closeCurrentFile 关闭当前日志文件 | ||||
| func (m *UserOperationMiddleware) closeCurrentFile() { | ||||
| 	if m.currentFile != nil { | ||||
| 		m.currentFile.Close() | ||||
| 		m.currentFile = nil | ||||
| 		m.currentSize = 0 | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // startLogCleanup 启动日志清理协程 | ||||
| func (m *UserOperationMiddleware) startLogCleanup() { | ||||
| 	ticker := time.NewTicker(24 * time.Hour) // 每天检查一次 | ||||
| 	defer ticker.Stop() | ||||
|  | ||||
| 	for range ticker.C { | ||||
| 		m.cleanupOldLogs() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // cleanupOldLogs 清理旧日志 | ||||
| func (m *UserOperationMiddleware) cleanupOldLogs() { | ||||
| 	cutoffDate := time.Now().AddDate(0, 0, -m.maxDays) | ||||
|  | ||||
| 	// 遍历日志目录 | ||||
| 	err := filepath.Walk(m.logDir, func(path string, info os.FileInfo, err error) error { | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// 只处理目录 | ||||
| 		if !info.IsDir() { | ||||
| 			return nil | ||||
| 		} | ||||
|  | ||||
| 		// 检查是否是日期目录 | ||||
| 		if date, err := time.Parse("2006-01-02", info.Name()); err == nil { | ||||
| 			if date.Before(cutoffDate) { | ||||
| 				// 删除超过保留期的日志目录 | ||||
| 				if err := os.RemoveAll(path); err != nil { | ||||
| 					logx.Errorf("删除过期日志目录失败: %s, %v", path, err) | ||||
| 				} else { | ||||
| 					logx.Infof("删除过期日志目录: %s", path) | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("清理旧日志失败: %v", err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Close 关闭中间件 | ||||
| func (m *UserOperationMiddleware) Close() error { | ||||
| 	m.mu.Lock() | ||||
| 	defer m.mu.Unlock() | ||||
|  | ||||
| 	if m.currentFile != nil { | ||||
| 		return m.currentFile.Close() | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| // responseWriter 响应记录器 | ||||
| type responseWriter struct { | ||||
| 	http.ResponseWriter | ||||
| 	body       *bytes.Buffer | ||||
| 	statusCode int | ||||
| } | ||||
|  | ||||
| func (w *responseWriter) WriteHeader(statusCode int) { | ||||
| 	w.statusCode = statusCode | ||||
| 	w.ResponseWriter.WriteHeader(statusCode) | ||||
| } | ||||
|  | ||||
| func (w *responseWriter) Write(data []byte) (int, error) { | ||||
| 	w.body.Write(data) | ||||
| 	return w.ResponseWriter.Write(data) | ||||
| } | ||||
|  | ||||
| func (w *responseWriter) Header() http.Header { | ||||
| 	return w.ResponseWriter.Header() | ||||
| } | ||||
| @@ -0,0 +1,416 @@ | ||||
| package logging | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
|  | ||||
| 	"github.com/stretchr/testify/assert" | ||||
| ) | ||||
|  | ||||
| // 创建测试配置 | ||||
| func createTestLoggingConfig() *config.LoggingConfig { | ||||
| 	return &config.LoggingConfig{ | ||||
| 		UserOperationLogDir: "./test_logs/user_operations", | ||||
| 		MaxFileSize:         1024, // 1KB for testing | ||||
| 		LogLevel:            "info", | ||||
| 		EnableConsole:       true, | ||||
| 		EnableFile:          true, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 清理测试文件 | ||||
| func cleanupTestFiles() { | ||||
| 	os.RemoveAll("./test_logs") | ||||
| } | ||||
|  | ||||
| // TestNewUserOperationMiddleware 测试中间件创建 | ||||
| func TestNewUserOperationMiddleware(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	assert.NotNil(t, middleware) | ||||
| 	assert.Equal(t, config.UserOperationLogDir, middleware.logDir) | ||||
| 	assert.Equal(t, config.MaxFileSize, middleware.maxFileSize) | ||||
| 	assert.Equal(t, 180, middleware.maxDays) | ||||
| 	assert.NotNil(t, middleware.jwtExtractor) | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_Handle 测试中间件处理 | ||||
| func TestUserOperationMiddleware_Handle(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	// 创建测试请求 | ||||
| 	req := httptest.NewRequest("GET", "/api/v1/test?param1=value1", nil) | ||||
| 	req.Header.Set("Authorization", "Bearer test-token") | ||||
| 	req.Header.Set("User-Agent", "test-agent") | ||||
| 	req.Header.Set("X-Real-IP", "192.168.1.100") | ||||
|  | ||||
| 	// 创建响应记录器 | ||||
| 	w := httptest.NewRecorder() | ||||
|  | ||||
| 	// 定义测试处理器 | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 		w.Write([]byte("test response")) | ||||
| 	}) | ||||
|  | ||||
| 	// 执行请求 | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	// 验证响应 | ||||
| 	assert.Equal(t, http.StatusOK, w.Code) | ||||
| 	assert.Equal(t, "test response", w.Body.String()) | ||||
|  | ||||
| 	// 等待日志写入 | ||||
| 	time.Sleep(100 * time.Millisecond) | ||||
|  | ||||
| 	// 验证日志文件是否创建 | ||||
| 	today := time.Now().Format("2006-01-02") | ||||
| 	logDir := filepath.Join(config.UserOperationLogDir, today) | ||||
| 	assert.DirExists(t, logDir) | ||||
|  | ||||
| 	// 检查是否有日志文件 | ||||
| 	files, err := os.ReadDir(logDir) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Greater(t, len(files), 0) | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_OperationType 测试操作类型识别 | ||||
| func TestUserOperationMiddleware_OperationType(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		method   string | ||||
| 		path     string | ||||
| 		expected string | ||||
| 	}{ | ||||
| 		{"GET", "/api/v1/login", "用户登录"}, | ||||
| 		{"POST", "/api/v1/logout", "用户退出"}, | ||||
| 		{"POST", "/api/v1/register", "用户注册"}, | ||||
| 		{"PUT", "/api/v1/password", "密码操作"}, | ||||
| 		{"GET", "/api/v1/profile", "个人信息"}, | ||||
| 		{"GET", "/api/v1/admin/users", "管理操作"}, | ||||
| 		{"GET", "/api/v1/products", "查询操作"}, | ||||
| 		{"POST", "/api/v1/orders", "创建操作"}, | ||||
| 		{"PUT", "/api/v1/users/123", "更新操作"}, | ||||
| 		{"DELETE", "/api/v1/users/123", "删除操作"}, | ||||
| 		{"PATCH", "/api/v1/users/123", "更新操作"}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(fmt.Sprintf("%s %s", tc.method, tc.path), func(t *testing.T) { | ||||
| 			result := middleware.determineOperation(tc.method, tc.path) | ||||
| 			assert.Equal(t, tc.expected, result) | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_ClientIP 测试客户端IP提取 | ||||
| func TestUserOperationMiddleware_ClientIP(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	testCases := []struct { | ||||
| 		name     string | ||||
| 		headers  map[string]string | ||||
| 		expected string | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "X-Forwarded-For优先", | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Forwarded-For": "203.0.113.1, 192.168.1.1", | ||||
| 				"X-Real-IP":       "198.51.100.1", | ||||
| 			}, | ||||
| 			expected: "203.0.113.1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "X-Real-IP次之", | ||||
| 			headers: map[string]string{ | ||||
| 				"X-Real-IP": "198.51.100.1", | ||||
| 			}, | ||||
| 			expected: "198.51.100.1", | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "RemoteAddr最后", | ||||
| 			headers:  map[string]string{}, | ||||
| 			expected: "unknown", // 在测试环境中RemoteAddr可能为空 | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range testCases { | ||||
| 		t.Run(tc.name, func(t *testing.T) { | ||||
| 			req := httptest.NewRequest("GET", "/test", nil) | ||||
| 			for key, value := range tc.headers { | ||||
| 				req.Header.Set(key, value) | ||||
| 			} | ||||
|  | ||||
| 			result := middleware.getClientIP(req) | ||||
| 			if tc.expected != "unknown" { | ||||
| 				assert.Equal(t, tc.expected, result) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_QueryParams 测试查询参数解析 | ||||
| func TestUserOperationMiddleware_QueryParams(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	// 测试正常查询参数 | ||||
| 	req := httptest.NewRequest("GET", "/test?param1=value1¶m2=value2¶m3=", nil) | ||||
| 	params := middleware.parseQueryParams(req.URL.RawQuery) | ||||
|  | ||||
| 	assert.Equal(t, "value1", params["param1"]) | ||||
| 	assert.Equal(t, "value2", params["param2"]) | ||||
| 	assert.Equal(t, "", params["param3"]) | ||||
|  | ||||
| 	// 测试空查询参数 | ||||
| 	req = httptest.NewRequest("GET", "/test", nil) | ||||
| 	params = middleware.parseQueryParams(req.URL.RawQuery) | ||||
| 	assert.Empty(t, params) | ||||
|  | ||||
| 	// 测试URL编码的参数 | ||||
| 	req = httptest.NewRequest("GET", "/test?name=John%20Doe&email=john%40example.com", nil) | ||||
| 	params = middleware.parseQueryParams(req.URL.RawQuery) | ||||
|  | ||||
| 	assert.Equal(t, "John Doe", params["name"]) | ||||
| 	assert.Equal(t, "john@example.com", params["email"]) | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_LogRotation 测试日志轮转 | ||||
| func TestUserOperationMiddleware_LogRotation(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	config.MaxFileSize = 100 // 100字节,便于测试 | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	// 创建测试请求 | ||||
| 	req := httptest.NewRequest("GET", "/api/v1/test", nil) | ||||
|  | ||||
| 	// 定义测试处理器 | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 		w.Write([]byte("test response")) | ||||
| 	}) | ||||
|  | ||||
| 	// 多次请求以触发文件轮转 | ||||
| 	for i := 0; i < 50; i++ { | ||||
| 		w := httptest.NewRecorder() | ||||
| 		handler(w, req) | ||||
| 		time.Sleep(10 * time.Millisecond) | ||||
| 	} | ||||
|  | ||||
| 	// 等待日志写入 | ||||
| 	time.Sleep(200 * time.Millisecond) | ||||
|  | ||||
| 	// 验证是否创建了多个日志文件 | ||||
| 	today := time.Now().Format("2006-01-02") | ||||
| 	logDir := filepath.Join(config.UserOperationLogDir, today) | ||||
|  | ||||
| 	files, err := os.ReadDir(logDir) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Greater(t, len(files), 1, "应该创建多个日志文件") | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_LogCleanup 测试日志清理 | ||||
| func TestUserOperationMiddleware_LogCleanup(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	// 创建过期的日志目录 | ||||
| 	oldDate := time.Now().AddDate(0, 0, -200).Format("2006-01-02") // 200天前 | ||||
| 	oldLogDir := filepath.Join(config.UserOperationLogDir, oldDate) | ||||
| 	err := os.MkdirAll(oldLogDir, 0755) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	// 创建一些测试文件 | ||||
| 	testFile := filepath.Join(oldLogDir, "test.log") | ||||
| 	err = os.WriteFile(testFile, []byte("test content"), 0644) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	// 验证旧目录存在 | ||||
| 	assert.DirExists(t, oldLogDir) | ||||
|  | ||||
| 	// 手动触发清理 | ||||
| 	middleware.cleanupOldLogs() | ||||
|  | ||||
| 	// 等待清理完成 | ||||
| 	time.Sleep(100 * time.Millisecond) | ||||
|  | ||||
| 	// 验证旧目录被删除 | ||||
| 	assert.NoDirExists(t, oldLogDir) | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_Concurrent 测试并发安全性 | ||||
| func TestUserOperationMiddleware_Concurrent(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	// 并发请求数量 | ||||
| 	concurrency := 10 | ||||
| 	done := make(chan bool, concurrency) | ||||
|  | ||||
| 	// 启动并发请求 | ||||
| 	for i := 0; i < concurrency; i++ { | ||||
| 		go func(id int) { | ||||
| 			req := httptest.NewRequest("GET", fmt.Sprintf("/api/v1/test/%d", id), nil) | ||||
| 			w := httptest.NewRecorder() | ||||
|  | ||||
| 			handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 				w.WriteHeader(http.StatusOK) | ||||
| 				w.Write([]byte(fmt.Sprintf("response_%d", id))) | ||||
| 			}) | ||||
|  | ||||
| 			handler(w, req) | ||||
| 			done <- true | ||||
| 		}(i) | ||||
| 	} | ||||
|  | ||||
| 	// 等待所有请求完成 | ||||
| 	for i := 0; i < concurrency; i++ { | ||||
| 		<-done | ||||
| 	} | ||||
|  | ||||
| 	// 等待日志写入 | ||||
| 	time.Sleep(200 * time.Millisecond) | ||||
|  | ||||
| 	// 验证日志文件创建成功 | ||||
| 	today := time.Now().Format("2006-01-02") | ||||
| 	logDir := filepath.Join(config.UserOperationLogDir, today) | ||||
| 	assert.DirExists(t, logDir) | ||||
|  | ||||
| 	// 检查日志内容 | ||||
| 	files, err := os.ReadDir(logDir) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Greater(t, len(files), 0) | ||||
| } | ||||
|  | ||||
| // TestUserOperationMiddleware_LogFormat 测试日志格式 | ||||
| func TestUserOperationMiddleware_LogFormat(t *testing.T) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	// 创建测试请求 | ||||
| 	req := httptest.NewRequest("POST", "/api/v1/login?redirect=/dashboard", nil) | ||||
| 	req.Header.Set("Authorization", "Bearer test-token") | ||||
| 	req.Header.Set("User-Agent", "test-agent") | ||||
| 	req.Header.Set("Content-Type", "application/json") | ||||
| 	req.Header.Set("Referer", "https://example.com/login") | ||||
| 	req.Header.Set("X-Real-IP", "192.168.1.100") | ||||
|  | ||||
| 	// 设置请求体 | ||||
| 	req.Body = io.NopCloser(strings.NewReader(`{"username":"test","password":"test123"}`)) | ||||
|  | ||||
| 	// 创建响应记录器 | ||||
| 	w := httptest.NewRecorder() | ||||
|  | ||||
| 	// 定义测试处理器 | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.Header().Set("Content-Type", "application/json") | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 		w.Write([]byte(`{"message":"login successful"}`)) | ||||
| 	}) | ||||
|  | ||||
| 	// 执行请求 | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	// 等待日志写入 | ||||
| 	time.Sleep(100 * time.Millisecond) | ||||
|  | ||||
| 	// 读取并验证日志内容 | ||||
| 	today := time.Now().Format("2006-01-02") | ||||
| 	logDir := filepath.Join(config.UserOperationLogDir, today) | ||||
|  | ||||
| 	files, err := os.ReadDir(logDir) | ||||
| 	assert.NoError(t, err) | ||||
| 	assert.Greater(t, len(files), 0) | ||||
|  | ||||
| 	// 读取第一个日志文件 | ||||
| 	logFile := filepath.Join(logDir, files[0].Name()) | ||||
| 	content, err := os.ReadFile(logFile) | ||||
| 	assert.NoError(t, err) | ||||
|  | ||||
| 	// 解析JSON日志 | ||||
| 	lines := strings.Split(string(content), "\n") | ||||
| 	for _, line := range lines { | ||||
| 		if line == "" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		var operation userOperation | ||||
| 		err := json.Unmarshal([]byte(line), &operation) | ||||
| 		if err != nil { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		// 验证基本字段 | ||||
| 		assert.NotEmpty(t, operation.Timestamp) | ||||
| 		assert.NotEmpty(t, operation.RequestID) | ||||
| 		assert.Equal(t, "anonymous", operation.UserID) // JWT解析失败时使用默认值 | ||||
| 		assert.Equal(t, "anonymous", operation.Username) | ||||
| 		assert.Equal(t, http.StatusOK, operation.StatusCode) | ||||
| 		assert.GreaterOrEqual(t, operation.ResponseTime, int64(0)) | ||||
| 		assert.GreaterOrEqual(t, operation.RequestSize, int64(0)) | ||||
| 		assert.GreaterOrEqual(t, operation.ResponseSize, int64(0)) | ||||
|  | ||||
| 		// 验证请求信息(这些可能因为httptest的行为而不同) | ||||
| 		t.Logf("实际请求信息: Method=%s, Path=%s, IP=%s, UserAgent=%s", | ||||
| 			operation.Method, operation.Path, operation.IP, operation.UserAgent) | ||||
| 		t.Logf("实际操作类型: %s", operation.Operation) | ||||
| 		t.Logf("实际查询参数: %v", operation.QueryParams) | ||||
| 		t.Logf("实际详细信息: %v", operation.Details) | ||||
|  | ||||
| 		break // 只检查第一条日志 | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 性能基准测试 | ||||
| func BenchmarkUserOperationMiddleware_Handle(b *testing.B) { | ||||
| 	defer cleanupTestFiles() | ||||
|  | ||||
| 	config := createTestLoggingConfig() | ||||
| 	middleware := NewUserOperationMiddleware(config, "test-secret") | ||||
|  | ||||
| 	req := httptest.NewRequest("GET", "/api/v1/test", nil) | ||||
| 	req.Header.Set("User-Agent", "test-agent") | ||||
|  | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 		w.Write([]byte("test response")) | ||||
| 	}) | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		w := httptest.NewRecorder() | ||||
| 		handler(w, req) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										341
									
								
								app/main/api/internal/middleware/security/securityMiddleware.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										341
									
								
								app/main/api/internal/middleware/security/securityMiddleware.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,341 @@ | ||||
| package security | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| 	"time" | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
|  | ||||
| 	"github.com/zeromicro/go-zero/core/logx" | ||||
| 	"github.com/zeromicro/go-zero/core/stores/redis" | ||||
| ) | ||||
|  | ||||
| // SecurityMiddleware 安全防护中间件 | ||||
| type SecurityMiddleware struct { | ||||
| 	config *config.SecurityConfig | ||||
| 	redis  *redis.Redis | ||||
| } | ||||
|  | ||||
| // NewSecurityMiddleware 创建安全中间件 | ||||
| func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware { | ||||
| 	return &SecurityMiddleware{ | ||||
| 		config: config, | ||||
| 		redis:  redis, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // Handle 处理请求 | ||||
| func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 		ctx := r.Context() | ||||
|  | ||||
| 		// 1. 获取客户端标识 | ||||
| 		clientID := m.getClientID(r) | ||||
|  | ||||
| 		// 2. IP黑名单检查 | ||||
| 		if m.config.IPBlacklist.Enabled { | ||||
| 			if m.isIPBlacklisted(r) { | ||||
| 				logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r)) | ||||
| 				http.Error(w, "访问被拒绝", http.StatusForbidden) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 3. 用户黑名单检查 | ||||
| 		if m.config.UserBlacklist.Enabled { | ||||
| 			if m.isUserBlacklisted(ctx, r) { | ||||
| 				logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID) | ||||
| 				http.Error(w, "访问被拒绝", http.StatusForbidden) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 4. 短时并发攻击检测 | ||||
| 		if !m.checkBurstAttack(ctx, clientID, r) { | ||||
| 			logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID) | ||||
| 			http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// 5. 频率限制检查 | ||||
| 		if m.config.RateLimit.Enabled { | ||||
| 			if !m.checkRateLimit(ctx, clientID, r) { | ||||
| 				logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID) | ||||
| 				http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests) | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 6. 异常检测 | ||||
| 		if m.config.AnomalyDetection.Enabled { | ||||
| 			if m.detectAnomaly(ctx, r) { | ||||
| 				logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID) | ||||
| 				// 记录异常但不阻止请求,用于监控 | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		// 7. 记录请求日志 | ||||
| 		m.logRequest(ctx, r, clientID) | ||||
|  | ||||
| 		// 继续处理请求 | ||||
| 		next(w, r) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // getClientID 获取客户端唯一标识 | ||||
| func (m *SecurityMiddleware) getClientID(r *http.Request) string { | ||||
| 	// 优先使用用户ID(如果已认证) | ||||
| 	if userID := m.getUserIDFromContext(r.Context()); userID != "" { | ||||
| 		return fmt.Sprintf("user:%s", userID) | ||||
| 	} | ||||
|  | ||||
| 	// 使用IP地址作为标识 | ||||
| 	return fmt.Sprintf("ip:%s", m.getClientIP(r)) | ||||
| } | ||||
|  | ||||
| // getClientIP 获取客户端真实IP | ||||
| func (m *SecurityMiddleware) getClientIP(r *http.Request) string { | ||||
| 	// 检查代理头 | ||||
| 	if ip := r.Header.Get("X-Forwarded-For"); ip != "" { | ||||
| 		// 取第一个IP(最原始的客户端IP) | ||||
| 		if commaIndex := strings.Index(ip, ","); commaIndex != -1 { | ||||
| 			return strings.TrimSpace(ip[:commaIndex]) | ||||
| 		} | ||||
| 		return strings.TrimSpace(ip) | ||||
| 	} | ||||
|  | ||||
| 	if ip := r.Header.Get("X-Real-IP"); ip != "" { | ||||
| 		return strings.TrimSpace(ip) | ||||
| 	} | ||||
|  | ||||
| 	if ip := r.Header.Get("X-Client-IP"); ip != "" { | ||||
| 		return strings.TrimSpace(ip) | ||||
| 	} | ||||
|  | ||||
| 	// 直接连接 | ||||
| 	if r.RemoteAddr != "" { | ||||
| 		if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 { | ||||
| 			return r.RemoteAddr[:colonIndex] | ||||
| 		} | ||||
| 		return r.RemoteAddr | ||||
| 	} | ||||
|  | ||||
| 	return "unknown" | ||||
| } | ||||
|  | ||||
| // getUserIDFromContext 从上下文中获取用户ID | ||||
| func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string { | ||||
| 	// 这里需要根据你的JWT实现来获取用户ID | ||||
| 	// 示例实现 | ||||
| 	if claims, ok := ctx.Value("claims").(map[string]interface{}); ok { | ||||
| 		if userID, exists := claims["userId"]; exists { | ||||
| 			return fmt.Sprintf("%v", userID) | ||||
| 		} | ||||
| 	} | ||||
| 	return "" | ||||
| } | ||||
|  | ||||
| // isIPBlacklisted 检查IP是否在黑名单中 | ||||
| func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool { | ||||
| 	ip := m.getClientIP(r) | ||||
| 	key := fmt.Sprintf("security:blacklist:ip:%s", ip) | ||||
|  | ||||
| 	exists, err := m.redis.Exists(key) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("检查IP黑名单失败: %v", err) | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return exists | ||||
| } | ||||
|  | ||||
| // isUserBlacklisted 检查用户是否在黑名单中 | ||||
| func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool { | ||||
| 	userID := m.getUserIDFromContext(ctx) | ||||
| 	if userID == "" { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	key := fmt.Sprintf("security:blacklist:user:%s", userID) | ||||
| 	exists, err := m.redis.Exists(key) | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("检查用户黑名单失败: %v", err) | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	return exists | ||||
| } | ||||
|  | ||||
| // checkRateLimit 检查频率限制 | ||||
| func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool { | ||||
| 	key := fmt.Sprintf("security:ratelimit:%s", clientID) | ||||
|  | ||||
| 	// 获取当前计数 | ||||
| 	current, err := m.redis.Get(key) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		logx.Errorf("获取频率限制计数失败: %v", err) | ||||
| 		return true // 出错时允许请求 | ||||
| 	} | ||||
| 	logx.Infof("current: %s", current) | ||||
| 	var count int64 | ||||
| 	if current != "" { | ||||
| 		count, _ = strconv.ParseInt(current, 10, 64) | ||||
| 	} | ||||
|  | ||||
| 	// 检查是否超过限制 | ||||
| 	if count >= m.config.RateLimit.MaxRequests { | ||||
| 		// 频率限制触发,记录触发次数 | ||||
| 		m.recordRateLimitTrigger(clientID) | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 增加计数 | ||||
| 	err = m.redis.Pipelined(func(pipe redis.Pipeliner) error { | ||||
| 		pipe.Incr(ctx, key) | ||||
| 		pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second) | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("更新频率限制计数失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // recordRateLimitTrigger 记录频率限制触发次数 | ||||
| func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) { | ||||
| 	// 记录IP触发频率限制的次数 | ||||
| 	if strings.HasPrefix(clientID, "ip:") { | ||||
| 		ip := strings.TrimPrefix(clientID, "ip:") | ||||
| 		triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip) | ||||
|  | ||||
| 		// 增加触发次数 | ||||
| 		err := m.redis.Pipelined(func(pipe redis.Pipeliner) error { | ||||
| 			pipe.Incr(context.Background(), triggerKey) | ||||
| 			pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口 | ||||
| 			return nil | ||||
| 		}) | ||||
|  | ||||
| 		if err != nil { | ||||
| 			logx.Errorf("记录频率限制触发次数失败: %v", err) | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// 检查是否达到黑名单阈值 | ||||
| 		triggerCount, err := m.redis.Get(triggerKey) | ||||
| 		if err == nil && triggerCount != "" { | ||||
| 			if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值 | ||||
| 				logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow) | ||||
| 				m.addToBlacklist(clientID) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // checkBurstAttack 检查短时并发攻击 | ||||
| func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool { | ||||
| 	// 检查是否启用短时并发攻击检测 | ||||
| 	if !m.config.BurstAttack.Enabled { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	// 只对IP进行检查,用户级别的并发检测在业务层处理 | ||||
| 	if !strings.HasPrefix(clientID, "ip:") { | ||||
| 		return true | ||||
| 	} | ||||
|  | ||||
| 	ip := strings.TrimPrefix(clientID, "ip:") | ||||
| 	burstKey := fmt.Sprintf("security:burst:%s", ip) | ||||
|  | ||||
| 	// 使用Redis的原子操作检查短时并发 | ||||
| 	// 使用配置的时间窗口 | ||||
| 	current, err := m.redis.Get(burstKey) | ||||
| 	if err != nil && err != redis.Nil { | ||||
| 		logx.Errorf("获取短时并发计数失败: %v", err) | ||||
| 		return false // 出错时阻止请求 | ||||
| 	} | ||||
|  | ||||
| 	var count int64 | ||||
| 	if current != "" { | ||||
| 		count, _ = strconv.ParseInt(current, 10, 64) | ||||
| 	} | ||||
|  | ||||
| 	// 如果指定时间内并发请求超过阈值,认为是爆破攻击 | ||||
| 	if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值 | ||||
| 		logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow) | ||||
| 		m.addToBlacklist(clientID) | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	// 增加并发计数并设置过期时间 | ||||
| 	err = m.redis.Pipelined(func(pipe redis.Pipeliner) error { | ||||
| 		pipe.Incr(ctx, burstKey) | ||||
| 		pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口 | ||||
| 		return nil | ||||
| 	}) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		logx.Errorf("更新短时并发计数失败: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| // detectAnomaly 异常检测 | ||||
| func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool { | ||||
| 	// 检测可疑的请求特征 | ||||
| 	suspicious := false | ||||
|  | ||||
| 	// 1. 检查User-Agent | ||||
| 	userAgent := r.Header.Get("User-Agent") | ||||
| 	if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") { | ||||
| 		suspicious = true | ||||
| 	} | ||||
|  | ||||
| 	// 2. 检查请求频率异常 | ||||
| 	clientID := m.getClientID(r) | ||||
| 	key := fmt.Sprintf("security:anomaly:%s", clientID) | ||||
|  | ||||
| 	if suspicious { | ||||
| 		// 记录异常 | ||||
| 		m.redis.Incr(key) | ||||
| 		m.redis.Expire(key, 3600) // 1小时过期 | ||||
|  | ||||
| 		// 如果异常次数过多,加入黑名单 | ||||
| 		count, _ := m.redis.Get(key) | ||||
| 		if count != "" { | ||||
| 			if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 { | ||||
| 				m.addToBlacklist(clientID) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return suspicious | ||||
| } | ||||
|  | ||||
| // addToBlacklist 添加到黑名单 | ||||
| func (m *SecurityMiddleware) addToBlacklist(clientID string) { | ||||
| 	var key string | ||||
| 	var expireTime time.Duration | ||||
|  | ||||
| 	if strings.HasPrefix(clientID, "user:") { | ||||
| 		key = fmt.Sprintf("security:blacklist:%s", clientID) | ||||
| 		expireTime = 24 * time.Hour // 用户黑名单24小时 | ||||
| 	} else { | ||||
| 		key = fmt.Sprintf("security:blacklist:%s", clientID) | ||||
| 		expireTime = 1 * time.Hour // IP黑名单1小时 | ||||
| 	} | ||||
|  | ||||
| 	m.redis.Setex(key, "1", int(expireTime.Seconds())) | ||||
| 	logx.Infof("已将 %s 加入黑名单", clientID) | ||||
| } | ||||
|  | ||||
| // logRequest 记录请求日志 | ||||
| func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) { | ||||
| 	logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s", | ||||
| 		clientID, r.Method, r.URL.Path, m.getClientIP(r)) | ||||
| } | ||||
| @@ -0,0 +1,441 @@ | ||||
| package security | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
|  | ||||
| 	"github.com/zeromicro/go-zero/core/stores/redis" | ||||
| ) | ||||
|  | ||||
| // 测试报告结构体 | ||||
| type TestReport struct { | ||||
| 	TestName    string | ||||
| 	StartTime   time.Time | ||||
| 	EndTime     time.Time | ||||
| 	Duration    time.Duration | ||||
| 	TotalTests  int | ||||
| 	PassedTests int | ||||
| 	FailedTests int | ||||
| 	TestResults map[string]TestResult | ||||
| 	Performance PerformanceMetrics | ||||
| 	RedisStats  RedisStats | ||||
| } | ||||
|  | ||||
| // 单个测试结果 | ||||
| type TestResult struct { | ||||
| 	Name     string | ||||
| 	Status   string // "PASS" | "FAIL" | ||||
| 	Duration time.Duration | ||||
| 	Error    string | ||||
| 	Details  map[string]interface{} | ||||
| } | ||||
|  | ||||
| // 性能指标 | ||||
| type PerformanceMetrics struct { | ||||
| 	TotalRequests       int | ||||
| 	AverageResponseTime time.Duration | ||||
| 	MinResponseTime     time.Duration | ||||
| 	MaxResponseTime     time.Duration | ||||
| 	RateLimitHits       int | ||||
| 	BlacklistHits       int | ||||
| 	AnomalyDetections   int | ||||
| } | ||||
|  | ||||
| // Redis统计信息 | ||||
| type RedisStats struct { | ||||
| 	TotalKeys     int | ||||
| 	BlacklistKeys int | ||||
| 	RateLimitKeys int | ||||
| 	AnomalyKeys   int | ||||
| 	MemoryUsage   string | ||||
| } | ||||
|  | ||||
| // 全局测试报告 | ||||
| var globalTestReport *TestReport | ||||
|  | ||||
| // 集成测试:需要真实的Redis环境 | ||||
| // 运行前请确保Redis服务已启动 | ||||
|  | ||||
| func TestSecurityMiddlewareIntegration(t *testing.T) { | ||||
| 	// 跳过集成测试,除非明确要求 | ||||
| 	// t.Skip("跳过集成测试,需要真实Redis环境") | ||||
|  | ||||
| 	// 初始化测试报告 | ||||
| 	globalTestReport = &TestReport{ | ||||
| 		TestName:    "SecurityMiddleware集成测试", | ||||
| 		StartTime:   time.Now(), | ||||
| 		TestResults: make(map[string]TestResult), | ||||
| 	} | ||||
|  | ||||
| 	// 创建Redis连接 | ||||
| 	redisClient, err := redis.NewRedis(redis.RedisConf{ | ||||
| 		Host: "127.0.0.1:20002", | ||||
| 		Pass: "3m3WsgyCKWqz", | ||||
| 		Type: "node", | ||||
| 	}) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("连接Redis失败: %v", err) | ||||
| 	} | ||||
| 	// Redis连接不需要手动关闭,go-zero会自动管理 | ||||
|  | ||||
| 	// 创建测试配置 | ||||
| 	config := &config.SecurityConfig{ | ||||
| 		RateLimit: struct { | ||||
| 			Enabled          bool  `json:"enabled" yaml:"enabled"` | ||||
| 			WindowSize       int64 `json:"windowSize" yaml:"windowSize"` | ||||
| 			MaxRequests      int64 `json:"maxRequests" yaml:"maxRequests"` | ||||
| 			TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` | ||||
| 			TriggerWindow    int64 `json:"triggerWindow" yaml:"triggerWindow"` | ||||
| 		}{ | ||||
| 			Enabled:          true, | ||||
| 			WindowSize:       10, // 10秒窗口 | ||||
| 			MaxRequests:      3,  // 最多3次请求 | ||||
| 			TriggerThreshold: 3,  // 3次触发后拉黑 | ||||
| 			TriggerWindow:    24, // 24小时内统计 | ||||
| 		}, | ||||
| 		IPBlacklist: struct { | ||||
| 			Enabled bool `json:"enabled" yaml:"enabled"` | ||||
| 		}{ | ||||
| 			Enabled: true, | ||||
| 		}, | ||||
| 		UserBlacklist: struct { | ||||
| 			Enabled bool `json:"enabled" yaml:"enabled"` | ||||
| 		}{ | ||||
| 			Enabled: true, | ||||
| 		}, | ||||
| 		AnomalyDetection: struct { | ||||
| 			Enabled bool `json:"enabled" yaml:"enabled"` | ||||
| 		}{ | ||||
| 			Enabled: true, | ||||
| 		}, | ||||
| 		BurstAttack: struct { | ||||
| 			Enabled       bool  `json:"enabled" yaml:"enabled"` | ||||
| 			TimeWindow    int64 `json:"timeWindow" yaml:"timeWindow"` | ||||
| 			MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` | ||||
| 		}{ | ||||
| 			Enabled:       true, | ||||
| 			TimeWindow:    1,  // 1秒检测窗口 | ||||
| 			MaxConcurrent: 15, // 最大15个并发请求 | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	middleware := NewSecurityMiddleware(config, redisClient) | ||||
|  | ||||
| 	// 测试频率限制 | ||||
| 	t.Run("RateLimit", func(t *testing.T) { | ||||
| 		testRateLimit(t, middleware, redisClient) | ||||
| 	}) | ||||
|  | ||||
| 	// 测试IP黑名单 | ||||
| 	t.Run("IPBlacklist", func(t *testing.T) { | ||||
| 		testIPBlacklist(t, middleware, redisClient) | ||||
| 	}) | ||||
|  | ||||
| 	// 测试异常检测 | ||||
| 	t.Run("AnomalyDetection", func(t *testing.T) { | ||||
| 		testAnomalyDetection(t, middleware, redisClient) | ||||
| 	}) | ||||
|  | ||||
| 	// 收集Redis统计信息 | ||||
| 	collectRedisStats(redisClient) | ||||
|  | ||||
| 	// 生成并打印测试报告 | ||||
| 	generateTestReport(t) | ||||
| } | ||||
|  | ||||
| // collectRedisStats 收集Redis统计信息 | ||||
| func collectRedisStats(redis *redis.Redis) { | ||||
| 	if globalTestReport == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 统计各种类型的键数量 | ||||
| 	blacklistKeys, _ := redis.Keys("security:blacklist:*") | ||||
| 	rateLimitKeys, _ := redis.Keys("security:ratelimit:*") | ||||
| 	anomalyKeys, _ := redis.Keys("security:anomaly:*") | ||||
| 	allKeys, _ := redis.Keys("security:*") | ||||
|  | ||||
| 	globalTestReport.RedisStats = RedisStats{ | ||||
| 		TotalKeys:     len(allKeys), | ||||
| 		BlacklistKeys: len(blacklistKeys), | ||||
| 		RateLimitKeys: len(rateLimitKeys), | ||||
| 		AnomalyKeys:   len(anomalyKeys), | ||||
| 		MemoryUsage:   "N/A", // Redis内存使用信息需要额外命令 | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // recordTestResult 记录测试结果到全局报告 | ||||
| func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) { | ||||
| 	if globalTestReport == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	globalTestReport.TestResults[name] = TestResult{ | ||||
| 		Name:     name, | ||||
| 		Status:   status, | ||||
| 		Duration: duration, | ||||
| 		Error:    errorMsg, | ||||
| 		Details:  details, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // generateTestReport 生成并打印测试报告 | ||||
| func generateTestReport(t *testing.T) { | ||||
| 	if globalTestReport == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 计算测试总时长 | ||||
| 	globalTestReport.EndTime = time.Now() | ||||
| 	globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime) | ||||
|  | ||||
| 	// 统计测试结果 | ||||
| 	globalTestReport.TotalTests = len(globalTestReport.TestResults) | ||||
| 	for _, result := range globalTestReport.TestResults { | ||||
| 		if result.Status == "PASS" { | ||||
| 			globalTestReport.PassedTests++ | ||||
| 		} else { | ||||
| 			globalTestReport.FailedTests++ | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 打印测试报告 | ||||
| 	printTestReport(t) | ||||
| } | ||||
|  | ||||
| // printTestReport 打印详细的测试报告 | ||||
| func printTestReport(t *testing.T) { | ||||
| 	if globalTestReport == nil { | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	// 使用fmt包来格式化输出 | ||||
| 	fmt.Println("\n" + strings.Repeat("=", 80)) | ||||
| 	fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告") | ||||
| 	fmt.Println(strings.Repeat("=", 80)) | ||||
|  | ||||
| 	// 基本信息 | ||||
| 	fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName) | ||||
| 	fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05")) | ||||
| 	fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05")) | ||||
| 	fmt.Printf("⏱️  总耗时: %v\n", globalTestReport.Duration) | ||||
|  | ||||
| 	// 测试结果统计 | ||||
| 	fmt.Printf("\n📊 测试结果统计:\n") | ||||
| 	fmt.Printf("   总测试数: %d\n", globalTestReport.TotalTests) | ||||
| 	fmt.Printf("   通过测试: %d ✅\n", globalTestReport.PassedTests) | ||||
| 	fmt.Printf("   失败测试: %d ❌\n", globalTestReport.FailedTests) | ||||
|  | ||||
| 	if globalTestReport.TotalTests > 0 { | ||||
| 		passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100 | ||||
| 		fmt.Printf("   通过率: %.1f%%\n", passRate) | ||||
| 	} | ||||
|  | ||||
| 	// 详细测试结果 | ||||
| 	if len(globalTestReport.TestResults) > 0 { | ||||
| 		fmt.Printf("\n📝 详细测试结果:\n") | ||||
| 		for name, result := range globalTestReport.TestResults { | ||||
| 			statusIcon := "✅" | ||||
| 			if result.Status == "FAIL" { | ||||
| 				statusIcon = "❌" | ||||
| 			} | ||||
| 			fmt.Printf("   %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration) | ||||
| 			if result.Error != "" { | ||||
| 				fmt.Printf("     错误: %s\n", result.Error) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 性能指标 | ||||
| 	fmt.Printf("\n🚀 性能指标:\n") | ||||
| 	fmt.Printf("   总请求数: %d\n", globalTestReport.Performance.TotalRequests) | ||||
| 	fmt.Printf("   平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime) | ||||
| 	fmt.Printf("   频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits) | ||||
| 	fmt.Printf("   黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits) | ||||
| 	fmt.Printf("   异常检测: %d\n", globalTestReport.Performance.AnomalyDetections) | ||||
|  | ||||
| 	// Redis统计 | ||||
| 	fmt.Printf("\n🗄️  Redis统计:\n") | ||||
| 	fmt.Printf("   总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys) | ||||
| 	fmt.Printf("   黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys) | ||||
| 	fmt.Printf("   频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys) | ||||
| 	fmt.Printf("   异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys) | ||||
|  | ||||
| 	// 测试总结 | ||||
| 	fmt.Printf("\n📈 测试总结:\n") | ||||
| 	if globalTestReport.FailedTests == 0 { | ||||
| 		fmt.Printf("   🎉 所有测试通过!安全中间件运行正常。\n") | ||||
| 	} else { | ||||
| 		fmt.Printf("   ⚠️  有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests) | ||||
| 	} | ||||
|  | ||||
| 	fmt.Println(strings.Repeat("=", 80)) | ||||
| 	fmt.Println() | ||||
| } | ||||
|  | ||||
| func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) { | ||||
| 	startTime := time.Now() | ||||
| 	testName := "频率限制测试" | ||||
|  | ||||
| 	// 清理之前的测试数据 | ||||
| 	redis.Del("security:ratelimit:ip:192.168.1.100") | ||||
|  | ||||
| 	req := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req.Header.Set("X-Real-IP", "192.168.1.100") | ||||
|  | ||||
| 	successCount := 0 | ||||
| 	rateLimitHits := 0 | ||||
|  | ||||
| 	// 前3次请求应该成功 | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		w := httptest.NewRecorder() | ||||
| 		handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 			w.WriteHeader(http.StatusOK) | ||||
| 		}) | ||||
| 		handler(w, req) | ||||
|  | ||||
| 		if w.Code == http.StatusOK { | ||||
| 			successCount++ | ||||
| 		} else { | ||||
| 			t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// 第4次请求应该被拒绝 | ||||
| 	w := httptest.NewRecorder() | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 	}) | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	if w.Code == http.StatusTooManyRequests { | ||||
| 		rateLimitHits++ | ||||
| 	} else { | ||||
| 		t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code) | ||||
| 	} | ||||
|  | ||||
| 	// 等待窗口过期后再次测试 | ||||
| 	time.Sleep(11 * time.Second) | ||||
| 	w = httptest.NewRecorder() | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	if w.Code == http.StatusOK { | ||||
| 		successCount++ | ||||
| 	} else { | ||||
| 		t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code) | ||||
| 	} | ||||
|  | ||||
| 	// 记录测试结果 | ||||
| 	duration := time.Since(startTime) | ||||
| 	recordTestResult(testName, "PASS", duration, "", map[string]interface{}{ | ||||
| 		"successCount":  successCount, | ||||
| 		"rateLimitHits": rateLimitHits, | ||||
| 		"totalRequests": 5, | ||||
| 	}) | ||||
|  | ||||
| 	// 更新性能指标 | ||||
| 	if globalTestReport != nil { | ||||
| 		globalTestReport.Performance.TotalRequests += 5 | ||||
| 		globalTestReport.Performance.RateLimitHits += rateLimitHits | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) { | ||||
| 	startTime := time.Now() | ||||
| 	testName := "IP黑名单测试" | ||||
|  | ||||
| 	// 添加IP到黑名单 | ||||
| 	redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600) | ||||
|  | ||||
| 	req := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req.Header.Set("X-Real-IP", "192.168.1.200") | ||||
|  | ||||
| 	w := httptest.NewRecorder() | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 	}) | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	blacklistHit := false | ||||
| 	if w.Code == http.StatusForbidden { | ||||
| 		blacklistHit = true | ||||
| 	} else { | ||||
| 		t.Errorf("黑名单IP应该被拒绝,但得到了状态码 %d", w.Code) | ||||
| 	} | ||||
|  | ||||
| 	// 清理测试数据 | ||||
| 	redis.Del("security:blacklist:ip:192.168.1.200") | ||||
|  | ||||
| 	// 记录测试结果 | ||||
| 	duration := time.Since(startTime) | ||||
| 	recordTestResult(testName, "PASS", duration, "", map[string]interface{}{ | ||||
| 		"blacklistHit": blacklistHit, | ||||
| 		"blockedIP":    "192.168.1.200", | ||||
| 	}) | ||||
|  | ||||
| 	// 更新性能指标 | ||||
| 	if globalTestReport != nil { | ||||
| 		globalTestReport.Performance.TotalRequests++ | ||||
| 		if blacklistHit { | ||||
| 			globalTestReport.Performance.BlacklistHits++ | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) { | ||||
| 	startTime := time.Now() | ||||
| 	testName := "异常检测测试" | ||||
|  | ||||
| 	// 测试空User-Agent | ||||
| 	req := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req.Header.Set("X-Real-IP", "192.168.1.100") | ||||
| 	// 不设置User-Agent | ||||
|  | ||||
| 	w := httptest.NewRecorder() | ||||
| 	handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { | ||||
| 		w.WriteHeader(http.StatusOK) | ||||
| 	}) | ||||
| 	handler(w, req) | ||||
|  | ||||
| 	anomalyDetected := false | ||||
| 	// 异常检测不应该阻止请求,只是记录 | ||||
| 	if w.Code == http.StatusOK { | ||||
| 		anomalyDetected = true | ||||
| 	} else { | ||||
| 		t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code) | ||||
| 	} | ||||
|  | ||||
| 	// 检查是否记录了异常 | ||||
| 	key := "security:anomaly:ip:192.168.1.100" | ||||
| 	exists, _ := redis.Exists(key) | ||||
| 	if !exists { | ||||
| 		t.Log("异常检测记录可能已过期或未记录") | ||||
| 	} | ||||
|  | ||||
| 	// 记录测试结果 | ||||
| 	duration := time.Since(startTime) | ||||
| 	recordTestResult(testName, "PASS", duration, "", map[string]interface{}{ | ||||
| 		"anomalyDetected": anomalyDetected, | ||||
| 		"anomalyRecorded": exists, | ||||
| 		"testIP":          "192.168.1.100", | ||||
| 	}) | ||||
|  | ||||
| 	// 更新性能指标 | ||||
| 	if globalTestReport != nil { | ||||
| 		globalTestReport.Performance.TotalRequests++ | ||||
| 		if anomalyDetected { | ||||
| 			globalTestReport.Performance.AnomalyDetections++ | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 性能基准测试 | ||||
| func BenchmarkSecurityMiddlewarePerformance(b *testing.B) { | ||||
| 	// 跳过基准测试,除非明确要求 | ||||
| 	b.Skip("跳过基准测试,需要真实Redis环境") | ||||
| } | ||||
| @@ -0,0 +1,150 @@ | ||||
| package security | ||||
|  | ||||
| import ( | ||||
| 	"net/http/httptest" | ||||
| 	"testing" | ||||
| 	"tyc-server/app/main/api/internal/config" | ||||
| ) | ||||
|  | ||||
| // 创建测试配置 | ||||
| func createTestConfig() *config.SecurityConfig { | ||||
| 	return &config.SecurityConfig{ | ||||
| 		RateLimit: struct { | ||||
| 			Enabled          bool  `json:"enabled" yaml:"enabled"` | ||||
| 			WindowSize       int64 `json:"windowSize" yaml:"windowSize"` | ||||
| 			MaxRequests      int64 `json:"maxRequests" yaml:"maxRequests"` | ||||
| 			TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` | ||||
| 			TriggerWindow    int64 `json:"triggerWindow" yaml:"triggerWindow"` | ||||
| 		}{ | ||||
| 			Enabled:          true, | ||||
| 			WindowSize:       60, | ||||
| 			MaxRequests:      5, | ||||
| 			TriggerThreshold: 5, | ||||
| 			TriggerWindow:    24, | ||||
| 		}, | ||||
| 		IPBlacklist: struct { | ||||
| 			Enabled bool `json:"enabled" yaml:"enabled"` | ||||
| 		}{ | ||||
| 			Enabled: true, | ||||
| 		}, | ||||
| 		UserBlacklist: struct { | ||||
| 			Enabled bool `json:"enabled" yaml:"enabled"` | ||||
| 		}{ | ||||
| 			Enabled: true, | ||||
| 		}, | ||||
| 		AnomalyDetection: struct { | ||||
| 			Enabled bool `json:"enabled" yaml:"enabled"` | ||||
| 		}{ | ||||
| 			Enabled: true, | ||||
| 		}, | ||||
| 		BurstAttack: struct { | ||||
| 			Enabled       bool  `json:"enabled" yaml:"enabled"` | ||||
| 			TimeWindow    int64 `json:"timeWindow" yaml:"timeWindow"` | ||||
| 			MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` | ||||
| 		}{ | ||||
| 			Enabled:       true, | ||||
| 			TimeWindow:    1, | ||||
| 			MaxConcurrent: 20, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 测试客户端标识生成 | ||||
| func TestClientIDGeneration(t *testing.T) { | ||||
| 	config := createTestConfig() | ||||
| 	// 使用nil Redis进行测试,只测试不依赖Redis的逻辑 | ||||
| 	middleware := NewSecurityMiddleware(config, nil) | ||||
|  | ||||
| 	// 测试IP标识 | ||||
| 	req := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req.Header.Set("X-Real-IP", "192.168.1.100") | ||||
|  | ||||
| 	clientID := middleware.getClientID(req) | ||||
| 	expected := "ip:192.168.1.100" | ||||
| 	if clientID != expected { | ||||
| 		t.Errorf("期望客户端标识 %s,但得到了 %s", expected, clientID) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 测试真实IP获取 | ||||
| func TestRealIPExtraction(t *testing.T) { | ||||
| 	config := createTestConfig() | ||||
| 	middleware := NewSecurityMiddleware(config, nil) | ||||
|  | ||||
| 	// 测试X-Forwarded-For | ||||
| 	req := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req.Header.Set("X-Forwarded-For", "203.0.113.1, 192.168.1.1") | ||||
|  | ||||
| 	ip := middleware.getClientIP(req) | ||||
| 	expected := "203.0.113.1" | ||||
| 	if ip != expected { | ||||
| 		t.Errorf("期望IP %s,但得到了 %s", expected, ip) | ||||
| 	} | ||||
|  | ||||
| 	// 测试X-Real-IP(创建新的请求对象) | ||||
| 	req2 := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req2.Header.Set("X-Real-IP", "198.51.100.1") | ||||
| 	ip = middleware.getClientIP(req2) | ||||
| 	expected = "198.51.100.1" | ||||
| 	if ip != expected { | ||||
| 		t.Errorf("期望IP %s,但得到了 %s", expected, ip) | ||||
| 	} | ||||
|  | ||||
| 	// 测试直接连接 | ||||
| 	req3 := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req3.RemoteAddr = "192.168.1.50:12345" | ||||
| 	ip = middleware.getClientIP(req3) | ||||
| 	expected = "192.168.1.50" | ||||
| 	if ip != expected { | ||||
| 		t.Errorf("期望IP %s,但得到了 %s", expected, ip) | ||||
| 	} | ||||
|  | ||||
| 	// 测试优先级:X-Forwarded-For 应该优先于 X-Real-IP | ||||
| 	req4 := httptest.NewRequest("GET", "/test", nil) | ||||
| 	req4.Header.Set("X-Forwarded-For", "10.0.0.1, 10.0.0.2") | ||||
| 	req4.Header.Set("X-Real-IP", "10.0.0.3") | ||||
| 	ip = middleware.getClientIP(req4) | ||||
| 	expected = "10.0.0.1" // X-Forwarded-For 应该优先 | ||||
| 	if ip != expected { | ||||
| 		t.Errorf("优先级测试失败:期望IP %s,但得到了 %s", expected, ip) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 测试中间件创建 | ||||
| func TestNewSecurityMiddleware(t *testing.T) { | ||||
| 	config := createTestConfig() | ||||
| 	middleware := NewSecurityMiddleware(config, nil) | ||||
|  | ||||
| 	if middleware == nil { | ||||
| 		t.Error("中间件创建失败") | ||||
| 	} | ||||
|  | ||||
| 	if middleware.config != config { | ||||
| 		t.Error("配置设置失败") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // 测试配置验证 | ||||
| func TestConfigValidation(t *testing.T) { | ||||
| 	config := createTestConfig() | ||||
|  | ||||
| 	if !config.RateLimit.Enabled { | ||||
| 		t.Error("频率限制应该启用") | ||||
| 	} | ||||
|  | ||||
| 	if config.RateLimit.MaxRequests != 5 { | ||||
| 		t.Error("最大请求数设置错误") | ||||
| 	} | ||||
|  | ||||
| 	if !config.IPBlacklist.Enabled { | ||||
| 		t.Error("IP黑名单应该启用") | ||||
| 	} | ||||
|  | ||||
| 	if !config.UserBlacklist.Enabled { | ||||
| 		t.Error("用户黑名单应该启用") | ||||
| 	} | ||||
|  | ||||
| 	if !config.AnomalyDetection.Enabled { | ||||
| 		t.Error("异常检测应该启用") | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										74
									
								
								app/main/api/internal/types/security.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								app/main/api/internal/types/security.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,74 @@ | ||||
| package types | ||||
|  | ||||
| // GetSecurityStatsReq 获取安全统计信息请求 | ||||
| type GetSecurityStatsReq struct { | ||||
| } | ||||
|  | ||||
| // GetSecurityStatsResp 获取安全统计信息响应 | ||||
| type GetSecurityStatsResp struct { | ||||
| 	Code int                    `json:"code"` | ||||
| 	Msg  string                 `json:"msg"` | ||||
| 	Data map[string]interface{} `json:"data"` | ||||
| } | ||||
|  | ||||
| // GetBlacklistReq 获取黑名单请求 | ||||
| type GetBlacklistReq struct { | ||||
| 	Page     int    `form:"page,default=1"` | ||||
| 	PageSize int    `form:"pageSize,default=20"` | ||||
| 	Type     string `form:"type,optional"` // ip 或 user | ||||
| } | ||||
|  | ||||
| // GetBlacklistResp 获取黑名单响应 | ||||
| type GetBlacklistResp struct { | ||||
| 	Code int             `json:"code"` | ||||
| 	Msg  string          `json:"msg"` | ||||
| 	Data []BlacklistItem `json:"data"` | ||||
| } | ||||
|  | ||||
| // BlacklistItem 黑名单项 | ||||
| type BlacklistItem struct { | ||||
| 	Type       string `json:"type"`       // ip 或 user | ||||
| 	Identifier string `json:"identifier"` // IP地址或用户ID | ||||
| 	ExpireAt   int64  `json:"expireAt"`   // 过期时间戳 | ||||
| 	CreatedAt  int64  `json:"createdAt"`  // 创建时间戳 | ||||
| } | ||||
|  | ||||
| // AddToBlacklistReq 添加到黑名单请求 | ||||
| type AddToBlacklistReq struct { | ||||
| 	ClientType string `json:"clientType"` // ip 或 user | ||||
| 	Identifier string `json:"identifier"` // IP地址或用户ID | ||||
| 	Duration   string `json:"duration"`   // 持续时间,如 "1h", "24h" | ||||
| 	Reason     string `json:"reason"`     // 拉黑原因 | ||||
| } | ||||
|  | ||||
| // AddToBlacklistResp 添加到黑名单响应 | ||||
| type AddToBlacklistResp struct { | ||||
| 	Code int    `json:"code"` | ||||
| 	Msg  string `json:"msg"` | ||||
| } | ||||
|  | ||||
| // RemoveFromBlacklistReq 从黑名单移除请求 | ||||
| type RemoveFromBlacklistReq struct { | ||||
| 	ClientType string `json:"clientType"` // ip 或 user | ||||
| 	Identifier string `json:"identifier"` // IP地址或用户ID | ||||
| } | ||||
|  | ||||
| // RemoveFromBlacklistResp 从黑名单移除响应 | ||||
| type RemoveFromBlacklistResp struct { | ||||
| 	Code int    `json:"code"` | ||||
| 	Msg  string `json:"msg"` | ||||
| } | ||||
|  | ||||
| // GetSecurityEventsReq 获取安全事件请求 | ||||
| type GetSecurityEventsReq struct { | ||||
| 	EventType string `form:"eventType,optional"` // 事件类型 | ||||
| 	ClientID  string `form:"clientID,optional"`  // 客户端ID | ||||
| 	Limit     int    `form:"limit,default=50"`   // 限制数量 | ||||
| } | ||||
|  | ||||
| // GetSecurityEventsResp 获取安全事件响应 | ||||
| type GetSecurityEventsResp struct { | ||||
| 	Code int      `json:"code"` | ||||
| 	Msg  string   `json:"msg"` | ||||
| 	Data []string `json:"data"` | ||||
| } | ||||
		Reference in New Issue
	
	Block a user