442 lines
12 KiB
Go
442 lines
12 KiB
Go
|
|
package security
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"fmt"
|
|||
|
|
"net/http"
|
|||
|
|
"net/http/httptest"
|
|||
|
|
"strings"
|
|||
|
|
"testing"
|
|||
|
|
"time"
|
|||
|
|
"tyc-server/app/main/api/internal/config"
|
|||
|
|
|
|||
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// 测试报告结构体
|
|||
|
|
type TestReport struct {
|
|||
|
|
TestName string
|
|||
|
|
StartTime time.Time
|
|||
|
|
EndTime time.Time
|
|||
|
|
Duration time.Duration
|
|||
|
|
TotalTests int
|
|||
|
|
PassedTests int
|
|||
|
|
FailedTests int
|
|||
|
|
TestResults map[string]TestResult
|
|||
|
|
Performance PerformanceMetrics
|
|||
|
|
RedisStats RedisStats
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 单个测试结果
|
|||
|
|
type TestResult struct {
|
|||
|
|
Name string
|
|||
|
|
Status string // "PASS" | "FAIL"
|
|||
|
|
Duration time.Duration
|
|||
|
|
Error string
|
|||
|
|
Details map[string]interface{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 性能指标
|
|||
|
|
type PerformanceMetrics struct {
|
|||
|
|
TotalRequests int
|
|||
|
|
AverageResponseTime time.Duration
|
|||
|
|
MinResponseTime time.Duration
|
|||
|
|
MaxResponseTime time.Duration
|
|||
|
|
RateLimitHits int
|
|||
|
|
BlacklistHits int
|
|||
|
|
AnomalyDetections int
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Redis统计信息
|
|||
|
|
type RedisStats struct {
|
|||
|
|
TotalKeys int
|
|||
|
|
BlacklistKeys int
|
|||
|
|
RateLimitKeys int
|
|||
|
|
AnomalyKeys int
|
|||
|
|
MemoryUsage string
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 全局测试报告
|
|||
|
|
var globalTestReport *TestReport
|
|||
|
|
|
|||
|
|
// 集成测试:需要真实的Redis环境
|
|||
|
|
// 运行前请确保Redis服务已启动
|
|||
|
|
|
|||
|
|
func TestSecurityMiddlewareIntegration(t *testing.T) {
|
|||
|
|
// 跳过集成测试,除非明确要求
|
|||
|
|
// t.Skip("跳过集成测试,需要真实Redis环境")
|
|||
|
|
|
|||
|
|
// 初始化测试报告
|
|||
|
|
globalTestReport = &TestReport{
|
|||
|
|
TestName: "SecurityMiddleware集成测试",
|
|||
|
|
StartTime: time.Now(),
|
|||
|
|
TestResults: make(map[string]TestResult),
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建Redis连接
|
|||
|
|
redisClient, err := redis.NewRedis(redis.RedisConf{
|
|||
|
|
Host: "127.0.0.1:20002",
|
|||
|
|
Pass: "3m3WsgyCKWqz",
|
|||
|
|
Type: "node",
|
|||
|
|
})
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("连接Redis失败: %v", err)
|
|||
|
|
}
|
|||
|
|
// Redis连接不需要手动关闭,go-zero会自动管理
|
|||
|
|
|
|||
|
|
// 创建测试配置
|
|||
|
|
config := &config.SecurityConfig{
|
|||
|
|
RateLimit: struct {
|
|||
|
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
|||
|
|
WindowSize int64 `json:"windowSize" yaml:"windowSize"`
|
|||
|
|
MaxRequests int64 `json:"maxRequests" yaml:"maxRequests"`
|
|||
|
|
TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"`
|
|||
|
|
TriggerWindow int64 `json:"triggerWindow" yaml:"triggerWindow"`
|
|||
|
|
}{
|
|||
|
|
Enabled: true,
|
|||
|
|
WindowSize: 10, // 10秒窗口
|
|||
|
|
MaxRequests: 3, // 最多3次请求
|
|||
|
|
TriggerThreshold: 3, // 3次触发后拉黑
|
|||
|
|
TriggerWindow: 24, // 24小时内统计
|
|||
|
|
},
|
|||
|
|
IPBlacklist: struct {
|
|||
|
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
|||
|
|
}{
|
|||
|
|
Enabled: true,
|
|||
|
|
},
|
|||
|
|
UserBlacklist: struct {
|
|||
|
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
|||
|
|
}{
|
|||
|
|
Enabled: true,
|
|||
|
|
},
|
|||
|
|
AnomalyDetection: struct {
|
|||
|
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
|||
|
|
}{
|
|||
|
|
Enabled: true,
|
|||
|
|
},
|
|||
|
|
BurstAttack: struct {
|
|||
|
|
Enabled bool `json:"enabled" yaml:"enabled"`
|
|||
|
|
TimeWindow int64 `json:"timeWindow" yaml:"timeWindow"`
|
|||
|
|
MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"`
|
|||
|
|
}{
|
|||
|
|
Enabled: true,
|
|||
|
|
TimeWindow: 1, // 1秒检测窗口
|
|||
|
|
MaxConcurrent: 15, // 最大15个并发请求
|
|||
|
|
},
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
middleware := NewSecurityMiddleware(config, redisClient)
|
|||
|
|
|
|||
|
|
// 测试频率限制
|
|||
|
|
t.Run("RateLimit", func(t *testing.T) {
|
|||
|
|
testRateLimit(t, middleware, redisClient)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 测试IP黑名单
|
|||
|
|
t.Run("IPBlacklist", func(t *testing.T) {
|
|||
|
|
testIPBlacklist(t, middleware, redisClient)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 测试异常检测
|
|||
|
|
t.Run("AnomalyDetection", func(t *testing.T) {
|
|||
|
|
testAnomalyDetection(t, middleware, redisClient)
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 收集Redis统计信息
|
|||
|
|
collectRedisStats(redisClient)
|
|||
|
|
|
|||
|
|
// 生成并打印测试报告
|
|||
|
|
generateTestReport(t)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// collectRedisStats 收集Redis统计信息
|
|||
|
|
func collectRedisStats(redis *redis.Redis) {
|
|||
|
|
if globalTestReport == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 统计各种类型的键数量
|
|||
|
|
blacklistKeys, _ := redis.Keys("security:blacklist:*")
|
|||
|
|
rateLimitKeys, _ := redis.Keys("security:ratelimit:*")
|
|||
|
|
anomalyKeys, _ := redis.Keys("security:anomaly:*")
|
|||
|
|
allKeys, _ := redis.Keys("security:*")
|
|||
|
|
|
|||
|
|
globalTestReport.RedisStats = RedisStats{
|
|||
|
|
TotalKeys: len(allKeys),
|
|||
|
|
BlacklistKeys: len(blacklistKeys),
|
|||
|
|
RateLimitKeys: len(rateLimitKeys),
|
|||
|
|
AnomalyKeys: len(anomalyKeys),
|
|||
|
|
MemoryUsage: "N/A", // Redis内存使用信息需要额外命令
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// recordTestResult 记录测试结果到全局报告
|
|||
|
|
func recordTestResult(name, status string, duration time.Duration, errorMsg string, details map[string]interface{}) {
|
|||
|
|
if globalTestReport == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
globalTestReport.TestResults[name] = TestResult{
|
|||
|
|
Name: name,
|
|||
|
|
Status: status,
|
|||
|
|
Duration: duration,
|
|||
|
|
Error: errorMsg,
|
|||
|
|
Details: details,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// generateTestReport 生成并打印测试报告
|
|||
|
|
func generateTestReport(t *testing.T) {
|
|||
|
|
if globalTestReport == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 计算测试总时长
|
|||
|
|
globalTestReport.EndTime = time.Now()
|
|||
|
|
globalTestReport.Duration = globalTestReport.EndTime.Sub(globalTestReport.StartTime)
|
|||
|
|
|
|||
|
|
// 统计测试结果
|
|||
|
|
globalTestReport.TotalTests = len(globalTestReport.TestResults)
|
|||
|
|
for _, result := range globalTestReport.TestResults {
|
|||
|
|
if result.Status == "PASS" {
|
|||
|
|
globalTestReport.PassedTests++
|
|||
|
|
} else {
|
|||
|
|
globalTestReport.FailedTests++
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 打印测试报告
|
|||
|
|
printTestReport(t)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// printTestReport 打印详细的测试报告
|
|||
|
|
func printTestReport(t *testing.T) {
|
|||
|
|
if globalTestReport == nil {
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 使用fmt包来格式化输出
|
|||
|
|
fmt.Println("\n" + strings.Repeat("=", 80))
|
|||
|
|
fmt.Println("🔒 SECURITY MIDDLEWARE 集成测试报告")
|
|||
|
|
fmt.Println(strings.Repeat("=", 80))
|
|||
|
|
|
|||
|
|
// 基本信息
|
|||
|
|
fmt.Printf("📋 测试名称: %s\n", globalTestReport.TestName)
|
|||
|
|
fmt.Printf("⏰ 开始时间: %s\n", globalTestReport.StartTime.Format("2006-01-02 15:04:05"))
|
|||
|
|
fmt.Printf("⏰ 结束时间: %s\n", globalTestReport.EndTime.Format("2006-01-02 15:04:05"))
|
|||
|
|
fmt.Printf("⏱️ 总耗时: %v\n", globalTestReport.Duration)
|
|||
|
|
|
|||
|
|
// 测试结果统计
|
|||
|
|
fmt.Printf("\n📊 测试结果统计:\n")
|
|||
|
|
fmt.Printf(" 总测试数: %d\n", globalTestReport.TotalTests)
|
|||
|
|
fmt.Printf(" 通过测试: %d ✅\n", globalTestReport.PassedTests)
|
|||
|
|
fmt.Printf(" 失败测试: %d ❌\n", globalTestReport.FailedTests)
|
|||
|
|
|
|||
|
|
if globalTestReport.TotalTests > 0 {
|
|||
|
|
passRate := float64(globalTestReport.PassedTests) / float64(globalTestReport.TotalTests) * 100
|
|||
|
|
fmt.Printf(" 通过率: %.1f%%\n", passRate)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 详细测试结果
|
|||
|
|
if len(globalTestReport.TestResults) > 0 {
|
|||
|
|
fmt.Printf("\n📝 详细测试结果:\n")
|
|||
|
|
for name, result := range globalTestReport.TestResults {
|
|||
|
|
statusIcon := "✅"
|
|||
|
|
if result.Status == "FAIL" {
|
|||
|
|
statusIcon = "❌"
|
|||
|
|
}
|
|||
|
|
fmt.Printf(" %s %s: %s (耗时: %v)\n", statusIcon, name, result.Status, result.Duration)
|
|||
|
|
if result.Error != "" {
|
|||
|
|
fmt.Printf(" 错误: %s\n", result.Error)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 性能指标
|
|||
|
|
fmt.Printf("\n🚀 性能指标:\n")
|
|||
|
|
fmt.Printf(" 总请求数: %d\n", globalTestReport.Performance.TotalRequests)
|
|||
|
|
fmt.Printf(" 平均响应时间: %v\n", globalTestReport.Performance.AverageResponseTime)
|
|||
|
|
fmt.Printf(" 频率限制触发: %d\n", globalTestReport.Performance.RateLimitHits)
|
|||
|
|
fmt.Printf(" 黑名单命中: %d\n", globalTestReport.Performance.BlacklistHits)
|
|||
|
|
fmt.Printf(" 异常检测: %d\n", globalTestReport.Performance.AnomalyDetections)
|
|||
|
|
|
|||
|
|
// Redis统计
|
|||
|
|
fmt.Printf("\n🗄️ Redis统计:\n")
|
|||
|
|
fmt.Printf(" 总安全键数: %d\n", globalTestReport.RedisStats.TotalKeys)
|
|||
|
|
fmt.Printf(" 黑名单键数: %d\n", globalTestReport.RedisStats.BlacklistKeys)
|
|||
|
|
fmt.Printf(" 频率限制键数: %d\n", globalTestReport.RedisStats.RateLimitKeys)
|
|||
|
|
fmt.Printf(" 异常检测键数: %d\n", globalTestReport.RedisStats.AnomalyKeys)
|
|||
|
|
|
|||
|
|
// 测试总结
|
|||
|
|
fmt.Printf("\n📈 测试总结:\n")
|
|||
|
|
if globalTestReport.FailedTests == 0 {
|
|||
|
|
fmt.Printf(" 🎉 所有测试通过!安全中间件运行正常。\n")
|
|||
|
|
} else {
|
|||
|
|
fmt.Printf(" ⚠️ 有 %d 个测试失败,需要检查相关功能。\n", globalTestReport.FailedTests)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
fmt.Println(strings.Repeat("=", 80))
|
|||
|
|
fmt.Println()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func testRateLimit(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
|||
|
|
startTime := time.Now()
|
|||
|
|
testName := "频率限制测试"
|
|||
|
|
|
|||
|
|
// 清理之前的测试数据
|
|||
|
|
redis.Del("security:ratelimit:ip:192.168.1.100")
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|||
|
|
req.Header.Set("X-Real-IP", "192.168.1.100")
|
|||
|
|
|
|||
|
|
successCount := 0
|
|||
|
|
rateLimitHits := 0
|
|||
|
|
|
|||
|
|
// 前3次请求应该成功
|
|||
|
|
for i := 0; i < 3; i++ {
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
w.WriteHeader(http.StatusOK)
|
|||
|
|
})
|
|||
|
|
handler(w, req)
|
|||
|
|
|
|||
|
|
if w.Code == http.StatusOK {
|
|||
|
|
successCount++
|
|||
|
|
} else {
|
|||
|
|
t.Errorf("请求 %d 应该成功,但得到了状态码 %d", i+1, w.Code)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 第4次请求应该被拒绝
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
w.WriteHeader(http.StatusOK)
|
|||
|
|
})
|
|||
|
|
handler(w, req)
|
|||
|
|
|
|||
|
|
if w.Code == http.StatusTooManyRequests {
|
|||
|
|
rateLimitHits++
|
|||
|
|
} else {
|
|||
|
|
t.Errorf("超过频率限制的请求应该被拒绝,但得到了状态码 %d", w.Code)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 等待窗口过期后再次测试
|
|||
|
|
time.Sleep(11 * time.Second)
|
|||
|
|
w = httptest.NewRecorder()
|
|||
|
|
handler(w, req)
|
|||
|
|
|
|||
|
|
if w.Code == http.StatusOK {
|
|||
|
|
successCount++
|
|||
|
|
} else {
|
|||
|
|
t.Errorf("窗口过期后请求应该成功,但得到了状态码 %d", w.Code)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 记录测试结果
|
|||
|
|
duration := time.Since(startTime)
|
|||
|
|
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
|||
|
|
"successCount": successCount,
|
|||
|
|
"rateLimitHits": rateLimitHits,
|
|||
|
|
"totalRequests": 5,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 更新性能指标
|
|||
|
|
if globalTestReport != nil {
|
|||
|
|
globalTestReport.Performance.TotalRequests += 5
|
|||
|
|
globalTestReport.Performance.RateLimitHits += rateLimitHits
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func testIPBlacklist(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
|||
|
|
startTime := time.Now()
|
|||
|
|
testName := "IP黑名单测试"
|
|||
|
|
|
|||
|
|
// 添加IP到黑名单
|
|||
|
|
redis.Setex("security:blacklist:ip:192.168.1.200", "1", 3600)
|
|||
|
|
|
|||
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|||
|
|
req.Header.Set("X-Real-IP", "192.168.1.200")
|
|||
|
|
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
w.WriteHeader(http.StatusOK)
|
|||
|
|
})
|
|||
|
|
handler(w, req)
|
|||
|
|
|
|||
|
|
blacklistHit := false
|
|||
|
|
if w.Code == http.StatusForbidden {
|
|||
|
|
blacklistHit = true
|
|||
|
|
} else {
|
|||
|
|
t.Errorf("黑名单IP应该被拒绝,但得到了状态码 %d", w.Code)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 清理测试数据
|
|||
|
|
redis.Del("security:blacklist:ip:192.168.1.200")
|
|||
|
|
|
|||
|
|
// 记录测试结果
|
|||
|
|
duration := time.Since(startTime)
|
|||
|
|
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
|||
|
|
"blacklistHit": blacklistHit,
|
|||
|
|
"blockedIP": "192.168.1.200",
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 更新性能指标
|
|||
|
|
if globalTestReport != nil {
|
|||
|
|
globalTestReport.Performance.TotalRequests++
|
|||
|
|
if blacklistHit {
|
|||
|
|
globalTestReport.Performance.BlacklistHits++
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func testAnomalyDetection(t *testing.T, middleware *SecurityMiddleware, redis *redis.Redis) {
|
|||
|
|
startTime := time.Now()
|
|||
|
|
testName := "异常检测测试"
|
|||
|
|
|
|||
|
|
// 测试空User-Agent
|
|||
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|||
|
|
req.Header.Set("X-Real-IP", "192.168.1.100")
|
|||
|
|
// 不设置User-Agent
|
|||
|
|
|
|||
|
|
w := httptest.NewRecorder()
|
|||
|
|
handler := middleware.Handle(func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
w.WriteHeader(http.StatusOK)
|
|||
|
|
})
|
|||
|
|
handler(w, req)
|
|||
|
|
|
|||
|
|
anomalyDetected := false
|
|||
|
|
// 异常检测不应该阻止请求,只是记录
|
|||
|
|
if w.Code == http.StatusOK {
|
|||
|
|
anomalyDetected = true
|
|||
|
|
} else {
|
|||
|
|
t.Errorf("异常检测不应该阻止请求,但得到了状态码 %d", w.Code)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否记录了异常
|
|||
|
|
key := "security:anomaly:ip:192.168.1.100"
|
|||
|
|
exists, _ := redis.Exists(key)
|
|||
|
|
if !exists {
|
|||
|
|
t.Log("异常检测记录可能已过期或未记录")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 记录测试结果
|
|||
|
|
duration := time.Since(startTime)
|
|||
|
|
recordTestResult(testName, "PASS", duration, "", map[string]interface{}{
|
|||
|
|
"anomalyDetected": anomalyDetected,
|
|||
|
|
"anomalyRecorded": exists,
|
|||
|
|
"testIP": "192.168.1.100",
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
// 更新性能指标
|
|||
|
|
if globalTestReport != nil {
|
|||
|
|
globalTestReport.Performance.TotalRequests++
|
|||
|
|
if anomalyDetected {
|
|||
|
|
globalTestReport.Performance.AnomalyDetections++
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 性能基准测试
|
|||
|
|
func BenchmarkSecurityMiddlewarePerformance(b *testing.B) {
|
|||
|
|
// 跳过基准测试,除非明确要求
|
|||
|
|
b.Skip("跳过基准测试,需要真实Redis环境")
|
|||
|
|
}
|