Files
tyc-server/app/main/api/internal/middleware/security/securityMiddleware.go
2025-08-31 14:18:31 +08:00

342 lines
9.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package security
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"tyc-server/app/main/api/internal/config"
"github.com/zeromicro/go-zero/core/logx"
"github.com/zeromicro/go-zero/core/stores/redis"
)
// SecurityMiddleware 安全防护中间件
type SecurityMiddleware struct {
config *config.SecurityConfig
redis *redis.Redis
}
// NewSecurityMiddleware 创建安全中间件
func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware {
return &SecurityMiddleware{
config: config,
redis: redis,
}
}
// Handle 处理请求
func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// 1. 获取客户端标识
clientID := m.getClientID(r)
// 2. IP黑名单检查
if m.config.IPBlacklist.Enabled {
if m.isIPBlacklisted(r) {
logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r))
http.Error(w, "访问被拒绝", http.StatusForbidden)
return
}
}
// 3. 用户黑名单检查
if m.config.UserBlacklist.Enabled {
if m.isUserBlacklisted(ctx, r) {
logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID)
http.Error(w, "访问被拒绝", http.StatusForbidden)
return
}
}
// 4. 短时并发攻击检测
if !m.checkBurstAttack(ctx, clientID, r) {
logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID)
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
return
}
// 5. 频率限制检查
if m.config.RateLimit.Enabled {
if !m.checkRateLimit(ctx, clientID, r) {
logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID)
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
return
}
}
// 6. 异常检测
if m.config.AnomalyDetection.Enabled {
if m.detectAnomaly(ctx, r) {
logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID)
// 记录异常但不阻止请求,用于监控
}
}
// 7. 记录请求日志
m.logRequest(ctx, r, clientID)
// 继续处理请求
next(w, r)
}
}
// getClientID 获取客户端唯一标识
func (m *SecurityMiddleware) getClientID(r *http.Request) string {
// 优先使用用户ID如果已认证
if userID := m.getUserIDFromContext(r.Context()); userID != "" {
return fmt.Sprintf("user:%s", userID)
}
// 使用IP地址作为标识
return fmt.Sprintf("ip:%s", m.getClientIP(r))
}
// getClientIP 获取客户端真实IP
func (m *SecurityMiddleware) getClientIP(r *http.Request) string {
// 检查代理头
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
// 取第一个IP最原始的客户端IP
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
return strings.TrimSpace(ip[:commaIndex])
}
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Real-IP"); ip != "" {
return strings.TrimSpace(ip)
}
if ip := r.Header.Get("X-Client-IP"); ip != "" {
return strings.TrimSpace(ip)
}
// 直接连接
if r.RemoteAddr != "" {
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
return r.RemoteAddr[:colonIndex]
}
return r.RemoteAddr
}
return "unknown"
}
// getUserIDFromContext 从上下文中获取用户ID
func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string {
// 这里需要根据你的JWT实现来获取用户ID
// 示例实现
if claims, ok := ctx.Value("claims").(map[string]interface{}); ok {
if userID, exists := claims["userId"]; exists {
return fmt.Sprintf("%v", userID)
}
}
return ""
}
// isIPBlacklisted 检查IP是否在黑名单中
func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool {
ip := m.getClientIP(r)
key := fmt.Sprintf("security:blacklist:ip:%s", ip)
exists, err := m.redis.Exists(key)
if err != nil {
logx.Errorf("检查IP黑名单失败: %v", err)
return false
}
return exists
}
// isUserBlacklisted 检查用户是否在黑名单中
func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool {
userID := m.getUserIDFromContext(ctx)
if userID == "" {
return false
}
key := fmt.Sprintf("security:blacklist:user:%s", userID)
exists, err := m.redis.Exists(key)
if err != nil {
logx.Errorf("检查用户黑名单失败: %v", err)
return false
}
return exists
}
// checkRateLimit 检查频率限制
func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool {
key := fmt.Sprintf("security:ratelimit:%s", clientID)
// 获取当前计数
current, err := m.redis.Get(key)
if err != nil && err != redis.Nil {
logx.Errorf("获取频率限制计数失败: %v", err)
return true // 出错时允许请求
}
logx.Infof("current: %s", current)
var count int64
if current != "" {
count, _ = strconv.ParseInt(current, 10, 64)
}
// 检查是否超过限制
if count >= m.config.RateLimit.MaxRequests {
// 频率限制触发,记录触发次数
m.recordRateLimitTrigger(clientID)
return false
}
// 增加计数
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second)
return nil
})
if err != nil {
logx.Errorf("更新频率限制计数失败: %v", err)
}
return true
}
// recordRateLimitTrigger 记录频率限制触发次数
func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) {
// 记录IP触发频率限制的次数
if strings.HasPrefix(clientID, "ip:") {
ip := strings.TrimPrefix(clientID, "ip:")
triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip)
// 增加触发次数
err := m.redis.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Incr(context.Background(), triggerKey)
pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口
return nil
})
if err != nil {
logx.Errorf("记录频率限制触发次数失败: %v", err)
return
}
// 检查是否达到黑名单阈值
triggerCount, err := m.redis.Get(triggerKey)
if err == nil && triggerCount != "" {
if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值
logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow)
m.addToBlacklist(clientID)
}
}
}
}
// checkBurstAttack 检查短时并发攻击
func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool {
// 检查是否启用短时并发攻击检测
if !m.config.BurstAttack.Enabled {
return true
}
// 只对IP进行检查用户级别的并发检测在业务层处理
if !strings.HasPrefix(clientID, "ip:") {
return true
}
ip := strings.TrimPrefix(clientID, "ip:")
burstKey := fmt.Sprintf("security:burst:%s", ip)
// 使用Redis的原子操作检查短时并发
// 使用配置的时间窗口
current, err := m.redis.Get(burstKey)
if err != nil && err != redis.Nil {
logx.Errorf("获取短时并发计数失败: %v", err)
return false // 出错时阻止请求
}
var count int64
if current != "" {
count, _ = strconv.ParseInt(current, 10, 64)
}
// 如果指定时间内并发请求超过阈值,认为是爆破攻击
if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值
logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow)
m.addToBlacklist(clientID)
return false
}
// 增加并发计数并设置过期时间
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
pipe.Incr(ctx, burstKey)
pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口
return nil
})
if err != nil {
logx.Errorf("更新短时并发计数失败: %v", err)
}
return true
}
// detectAnomaly 异常检测
func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool {
// 检测可疑的请求特征
suspicious := false
// 1. 检查User-Agent
userAgent := r.Header.Get("User-Agent")
if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") {
suspicious = true
}
// 2. 检查请求频率异常
clientID := m.getClientID(r)
key := fmt.Sprintf("security:anomaly:%s", clientID)
if suspicious {
// 记录异常
m.redis.Incr(key)
m.redis.Expire(key, 3600) // 1小时过期
// 如果异常次数过多,加入黑名单
count, _ := m.redis.Get(key)
if count != "" {
if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 {
m.addToBlacklist(clientID)
}
}
}
return suspicious
}
// addToBlacklist 添加到黑名单
func (m *SecurityMiddleware) addToBlacklist(clientID string) {
var key string
var expireTime time.Duration
if strings.HasPrefix(clientID, "user:") {
key = fmt.Sprintf("security:blacklist:%s", clientID)
expireTime = 24 * time.Hour // 用户黑名单24小时
} else {
key = fmt.Sprintf("security:blacklist:%s", clientID)
expireTime = 1 * time.Hour // IP黑名单1小时
}
m.redis.Setex(key, "1", int(expireTime.Seconds()))
logx.Infof("已将 %s 加入黑名单", clientID)
}
// logRequest 记录请求日志
func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) {
logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s",
clientID, r.Method, r.URL.Path, m.getClientIP(r))
}