package security import ( "fmt" "net/http" "net/http/httptest" "strings" "testing" "time" "tyc-server/app/main/api/internal/config" "github.com/zeromicro/go-zero/core/stores/redis" ) // 测试报告结构体 type TestReport struct { TestName string StartTime time.Time EndTime time.Time Duration time.Duration TotalTests int PassedTests int FailedTests int TestResults map[string]TestResult Performance PerformanceMetrics RedisStats RedisStats } // 单个测试结果 type TestResult struct { Name string Status string // "PASS" | "FAIL" Duration time.Duration Error string Details map[string]interface{} } // 性能指标 type PerformanceMetrics struct { TotalRequests int AverageResponseTime time.Duration MinResponseTime time.Duration MaxResponseTime time.Duration RateLimitHits int BlacklistHits int AnomalyDetections int } // Redis统计信息 type RedisStats struct { TotalKeys int BlacklistKeys int RateLimitKeys int AnomalyKeys int MemoryUsage string } // 全局测试报告 var globalTestReport *TestReport // 集成测试:需要真实的Redis环境 // 运行前请确保Redis服务已启动 func TestSecurityMiddlewareIntegration(t *testing.T) { // 跳过集成测试,除非明确要求 // t.Skip("跳过集成测试,需要真实Redis环境") // 初始化测试报告 globalTestReport = &TestReport{ TestName: "SecurityMiddleware集成测试", StartTime: time.Now(), TestResults: make(map[string]TestResult), } // 创建Redis连接 redisClient, err := redis.NewRedis(redis.RedisConf{ Host: "127.0.0.1:20002", Pass: "3m3WsgyCKWqz", Type: "node", }) if err != nil { t.Fatalf("连接Redis失败: %v", err) } // Redis连接不需要手动关闭,go-zero会自动管理 // 创建测试配置 config := &config.SecurityConfig{ RateLimit: struct { Enabled bool `json:"enabled" yaml:"enabled"` WindowSize int64 `json:"windowSize" yaml:"windowSize"` MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"` TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"` }{ Enabled: true, WindowSize: 10, // 10秒窗口 MaxRequests: 3, // 最多3次请求 TriggerThreshold: 3, // 3次触发后拉黑 TriggerWindow: 24, // 24小时内统计 }, IPBlacklist: struct { Enabled bool `json:"enabled" yaml:"enabled"` }{ Enabled: true, }, UserBlacklist: struct { Enabled bool `json:"enabled" yaml:"enabled"` }{ Enabled: true, }, AnomalyDetection: struct { Enabled bool `json:"enabled" yaml:"enabled"` }{ Enabled: true, }, BurstAttack: struct { Enabled bool `json:"enabled" yaml:"enabled"` TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"` MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` }{ Enabled: true, TimeWindow: 1, // 1秒检测窗口 MaxConcurrent: 15, // 最大15个并发请求 }, } middleware := NewSecurityMiddleware(config, redisClient) // 测试频率限制 t.Run("RateLimit", func(t *testing.T) { testRateLimit(t, middleware, redisClient) }) // 测试IP黑名单 t.Run("IPBlacklist", func(t *testing.T) { testIPBlacklist(t, middleware, redisClient) }) // 测试异常检测 t.Run("AnomalyDetection", func(t *testing.T) { testAnomalyDetection(t, middleware, redisClient) }) // 收集Redis统计信息 collectRedisStats(redisClient) // 生成并打印测试报告 generateTestReport(t) } // collectRedisStats 收集Redis统计信息 func collectRedisStats(redis *redis.Redis) { if globalTestReport == nil { return } // 统计各种类型的键数量 blacklistKeys, _ := redis.Keys("security:blacklist:*") rateLimitKeys, _ := redis.Keys("security:ratelimit:*") anomalyKeys, _ := redis.Keys("security:anomaly:*") allKeys, _ := redis.Keys("security:*") globalTestReport.RedisStats = RedisStats{ TotalKeys: len(allKeys), BlacklistKeys: len(blacklistKeys), RateLimitKeys: len(rateLimitKeys), AnomalyKeys: len(anomalyKeys), MemoryUsage: "N/A", // Redis内存使用信息需要额外命令 } } // recordTestResult 记录测试结果到全局报告 func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) { if globalTestReport == nil { return } globalTestReport.TestResults[name] = TestResult{ Name: name, Status: status, Duration: duration, Error: errorMsg, Details: details, } } // generateTestReport 生成并打印测试报告 func generateTestReport(t *testing.T) { if globalTestReport == nil { return } // 计算测试总时长 globalTestReport.EndTime = time.Now() globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime) // 统计测试结果 globalTestReport.TotalTests = len(globalTestReport.TestResults) for _, result := range globalTestReport.TestResults { if result.Status == "PASS" { globalTestReport.PassedTests++ } else { globalTestReport.FailedTests++ } } // 打印测试报告 printTestReport(t) } // printTestReport 打印详细的测试报告 func printTestReport(t *testing.T) { if globalTestReport == nil { return } // 使用fmt包来格式化输出 fmt.Println("\n" + strings.Repeat("=", 80)) fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告") fmt.Println(strings.Repeat("=", 80)) // 基本信息 fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName) fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05")) fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05")) fmt.Printf("⏱️ 总耗时: %v\n", globalTestReport.Duration) // 测试结果统计 fmt.Printf("\n📊 测试结果统计:\n") fmt.Printf(" 总测试数: %d\n", globalTestReport.TotalTests) fmt.Printf(" 通过测试: %d ✅\n", globalTestReport.PassedTests) fmt.Printf(" 失败测试: %d ❌\n", globalTestReport.FailedTests) if globalTestReport.TotalTests > 0 { passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100 fmt.Printf(" 通过率: %.1f%%\n", passRate) } // 详细测试结果 if len(globalTestReport.TestResults) > 0 { fmt.Printf("\n📝 详细测试结果:\n") for name, result := range globalTestReport.TestResults { statusIcon := "✅" if result.Status == "FAIL" { statusIcon = "❌" } fmt.Printf(" %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration) if result.Error != "" { fmt.Printf(" 错误: %s\n", result.Error) } } } // 性能指标 fmt.Printf("\n🚀 性能指标:\n") fmt.Printf(" 总请求数: %d\n", globalTestReport.Performance.TotalRequests) fmt.Printf(" 平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime) fmt.Printf(" 频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits) fmt.Printf(" 黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits) fmt.Printf(" 异常检测: %d\n", globalTestReport.Performance.AnomalyDetections) // Redis统计 fmt.Printf("\n🗄️ Redis统计:\n") fmt.Printf(" 总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys) fmt.Printf(" 黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys) fmt.Printf(" 频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys) fmt.Printf(" 异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys) // 测试总结 fmt.Printf("\n📈 测试总结:\n") if globalTestReport.FailedTests == 0 { fmt.Printf(" 🎉 所有测试通过!安全中间件运行正常。\n") } else { fmt.Printf(" ⚠️ 有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests) } fmt.Println(strings.Repeat("=", 80)) fmt.Println() } func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) { startTime := time.Now() testName := "频率限制测试" // 清理之前的测试数据 redis.Del("security:ratelimit:ip:192.168.1.100") req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Real-IP", "192.168.1.100") successCount := 0 rateLimitHits := 0 // 前3次请求应该成功 for i := 0; i < 3; i++ { w := httptest.NewRecorder() handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler(w, req) if w.Code == http.StatusOK { successCount++ } else { t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code) } } // 第4次请求应该被拒绝 w := httptest.NewRecorder() handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler(w, req) if w.Code == http.StatusTooManyRequests { rateLimitHits++ } else { t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code) } // 等待窗口过期后再次测试 time.Sleep(11 * time.Second) w = httptest.NewRecorder() handler(w, req) if w.Code == http.StatusOK { successCount++ } else { t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code) } // 记录测试结果 duration := time.Since(startTime) recordTestResult(testName, "PASS", duration, "", map[string]interface{}{ "successCount": successCount, "rateLimitHits": rateLimitHits, "totalRequests": 5, }) // 更新性能指标 if globalTestReport != nil { globalTestReport.Performance.TotalRequests += 5 globalTestReport.Performance.RateLimitHits += rateLimitHits } } func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) { startTime := time.Now() testName := "IP黑名单测试" // 添加IP到黑名单 redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600) req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Real-IP", "192.168.1.200") w := httptest.NewRecorder() handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler(w, req) blacklistHit := false if w.Code == http.StatusForbidden { blacklistHit = true } else { t.Errorf("黑名单IP应该被拒绝,但得到了状态码 %d", w.Code) } // 清理测试数据 redis.Del("security:blacklist:ip:192.168.1.200") // 记录测试结果 duration := time.Since(startTime) recordTestResult(testName, "PASS", duration, "", map[string]interface{}{ "blacklistHit": blacklistHit, "blockedIP": "192.168.1.200", }) // 更新性能指标 if globalTestReport != nil { globalTestReport.Performance.TotalRequests++ if blacklistHit { globalTestReport.Performance.BlacklistHits++ } } } func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) { startTime := time.Now() testName := "异常检测测试" // 测试空User-Agent req := httptest.NewRequest("GET", "/test", nil) req.Header.Set("X-Real-IP", "192.168.1.100") // 不设置User-Agent w := httptest.NewRecorder() handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) handler(w, req) anomalyDetected := false // 异常检测不应该阻止请求,只是记录 if w.Code == http.StatusOK { anomalyDetected = true } else { t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code) } // 检查是否记录了异常 key := "security:anomaly:ip:192.168.1.100" exists, _ := redis.Exists(key) if !exists { t.Log("异常检测记录可能已过期或未记录") } // 记录测试结果 duration := time.Since(startTime) recordTestResult(testName, "PASS", duration, "", map[string]interface{}{ "anomalyDetected": anomalyDetected, "anomalyRecorded": exists, "testIP": "192.168.1.100", }) // 更新性能指标 if globalTestReport != nil { globalTestReport.Performance.TotalRequests++ if anomalyDetected { globalTestReport.Performance.AnomalyDetections++ } } } // 性能基准测试 func BenchmarkSecurityMiddlewarePerformance(b *testing.B) { // 跳过基准测试,除非明确要求 b.Skip("跳过基准测试,需要真实Redis环境") }