151 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			151 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package security | |||
|  | 
 | |||
|  | import ( | |||
|  | 	"net/http/httptest" | |||
|  | 	"testing" | |||
|  | 	"tyc-server/app/main/api/internal/config" | |||
|  | ) | |||
|  | 
 | |||
|  | // 创建测试配置 | |||
|  | func createTestConfig() *config.SecurityConfig { | |||
|  | 	return &config.SecurityConfig{ | |||
|  | 		RateLimit: struct { | |||
|  | 			Enabled          bool  `json:"enabled" yaml:"enabled"` | |||
|  | 			WindowSize       int64 `json:"windowSize" yaml:"windowSize"` | |||
|  | 			MaxRequests      int64 `json:"maxRequests" yaml:"maxRequests"` | |||
|  | 			TriggerThreshold int64 `json:"triggerThreshold" yaml:"triggerThreshold"` | |||
|  | 			TriggerWindow    int64 `json:"triggerWindow" yaml:"triggerWindow"` | |||
|  | 		}{ | |||
|  | 			Enabled:          true, | |||
|  | 			WindowSize:       60, | |||
|  | 			MaxRequests:      5, | |||
|  | 			TriggerThreshold: 5, | |||
|  | 			TriggerWindow:    24, | |||
|  | 		}, | |||
|  | 		IPBlacklist: struct { | |||
|  | 			Enabled bool `json:"enabled" yaml:"enabled"` | |||
|  | 		}{ | |||
|  | 			Enabled: true, | |||
|  | 		}, | |||
|  | 		UserBlacklist: struct { | |||
|  | 			Enabled bool `json:"enabled" yaml:"enabled"` | |||
|  | 		}{ | |||
|  | 			Enabled: true, | |||
|  | 		}, | |||
|  | 		AnomalyDetection: struct { | |||
|  | 			Enabled bool `json:"enabled" yaml:"enabled"` | |||
|  | 		}{ | |||
|  | 			Enabled: true, | |||
|  | 		}, | |||
|  | 		BurstAttack: struct { | |||
|  | 			Enabled       bool  `json:"enabled" yaml:"enabled"` | |||
|  | 			TimeWindow    int64 `json:"timeWindow" yaml:"timeWindow"` | |||
|  | 			MaxConcurrent int64 `json:"maxConcurrent" yaml:"maxConcurrent"` | |||
|  | 		}{ | |||
|  | 			Enabled:       true, | |||
|  | 			TimeWindow:    1, | |||
|  | 			MaxConcurrent: 20, | |||
|  | 		}, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // 测试客户端标识生成 | |||
|  | func TestClientIDGeneration(t *testing.T) { | |||
|  | 	config := createTestConfig() | |||
|  | 	// 使用nil Redis进行测试,只测试不依赖Redis的逻辑 | |||
|  | 	middleware := NewSecurityMiddleware(config, nil) | |||
|  | 
 | |||
|  | 	// 测试IP标识 | |||
|  | 	req := httptest.NewRequest("GET", "/test", nil) | |||
|  | 	req.Header.Set("X-Real-IP", "192.168.1.100") | |||
|  | 
 | |||
|  | 	clientID := middleware.getClientID(req) | |||
|  | 	expected := "ip:192.168.1.100" | |||
|  | 	if clientID != expected { | |||
|  | 		t.Errorf("期望客户端标识 %s,但得到了 %s", expected, clientID) | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // 测试真实IP获取 | |||
|  | func TestRealIPExtraction(t *testing.T) { | |||
|  | 	config := createTestConfig() | |||
|  | 	middleware := NewSecurityMiddleware(config, nil) | |||
|  | 
 | |||
|  | 	// 测试X-Forwarded-For | |||
|  | 	req := httptest.NewRequest("GET", "/test", nil) | |||
|  | 	req.Header.Set("X-Forwarded-For", "203.0.113.1, 192.168.1.1") | |||
|  | 
 | |||
|  | 	ip := middleware.getClientIP(req) | |||
|  | 	expected := "203.0.113.1" | |||
|  | 	if ip != expected { | |||
|  | 		t.Errorf("期望IP %s,但得到了 %s", expected, ip) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 测试X-Real-IP(创建新的请求对象) | |||
|  | 	req2 := httptest.NewRequest("GET", "/test", nil) | |||
|  | 	req2.Header.Set("X-Real-IP", "198.51.100.1") | |||
|  | 	ip = middleware.getClientIP(req2) | |||
|  | 	expected = "198.51.100.1" | |||
|  | 	if ip != expected { | |||
|  | 		t.Errorf("期望IP %s,但得到了 %s", expected, ip) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 测试直接连接 | |||
|  | 	req3 := httptest.NewRequest("GET", "/test", nil) | |||
|  | 	req3.RemoteAddr = "192.168.1.50:12345" | |||
|  | 	ip = middleware.getClientIP(req3) | |||
|  | 	expected = "192.168.1.50" | |||
|  | 	if ip != expected { | |||
|  | 		t.Errorf("期望IP %s,但得到了 %s", expected, ip) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 测试优先级:X-Forwarded-For 应该优先于 X-Real-IP | |||
|  | 	req4 := httptest.NewRequest("GET", "/test", nil) | |||
|  | 	req4.Header.Set("X-Forwarded-For", "10.0.0.1, 10.0.0.2") | |||
|  | 	req4.Header.Set("X-Real-IP", "10.0.0.3") | |||
|  | 	ip = middleware.getClientIP(req4) | |||
|  | 	expected = "10.0.0.1" // X-Forwarded-For 应该优先 | |||
|  | 	if ip != expected { | |||
|  | 		t.Errorf("优先级测试失败:期望IP %s,但得到了 %s", expected, ip) | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // 测试中间件创建 | |||
|  | func TestNewSecurityMiddleware(t *testing.T) { | |||
|  | 	config := createTestConfig() | |||
|  | 	middleware := NewSecurityMiddleware(config, nil) | |||
|  | 
 | |||
|  | 	if middleware == nil { | |||
|  | 		t.Error("中间件创建失败") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if middleware.config != config { | |||
|  | 		t.Error("配置设置失败") | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // 测试配置验证 | |||
|  | func TestConfigValidation(t *testing.T) { | |||
|  | 	config := createTestConfig() | |||
|  | 
 | |||
|  | 	if !config.RateLimit.Enabled { | |||
|  | 		t.Error("频率限制应该启用") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if config.RateLimit.MaxRequests != 5 { | |||
|  | 		t.Error("最大请求数设置错误") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if !config.IPBlacklist.Enabled { | |||
|  | 		t.Error("IP黑名单应该启用") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if !config.UserBlacklist.Enabled { | |||
|  | 		t.Error("用户黑名单应该启用") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if !config.AnomalyDetection.Enabled { | |||
|  | 		t.Error("异常检测应该启用") | |||
|  | 	} | |||
|  | } |