fix
This commit is contained in:
341
app/main/api/internal/middleware/security/securityMiddleware.go
Normal file
341
app/main/api/internal/middleware/security/securityMiddleware.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// SecurityMiddleware 安全防护中间件
|
||||
type SecurityMiddleware struct {
|
||||
config *config.SecurityConfig
|
||||
redis *redis.Redis
|
||||
}
|
||||
|
||||
// NewSecurityMiddleware 创建安全中间件
|
||||
func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware {
|
||||
return &SecurityMiddleware{
|
||||
config: config,
|
||||
redis: redis,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 处理请求
|
||||
func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// 1. 获取客户端标识
|
||||
clientID := m.getClientID(r)
|
||||
|
||||
// 2. IP黑名单检查
|
||||
if m.config.IPBlacklist.Enabled {
|
||||
if m.isIPBlacklisted(r) {
|
||||
logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r))
|
||||
http.Error(w, "访问被拒绝", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 用户黑名单检查
|
||||
if m.config.UserBlacklist.Enabled {
|
||||
if m.isUserBlacklisted(ctx, r) {
|
||||
logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID)
|
||||
http.Error(w, "访问被拒绝", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 4. 短时并发攻击检测
|
||||
if !m.checkBurstAttack(ctx, clientID, r) {
|
||||
logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID)
|
||||
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
|
||||
// 5. 频率限制检查
|
||||
if m.config.RateLimit.Enabled {
|
||||
if !m.checkRateLimit(ctx, clientID, r) {
|
||||
logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID)
|
||||
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 6. 异常检测
|
||||
if m.config.AnomalyDetection.Enabled {
|
||||
if m.detectAnomaly(ctx, r) {
|
||||
logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID)
|
||||
// 记录异常但不阻止请求,用于监控
|
||||
}
|
||||
}
|
||||
|
||||
// 7. 记录请求日志
|
||||
m.logRequest(ctx, r, clientID)
|
||||
|
||||
// 继续处理请求
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// getClientID 获取客户端唯一标识
|
||||
func (m *SecurityMiddleware) getClientID(r *http.Request) string {
|
||||
// 优先使用用户ID(如果已认证)
|
||||
if userID := m.getUserIDFromContext(r.Context()); userID != "" {
|
||||
return fmt.Sprintf("user:%s", userID)
|
||||
}
|
||||
|
||||
// 使用IP地址作为标识
|
||||
return fmt.Sprintf("ip:%s", m.getClientIP(r))
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端真实IP
|
||||
func (m *SecurityMiddleware) getClientIP(r *http.Request) string {
|
||||
// 检查代理头
|
||||
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
||||
// 取第一个IP(最原始的客户端IP)
|
||||
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
|
||||
return strings.TrimSpace(ip[:commaIndex])
|
||||
}
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
if ip := r.Header.Get("X-Client-IP"); ip != "" {
|
||||
return strings.TrimSpace(ip)
|
||||
}
|
||||
|
||||
// 直接连接
|
||||
if r.RemoteAddr != "" {
|
||||
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
|
||||
return r.RemoteAddr[:colonIndex]
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从上下文中获取用户ID
|
||||
func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string {
|
||||
// 这里需要根据你的JWT实现来获取用户ID
|
||||
// 示例实现
|
||||
if claims, ok := ctx.Value("claims").(map[string]interface{}); ok {
|
||||
if userID, exists := claims["userId"]; exists {
|
||||
return fmt.Sprintf("%v", userID)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// isIPBlacklisted 检查IP是否在黑名单中
|
||||
func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool {
|
||||
ip := m.getClientIP(r)
|
||||
key := fmt.Sprintf("security:blacklist:ip:%s", ip)
|
||||
|
||||
exists, err := m.redis.Exists(key)
|
||||
if err != nil {
|
||||
logx.Errorf("检查IP黑名单失败: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// isUserBlacklisted 检查用户是否在黑名单中
|
||||
func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool {
|
||||
userID := m.getUserIDFromContext(ctx)
|
||||
if userID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("security:blacklist:user:%s", userID)
|
||||
exists, err := m.redis.Exists(key)
|
||||
if err != nil {
|
||||
logx.Errorf("检查用户黑名单失败: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
// checkRateLimit 检查频率限制
|
||||
func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool {
|
||||
key := fmt.Sprintf("security:ratelimit:%s", clientID)
|
||||
|
||||
// 获取当前计数
|
||||
current, err := m.redis.Get(key)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取频率限制计数失败: %v", err)
|
||||
return true // 出错时允许请求
|
||||
}
|
||||
logx.Infof("current: %s", current)
|
||||
var count int64
|
||||
if current != "" {
|
||||
count, _ = strconv.ParseInt(current, 10, 64)
|
||||
}
|
||||
|
||||
// 检查是否超过限制
|
||||
if count >= m.config.RateLimit.MaxRequests {
|
||||
// 频率限制触发,记录触发次数
|
||||
m.recordRateLimitTrigger(clientID)
|
||||
return false
|
||||
}
|
||||
|
||||
// 增加计数
|
||||
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("更新频率限制计数失败: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// recordRateLimitTrigger 记录频率限制触发次数
|
||||
func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) {
|
||||
// 记录IP触发频率限制的次数
|
||||
if strings.HasPrefix(clientID, "ip:") {
|
||||
ip := strings.TrimPrefix(clientID, "ip:")
|
||||
triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip)
|
||||
|
||||
// 增加触发次数
|
||||
err := m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
pipe.Incr(context.Background(), triggerKey)
|
||||
pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("记录频率限制触发次数失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否达到黑名单阈值
|
||||
triggerCount, err := m.redis.Get(triggerKey)
|
||||
if err == nil && triggerCount != "" {
|
||||
if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值
|
||||
logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow)
|
||||
m.addToBlacklist(clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// checkBurstAttack 检查短时并发攻击
|
||||
func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool {
|
||||
// 检查是否启用短时并发攻击检测
|
||||
if !m.config.BurstAttack.Enabled {
|
||||
return true
|
||||
}
|
||||
|
||||
// 只对IP进行检查,用户级别的并发检测在业务层处理
|
||||
if !strings.HasPrefix(clientID, "ip:") {
|
||||
return true
|
||||
}
|
||||
|
||||
ip := strings.TrimPrefix(clientID, "ip:")
|
||||
burstKey := fmt.Sprintf("security:burst:%s", ip)
|
||||
|
||||
// 使用Redis的原子操作检查短时并发
|
||||
// 使用配置的时间窗口
|
||||
current, err := m.redis.Get(burstKey)
|
||||
if err != nil && err != redis.Nil {
|
||||
logx.Errorf("获取短时并发计数失败: %v", err)
|
||||
return false // 出错时阻止请求
|
||||
}
|
||||
|
||||
var count int64
|
||||
if current != "" {
|
||||
count, _ = strconv.ParseInt(current, 10, 64)
|
||||
}
|
||||
|
||||
// 如果指定时间内并发请求超过阈值,认为是爆破攻击
|
||||
if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值
|
||||
logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow)
|
||||
m.addToBlacklist(clientID)
|
||||
return false
|
||||
}
|
||||
|
||||
// 增加并发计数并设置过期时间
|
||||
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
||||
pipe.Incr(ctx, burstKey)
|
||||
pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("更新短时并发计数失败: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// detectAnomaly 异常检测
|
||||
func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool {
|
||||
// 检测可疑的请求特征
|
||||
suspicious := false
|
||||
|
||||
// 1. 检查User-Agent
|
||||
userAgent := r.Header.Get("User-Agent")
|
||||
if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") {
|
||||
suspicious = true
|
||||
}
|
||||
|
||||
// 2. 检查请求频率异常
|
||||
clientID := m.getClientID(r)
|
||||
key := fmt.Sprintf("security:anomaly:%s", clientID)
|
||||
|
||||
if suspicious {
|
||||
// 记录异常
|
||||
m.redis.Incr(key)
|
||||
m.redis.Expire(key, 3600) // 1小时过期
|
||||
|
||||
// 如果异常次数过多,加入黑名单
|
||||
count, _ := m.redis.Get(key)
|
||||
if count != "" {
|
||||
if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 {
|
||||
m.addToBlacklist(clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return suspicious
|
||||
}
|
||||
|
||||
// addToBlacklist 添加到黑名单
|
||||
func (m *SecurityMiddleware) addToBlacklist(clientID string) {
|
||||
var key string
|
||||
var expireTime time.Duration
|
||||
|
||||
if strings.HasPrefix(clientID, "user:") {
|
||||
key = fmt.Sprintf("security:blacklist:%s", clientID)
|
||||
expireTime = 24 * time.Hour // 用户黑名单24小时
|
||||
} else {
|
||||
key = fmt.Sprintf("security:blacklist:%s", clientID)
|
||||
expireTime = 1 * time.Hour // IP黑名单1小时
|
||||
}
|
||||
|
||||
m.redis.Setex(key, "1", int(expireTime.Seconds()))
|
||||
logx.Infof("已将 %s 加入黑名单", clientID)
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) {
|
||||
logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s",
|
||||
clientID, r.Method, r.URL.Path, m.getClientIP(r))
|
||||
}
|
||||
@@ -0,0 +1,441 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
|
||||
"github.com/zeromicro/go-zero/core/stores/redis"
|
||||
)
|
||||
|
||||
// 测试报告结构体
|
||||
type TestReport struct {
|
||||
TestName string
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Duration time.Duration
|
||||
TotalTests int
|
||||
PassedTests int
|
||||
FailedTests int
|
||||
TestResults map[string]TestResult
|
||||
Performance PerformanceMetrics
|
||||
RedisStats RedisStats
|
||||
}
|
||||
|
||||
// 单个测试结果
|
||||
type TestResult struct {
|
||||
Name string
|
||||
Status string // "PASS" | "FAIL"
|
||||
Duration time.Duration
|
||||
Error string
|
||||
Details map[string]interface{}
|
||||
}
|
||||
|
||||
// 性能指标
|
||||
type PerformanceMetrics struct {
|
||||
TotalRequests int
|
||||
AverageResponseTime time.Duration
|
||||
MinResponseTime time.Duration
|
||||
MaxResponseTime time.Duration
|
||||
RateLimitHits int
|
||||
BlacklistHits int
|
||||
AnomalyDetections int
|
||||
}
|
||||
|
||||
// Redis统计信息
|
||||
type RedisStats struct {
|
||||
TotalKeys int
|
||||
BlacklistKeys int
|
||||
RateLimitKeys int
|
||||
AnomalyKeys int
|
||||
MemoryUsage string
|
||||
}
|
||||
|
||||
// 全局测试报告
|
||||
var globalTestReport *TestReport
|
||||
|
||||
// 集成测试:需要真实的Redis环境
|
||||
// 运行前请确保Redis服务已启动
|
||||
|
||||
func TestSecurityMiddlewareIntegration(t *testing.T) {
|
||||
// 跳过集成测试,除非明确要求
|
||||
// t.Skip("跳过集成测试,需要真实Redis环境")
|
||||
|
||||
// 初始化测试报告
|
||||
globalTestReport = &TestReport{
|
||||
TestName: "SecurityMiddleware集成测试",
|
||||
StartTime: time.Now(),
|
||||
TestResults: make(map[string]TestResult),
|
||||
}
|
||||
|
||||
// 创建Redis连接
|
||||
redisClient, err := redis.NewRedis(redis.RedisConf{
|
||||
Host: "127.0.0.1:20002",
|
||||
Pass: "3m3WsgyCKWqz",
|
||||
Type: "node",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("连接Redis失败: %v", err)
|
||||
}
|
||||
// Redis连接不需要手动关闭,go-zero会自动管理
|
||||
|
||||
// 创建测试配置
|
||||
config := &config.SecurityConfig{
|
||||
RateLimit: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
|
||||
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
|
||||
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
|
||||
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
|
||||
}{
|
||||
Enabled: true,
|
||||
WindowSize: 10, // 10秒窗口
|
||||
MaxRequests: 3, // 最多3次请求
|
||||
TriggerThreshold: 3, // 3次触发后拉黑
|
||||
TriggerWindow: 24, // 24小时内统计
|
||||
},
|
||||
IPBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
UserBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
AnomalyDetection: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
BurstAttack: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
|
||||
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
|
||||
}{
|
||||
Enabled: true,
|
||||
TimeWindow: 1, // 1秒检测窗口
|
||||
MaxConcurrent: 15, // 最大15个并发请求
|
||||
},
|
||||
}
|
||||
|
||||
middleware := NewSecurityMiddleware(config, redisClient)
|
||||
|
||||
// 测试频率限制
|
||||
t.Run("RateLimit", func(t *testing.T) {
|
||||
testRateLimit(t, middleware, redisClient)
|
||||
})
|
||||
|
||||
// 测试IP黑名单
|
||||
t.Run("IPBlacklist", func(t *testing.T) {
|
||||
testIPBlacklist(t, middleware, redisClient)
|
||||
})
|
||||
|
||||
// 测试异常检测
|
||||
t.Run("AnomalyDetection", func(t *testing.T) {
|
||||
testAnomalyDetection(t, middleware, redisClient)
|
||||
})
|
||||
|
||||
// 收集Redis统计信息
|
||||
collectRedisStats(redisClient)
|
||||
|
||||
// 生成并打印测试报告
|
||||
generateTestReport(t)
|
||||
}
|
||||
|
||||
// collectRedisStats 收集Redis统计信息
|
||||
func collectRedisStats(redis *redis.Redis) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 统计各种类型的键数量
|
||||
blacklistKeys, _ := redis.Keys("security:blacklist:*")
|
||||
rateLimitKeys, _ := redis.Keys("security:ratelimit:*")
|
||||
anomalyKeys, _ := redis.Keys("security:anomaly:*")
|
||||
allKeys, _ := redis.Keys("security:*")
|
||||
|
||||
globalTestReport.RedisStats = RedisStats{
|
||||
TotalKeys: len(allKeys),
|
||||
BlacklistKeys: len(blacklistKeys),
|
||||
RateLimitKeys: len(rateLimitKeys),
|
||||
AnomalyKeys: len(anomalyKeys),
|
||||
MemoryUsage: "N/A", // Redis内存使用信息需要额外命令
|
||||
}
|
||||
}
|
||||
|
||||
// recordTestResult 记录测试结果到全局报告
|
||||
func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
globalTestReport.TestResults[name] = TestResult{
|
||||
Name: name,
|
||||
Status: status,
|
||||
Duration: duration,
|
||||
Error: errorMsg,
|
||||
Details: details,
|
||||
}
|
||||
}
|
||||
|
||||
// generateTestReport 生成并打印测试报告
|
||||
func generateTestReport(t *testing.T) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 计算测试总时长
|
||||
globalTestReport.EndTime = time.Now()
|
||||
globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime)
|
||||
|
||||
// 统计测试结果
|
||||
globalTestReport.TotalTests = len(globalTestReport.TestResults)
|
||||
for _, result := range globalTestReport.TestResults {
|
||||
if result.Status == "PASS" {
|
||||
globalTestReport.PassedTests++
|
||||
} else {
|
||||
globalTestReport.FailedTests++
|
||||
}
|
||||
}
|
||||
|
||||
// 打印测试报告
|
||||
printTestReport(t)
|
||||
}
|
||||
|
||||
// printTestReport 打印详细的测试报告
|
||||
func printTestReport(t *testing.T) {
|
||||
if globalTestReport == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 使用fmt包来格式化输出
|
||||
fmt.Println("\n" + strings.Repeat("=", 80))
|
||||
fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告")
|
||||
fmt.Println(strings.Repeat("=", 80))
|
||||
|
||||
// 基本信息
|
||||
fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName)
|
||||
fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05"))
|
||||
fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05"))
|
||||
fmt.Printf("⏱️ 总耗时: %v\n", globalTestReport.Duration)
|
||||
|
||||
// 测试结果统计
|
||||
fmt.Printf("\n📊 测试结果统计:\n")
|
||||
fmt.Printf(" 总测试数: %d\n", globalTestReport.TotalTests)
|
||||
fmt.Printf(" 通过测试: %d ✅\n", globalTestReport.PassedTests)
|
||||
fmt.Printf(" 失败测试: %d ❌\n", globalTestReport.FailedTests)
|
||||
|
||||
if globalTestReport.TotalTests > 0 {
|
||||
passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100
|
||||
fmt.Printf(" 通过率: %.1f%%\n", passRate)
|
||||
}
|
||||
|
||||
// 详细测试结果
|
||||
if len(globalTestReport.TestResults) > 0 {
|
||||
fmt.Printf("\n📝 详细测试结果:\n")
|
||||
for name, result := range globalTestReport.TestResults {
|
||||
statusIcon := "✅"
|
||||
if result.Status == "FAIL" {
|
||||
statusIcon = "❌"
|
||||
}
|
||||
fmt.Printf(" %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration)
|
||||
if result.Error != "" {
|
||||
fmt.Printf(" 错误: %s\n", result.Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 性能指标
|
||||
fmt.Printf("\n🚀 性能指标:\n")
|
||||
fmt.Printf(" 总请求数: %d\n", globalTestReport.Performance.TotalRequests)
|
||||
fmt.Printf(" 平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime)
|
||||
fmt.Printf(" 频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits)
|
||||
fmt.Printf(" 黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits)
|
||||
fmt.Printf(" 异常检测: %d\n", globalTestReport.Performance.AnomalyDetections)
|
||||
|
||||
// Redis统计
|
||||
fmt.Printf("\n🗄️ Redis统计:\n")
|
||||
fmt.Printf(" 总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys)
|
||||
fmt.Printf(" 黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys)
|
||||
fmt.Printf(" 频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys)
|
||||
fmt.Printf(" 异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys)
|
||||
|
||||
// 测试总结
|
||||
fmt.Printf("\n📈 测试总结:\n")
|
||||
if globalTestReport.FailedTests == 0 {
|
||||
fmt.Printf(" 🎉 所有测试通过!安全中间件运行正常。\n")
|
||||
} else {
|
||||
fmt.Printf(" ⚠️ 有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests)
|
||||
}
|
||||
|
||||
fmt.Println(strings.Repeat("=", 80))
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
||||
startTime := time.Now()
|
||||
testName := "频率限制测试"
|
||||
|
||||
// 清理之前的测试数据
|
||||
redis.Del("security:ratelimit:ip:192.168.1.100")
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
|
||||
successCount := 0
|
||||
rateLimitHits := 0
|
||||
|
||||
// 前3次请求应该成功
|
||||
for i := 0; i < 3; i++ {
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
if w.Code == http.StatusOK {
|
||||
successCount++
|
||||
} else {
|
||||
t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
// 第4次请求应该被拒绝
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
if w.Code == http.StatusTooManyRequests {
|
||||
rateLimitHits++
|
||||
} else {
|
||||
t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 等待窗口过期后再次测试
|
||||
time.Sleep(11 * time.Second)
|
||||
w = httptest.NewRecorder()
|
||||
handler(w, req)
|
||||
|
||||
if w.Code == http.StatusOK {
|
||||
successCount++
|
||||
} else {
|
||||
t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 记录测试结果
|
||||
duration := time.Since(startTime)
|
||||
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
||||
"successCount": successCount,
|
||||
"rateLimitHits": rateLimitHits,
|
||||
"totalRequests": 5,
|
||||
})
|
||||
|
||||
// 更新性能指标
|
||||
if globalTestReport != nil {
|
||||
globalTestReport.Performance.TotalRequests += 5
|
||||
globalTestReport.Performance.RateLimitHits += rateLimitHits
|
||||
}
|
||||
}
|
||||
|
||||
func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
||||
startTime := time.Now()
|
||||
testName := "IP黑名单测试"
|
||||
|
||||
// 添加IP到黑名单
|
||||
redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.200")
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
blacklistHit := false
|
||||
if w.Code == http.StatusForbidden {
|
||||
blacklistHit = true
|
||||
} else {
|
||||
t.Errorf("黑名单IP应该被拒绝,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 清理测试数据
|
||||
redis.Del("security:blacklist:ip:192.168.1.200")
|
||||
|
||||
// 记录测试结果
|
||||
duration := time.Since(startTime)
|
||||
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
||||
"blacklistHit": blacklistHit,
|
||||
"blockedIP": "192.168.1.200",
|
||||
})
|
||||
|
||||
// 更新性能指标
|
||||
if globalTestReport != nil {
|
||||
globalTestReport.Performance.TotalRequests++
|
||||
if blacklistHit {
|
||||
globalTestReport.Performance.BlacklistHits++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
||||
startTime := time.Now()
|
||||
testName := "异常检测测试"
|
||||
|
||||
// 测试空User-Agent
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
// 不设置User-Agent
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
handler(w, req)
|
||||
|
||||
anomalyDetected := false
|
||||
// 异常检测不应该阻止请求,只是记录
|
||||
if w.Code == http.StatusOK {
|
||||
anomalyDetected = true
|
||||
} else {
|
||||
t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code)
|
||||
}
|
||||
|
||||
// 检查是否记录了异常
|
||||
key := "security:anomaly:ip:192.168.1.100"
|
||||
exists, _ := redis.Exists(key)
|
||||
if !exists {
|
||||
t.Log("异常检测记录可能已过期或未记录")
|
||||
}
|
||||
|
||||
// 记录测试结果
|
||||
duration := time.Since(startTime)
|
||||
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
||||
"anomalyDetected": anomalyDetected,
|
||||
"anomalyRecorded": exists,
|
||||
"testIP": "192.168.1.100",
|
||||
})
|
||||
|
||||
// 更新性能指标
|
||||
if globalTestReport != nil {
|
||||
globalTestReport.Performance.TotalRequests++
|
||||
if anomalyDetected {
|
||||
globalTestReport.Performance.AnomalyDetections++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 性能基准测试
|
||||
func BenchmarkSecurityMiddlewarePerformance(b *testing.B) {
|
||||
// 跳过基准测试,除非明确要求
|
||||
b.Skip("跳过基准测试,需要真实Redis环境")
|
||||
}
|
||||
@@ -0,0 +1,150 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"tyc-server/app/main/api/internal/config"
|
||||
)
|
||||
|
||||
// 创建测试配置
|
||||
func createTestConfig() *config.SecurityConfig {
|
||||
return &config.SecurityConfig{
|
||||
RateLimit: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
|
||||
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
|
||||
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
|
||||
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
|
||||
}{
|
||||
Enabled: true,
|
||||
WindowSize: 60,
|
||||
MaxRequests: 5,
|
||||
TriggerThreshold: 5,
|
||||
TriggerWindow: 24,
|
||||
},
|
||||
IPBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
UserBlacklist: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
AnomalyDetection: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
}{
|
||||
Enabled: true,
|
||||
},
|
||||
BurstAttack: struct {
|
||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
||||
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
|
||||
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
|
||||
}{
|
||||
Enabled: true,
|
||||
TimeWindow: 1,
|
||||
MaxConcurrent: 20,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// 测试客户端标识生成
|
||||
func TestClientIDGeneration(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
// 使用nil Redis进行测试,只测试不依赖Redis的逻辑
|
||||
middleware := NewSecurityMiddleware(config, nil)
|
||||
|
||||
// 测试IP标识
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Real-IP", "192.168.1.100")
|
||||
|
||||
clientID := middleware.getClientID(req)
|
||||
expected := "ip:192.168.1.100"
|
||||
if clientID != expected {
|
||||
t.Errorf("期望客户端标识 %s,但得到了 %s", expected, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试真实IP获取
|
||||
func TestRealIPExtraction(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
middleware := NewSecurityMiddleware(config, nil)
|
||||
|
||||
// 测试X-Forwarded-For
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Forwarded-For", "203.0.113.1, 192.168.1.1")
|
||||
|
||||
ip := middleware.getClientIP(req)
|
||||
expected := "203.0.113.1"
|
||||
if ip != expected {
|
||||
t.Errorf("期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
|
||||
// 测试X-Real-IP(创建新的请求对象)
|
||||
req2 := httptest.NewRequest("GET", "/test", nil)
|
||||
req2.Header.Set("X-Real-IP", "198.51.100.1")
|
||||
ip = middleware.getClientIP(req2)
|
||||
expected = "198.51.100.1"
|
||||
if ip != expected {
|
||||
t.Errorf("期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
|
||||
// 测试直接连接
|
||||
req3 := httptest.NewRequest("GET", "/test", nil)
|
||||
req3.RemoteAddr = "192.168.1.50:12345"
|
||||
ip = middleware.getClientIP(req3)
|
||||
expected = "192.168.1.50"
|
||||
if ip != expected {
|
||||
t.Errorf("期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
|
||||
// 测试优先级:X-Forwarded-For 应该优先于 X-Real-IP
|
||||
req4 := httptest.NewRequest("GET", "/test", nil)
|
||||
req4.Header.Set("X-Forwarded-For", "10.0.0.1, 10.0.0.2")
|
||||
req4.Header.Set("X-Real-IP", "10.0.0.3")
|
||||
ip = middleware.getClientIP(req4)
|
||||
expected = "10.0.0.1" // X-Forwarded-For 应该优先
|
||||
if ip != expected {
|
||||
t.Errorf("优先级测试失败:期望IP %s,但得到了 %s", expected, ip)
|
||||
}
|
||||
}
|
||||
|
||||
// 测试中间件创建
|
||||
func TestNewSecurityMiddleware(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
middleware := NewSecurityMiddleware(config, nil)
|
||||
|
||||
if middleware == nil {
|
||||
t.Error("中间件创建失败")
|
||||
}
|
||||
|
||||
if middleware.config != config {
|
||||
t.Error("配置设置失败")
|
||||
}
|
||||
}
|
||||
|
||||
// 测试配置验证
|
||||
func TestConfigValidation(t *testing.T) {
|
||||
config := createTestConfig()
|
||||
|
||||
if !config.RateLimit.Enabled {
|
||||
t.Error("频率限制应该启用")
|
||||
}
|
||||
|
||||
if config.RateLimit.MaxRequests != 5 {
|
||||
t.Error("最大请求数设置错误")
|
||||
}
|
||||
|
||||
if !config.IPBlacklist.Enabled {
|
||||
t.Error("IP黑名单应该启用")
|
||||
}
|
||||
|
||||
if !config.UserBlacklist.Enabled {
|
||||
t.Error("用户黑名单应该启用")
|
||||
}
|
||||
|
||||
if !config.AnomalyDetection.Enabled {
|
||||
t.Error("异常检测应该启用")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user