Files
tyc-server/app/main/api/internal/middleware/security/securityMiddleware_integration_test.go
2025-08-31 14:18:31 +08:00

442 lines
12 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package security
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"tyc-server/app/main/api/internal/config"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// 测试报告结构体
type TestReport struct {
TestName string
StartTime time.Time
EndTime time.Time
Duration time.Duration
TotalTests int
PassedTests int
FailedTests int
TestResults map[string]TestResult
Performance PerformanceMetrics
RedisStats RedisStats
}
// 单个测试结果
type TestResult struct {
Name string
Status string // "PASS" | "FAIL"
Duration time.Duration
Error string
Details map[string]interface{}
}
// 性能指标
type PerformanceMetrics struct {
TotalRequests int
AverageResponseTime time.Duration
MinResponseTime time.Duration
MaxResponseTime time.Duration
RateLimitHits int
BlacklistHits int
AnomalyDetections int
}
// Redis统计信息
type RedisStats struct {
TotalKeys int
BlacklistKeys int
RateLimitKeys int
AnomalyKeys int
MemoryUsage string
}
// 全局测试报告
var globalTestReport *TestReport
// 集成测试需要真实的Redis环境
// 运行前请确保Redis服务已启动
func TestSecurityMiddlewareIntegration(t *testing.T) {
// 跳过集成测试,除非明确要求
// t.Skip("跳过集成测试需要真实Redis环境")
// 初始化测试报告
globalTestReport = &TestReport{
TestName: "SecurityMiddleware集成测试",
StartTime: time.Now(),
TestResults: make(map[string]TestResult),
}
// 创建Redis连接
redisClient, err := redis.NewRedis(redis.RedisConf{
Host: "127.0.0.1:20002",
Pass: "3m3WsgyCKWqz",
Type: "node",
})
if err != nil {
t.Fatalf("连接Redis失败: %v", err)
}
// Redis连接不需要手动关闭go-zero会自动管理
// 创建测试配置
config := &config.SecurityConfig{
RateLimit: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
}{
Enabled: true,
WindowSize: 10, // 10秒窗口
MaxRequests: 3, // 最多3次请求
TriggerThreshold: 3, // 3次触发后拉黑
TriggerWindow: 24, // 24小时内统计
},
IPBlacklist: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
UserBlacklist: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
AnomalyDetection: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
}{
Enabled: true,
},
BurstAttack: struct {
Enabled bool `json:"enabled" yaml:"enabled"`
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
}{
Enabled: true,
TimeWindow: 1, // 1秒检测窗口
MaxConcurrent: 15, // 最大15个并发请求
},
}
middleware := NewSecurityMiddleware(config, redisClient)
// 测试频率限制
t.Run("RateLimit", func(t *testing.T) {
testRateLimit(t, middleware, redisClient)
})
// 测试IP黑名单
t.Run("IPBlacklist", func(t *testing.T) {
testIPBlacklist(t, middleware, redisClient)
})
// 测试异常检测
t.Run("AnomalyDetection", func(t *testing.T) {
testAnomalyDetection(t, middleware, redisClient)
})
// 收集Redis统计信息
collectRedisStats(redisClient)
// 生成并打印测试报告
generateTestReport(t)
}
// collectRedisStats 收集Redis统计信息
func collectRedisStats(redis *redis.Redis) {
if globalTestReport == nil {
return
}
// 统计各种类型的键数量
blacklistKeys, _ := redis.Keys("security:blacklist:*")
rateLimitKeys, _ := redis.Keys("security:ratelimit:*")
anomalyKeys, _ := redis.Keys("security:anomaly:*")
allKeys, _ := redis.Keys("security:*")
globalTestReport.RedisStats = RedisStats{
TotalKeys: len(allKeys),
BlacklistKeys: len(blacklistKeys),
RateLimitKeys: len(rateLimitKeys),
AnomalyKeys: len(anomalyKeys),
MemoryUsage: "N/A", // Redis内存使用信息需要额外命令
}
}
// recordTestResult 记录测试结果到全局报告
func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) {
if globalTestReport == nil {
return
}
globalTestReport.TestResults[name] = TestResult{
Name: name,
Status: status,
Duration: duration,
Error: errorMsg,
Details: details,
}
}
// generateTestReport 生成并打印测试报告
func generateTestReport(t *testing.T) {
if globalTestReport == nil {
return
}
// 计算测试总时长
globalTestReport.EndTime = time.Now()
globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime)
// 统计测试结果
globalTestReport.TotalTests = len(globalTestReport.TestResults)
for _, result := range globalTestReport.TestResults {
if result.Status == "PASS" {
globalTestReport.PassedTests++
} else {
globalTestReport.FailedTests++
}
}
// 打印测试报告
printTestReport(t)
}
// printTestReport 打印详细的测试报告
func printTestReport(t *testing.T) {
if globalTestReport == nil {
return
}
// 使用fmt包来格式化输出
fmt.Println("\n" + strings.Repeat("=", 80))
fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告")
fmt.Println(strings.Repeat("=", 80))
// 基本信息
fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName)
fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05"))
fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05"))
fmt.Printf("⏱️ 总耗时: %v\n", globalTestReport.Duration)
// 测试结果统计
fmt.Printf("\n📊 测试结果统计:\n")
fmt.Printf(" 总测试数: %d\n", globalTestReport.TotalTests)
fmt.Printf(" 通过测试: %d ✅\n", globalTestReport.PassedTests)
fmt.Printf(" 失败测试: %d ❌\n", globalTestReport.FailedTests)
if globalTestReport.TotalTests > 0 {
passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100
fmt.Printf(" 通过率: %.1f%%\n", passRate)
}
// 详细测试结果
if len(globalTestReport.TestResults) > 0 {
fmt.Printf("\n📝 详细测试结果:\n")
for name, result := range globalTestReport.TestResults {
statusIcon := "✅"
if result.Status == "FAIL" {
statusIcon = "❌"
}
fmt.Printf(" %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration)
if result.Error != "" {
fmt.Printf(" 错误: %s\n", result.Error)
}
}
}
// 性能指标
fmt.Printf("\n🚀 性能指标:\n")
fmt.Printf(" 总请求数: %d\n", globalTestReport.Performance.TotalRequests)
fmt.Printf(" 平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime)
fmt.Printf(" 频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits)
fmt.Printf(" 黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits)
fmt.Printf(" 异常检测: %d\n", globalTestReport.Performance.AnomalyDetections)
// Redis统计
fmt.Printf("\n🗄 Redis统计:\n")
fmt.Printf(" 总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys)
fmt.Printf(" 黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys)
fmt.Printf(" 频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys)
fmt.Printf(" 异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys)
// 测试总结
fmt.Printf("\n📈 测试总结:\n")
if globalTestReport.FailedTests == 0 {
fmt.Printf(" 🎉 所有测试通过!安全中间件运行正常。\n")
} else {
fmt.Printf(" ⚠️ 有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests)
}
fmt.Println(strings.Repeat("=", 80))
fmt.Println()
}
func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
startTime := time.Now()
testName := "频率限制测试"
// 清理之前的测试数据
redis.Del("security:ratelimit:ip:192.168.1.100")
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.100")
successCount := 0
rateLimitHits := 0
// 前3次请求应该成功
for i := 0; i < 3; i++ {
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
if w.Code == http.StatusOK {
successCount++
} else {
t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code)
}
}
// 第4次请求应该被拒绝
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
if w.Code == http.StatusTooManyRequests {
rateLimitHits++
} else {
t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code)
}
// 等待窗口过期后再次测试
time.Sleep(11 * time.Second)
w = httptest.NewRecorder()
handler(w, req)
if w.Code == http.StatusOK {
successCount++
} else {
t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code)
}
// 记录测试结果
duration := time.Since(startTime)
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
"successCount": successCount,
"rateLimitHits": rateLimitHits,
"totalRequests": 5,
})
// 更新性能指标
if globalTestReport != nil {
globalTestReport.Performance.TotalRequests += 5
globalTestReport.Performance.RateLimitHits += rateLimitHits
}
}
func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
startTime := time.Now()
testName := "IP黑名单测试"
// 添加IP到黑名单
redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600)
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.200")
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
blacklistHit := false
if w.Code == http.StatusForbidden {
blacklistHit = true
} else {
t.Errorf("黑名单IP应该被拒绝但得到了状态码 %d", w.Code)
}
// 清理测试数据
redis.Del("security:blacklist:ip:192.168.1.200")
// 记录测试结果
duration := time.Since(startTime)
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
"blacklistHit": blacklistHit,
"blockedIP": "192.168.1.200",
})
// 更新性能指标
if globalTestReport != nil {
globalTestReport.Performance.TotalRequests++
if blacklistHit {
globalTestReport.Performance.BlacklistHits++
}
}
}
func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
startTime := time.Now()
testName := "异常检测测试"
// 测试空User-Agent
req := httptest.NewRequest("GET", "/test", nil)
req.Header.Set("X-Real-IP", "192.168.1.100")
// 不设置User-Agent
w := httptest.NewRecorder()
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
handler(w, req)
anomalyDetected := false
// 异常检测不应该阻止请求,只是记录
if w.Code == http.StatusOK {
anomalyDetected = true
} else {
t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code)
}
// 检查是否记录了异常
key := "security:anomaly:ip:192.168.1.100"
exists, _ := redis.Exists(key)
if !exists {
t.Log("异常检测记录可能已过期或未记录")
}
// 记录测试结果
duration := time.Since(startTime)
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
"anomalyDetected": anomalyDetected,
"anomalyRecorded": exists,
"testIP": "192.168.1.100",
})
// 更新性能指标
if globalTestReport != nil {
globalTestReport.Performance.TotalRequests++
if anomalyDetected {
globalTestReport.Performance.AnomalyDetections++
}
}
}
// 性能基准测试
func BenchmarkSecurityMiddlewarePerformance(b *testing.B) {
// 跳过基准测试,除非明确要求
b.Skip("跳过基准测试需要真实Redis环境")
}