package security import ( "context" "fmt" "net/http" "strconv" "strings" "time" "tyc-server/app/main/api/internal/config" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/redis" ) // SecurityMiddleware 安全防护中间件 type SecurityMiddleware struct { config *config.SecurityConfig redis *redis.Redis } // NewSecurityMiddleware 创建安全中间件 func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware { return &SecurityMiddleware{ config: config, redis: redis, } } // Handle 处理请求 func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // 1. 获取客户端标识 clientID := m.getClientID(r) // 2. IP黑名单检查 if m.config.IPBlacklist.Enabled { if m.isIPBlacklisted(r) { logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r)) http.Error(w, "访问被拒绝", http.StatusForbidden) return } } // 3. 用户黑名单检查 if m.config.UserBlacklist.Enabled { if m.isUserBlacklisted(ctx, r) { logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID) http.Error(w, "访问被拒绝", http.StatusForbidden) return } } // 4. 短时并发攻击检测 if !m.checkBurstAttack(ctx, clientID, r) { logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID) http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests) return } // 5. 频率限制检查 if m.config.RateLimit.Enabled { if !m.checkRateLimit(ctx, clientID, r) { logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID) http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests) return } } // 6. 异常检测 if m.config.AnomalyDetection.Enabled { if m.detectAnomaly(ctx, r) { logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID) // 记录异常但不阻止请求,用于监控 } } // 7. 记录请求日志 m.logRequest(ctx, r, clientID) // 继续处理请求 next(w, r) } } // getClientID 获取客户端唯一标识 func (m *SecurityMiddleware) getClientID(r *http.Request) string { // 优先使用用户ID(如果已认证) if userID := m.getUserIDFromContext(r.Context()); userID != "" { return fmt.Sprintf("user:%s", userID) } // 使用IP地址作为标识 return fmt.Sprintf("ip:%s", m.getClientIP(r)) } // getClientIP 获取客户端真实IP func (m *SecurityMiddleware) getClientIP(r *http.Request) string { // 检查代理头 if ip := r.Header.Get("X-Forwarded-For"); ip != "" { // 取第一个IP(最原始的客户端IP) if commaIndex := strings.Index(ip, ","); commaIndex != -1 { return strings.TrimSpace(ip[:commaIndex]) } return strings.TrimSpace(ip) } if ip := r.Header.Get("X-Real-IP"); ip != "" { return strings.TrimSpace(ip) } if ip := r.Header.Get("X-Client-IP"); ip != "" { return strings.TrimSpace(ip) } // 直接连接 if r.RemoteAddr != "" { if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 { return r.RemoteAddr[:colonIndex] } return r.RemoteAddr } return "unknown" } // getUserIDFromContext 从上下文中获取用户ID func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string { // 这里需要根据你的JWT实现来获取用户ID // 示例实现 if claims, ok := ctx.Value("claims").(map[string]interface{}); ok { if userID, exists := claims["userId"]; exists { return fmt.Sprintf("%v", userID) } } return "" } // isIPBlacklisted 检查IP是否在黑名单中 func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool { ip := m.getClientIP(r) key := fmt.Sprintf("security:blacklist:ip:%s", ip) exists, err := m.redis.Exists(key) if err != nil { logx.Errorf("检查IP黑名单失败: %v", err) return false } return exists } // isUserBlacklisted 检查用户是否在黑名单中 func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool { userID := m.getUserIDFromContext(ctx) if userID == "" { return false } key := fmt.Sprintf("security:blacklist:user:%s", userID) exists, err := m.redis.Exists(key) if err != nil { logx.Errorf("检查用户黑名单失败: %v", err) return false } return exists } // checkRateLimit 检查频率限制 func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool { key := fmt.Sprintf("security:ratelimit:%s", clientID) // 获取当前计数 current, err := m.redis.Get(key) if err != nil && err != redis.Nil { logx.Errorf("获取频率限制计数失败: %v", err) return true // 出错时允许请求 } logx.Infof("current: %s", current) var count int64 if current != "" { count, _ = strconv.ParseInt(current, 10, 64) } // 检查是否超过限制 if count >= m.config.RateLimit.MaxRequests { // 频率限制触发,记录触发次数 m.recordRateLimitTrigger(clientID) return false } // 增加计数 err = m.redis.Pipelined(func(pipe redis.Pipeliner) error { pipe.Incr(ctx, key) pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second) return nil }) if err != nil { logx.Errorf("更新频率限制计数失败: %v", err) } return true } // recordRateLimitTrigger 记录频率限制触发次数 func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) { // 记录IP触发频率限制的次数 if strings.HasPrefix(clientID, "ip:") { ip := strings.TrimPrefix(clientID, "ip:") triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip) // 增加触发次数 err := m.redis.Pipelined(func(pipe redis.Pipeliner) error { pipe.Incr(context.Background(), triggerKey) pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口 return nil }) if err != nil { logx.Errorf("记录频率限制触发次数失败: %v", err) return } // 检查是否达到黑名单阈值 triggerCount, err := m.redis.Get(triggerKey) if err == nil && triggerCount != "" { if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值 logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow) m.addToBlacklist(clientID) } } } } // checkBurstAttack 检查短时并发攻击 func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool { // 检查是否启用短时并发攻击检测 if !m.config.BurstAttack.Enabled { return true } // 只对IP进行检查,用户级别的并发检测在业务层处理 if !strings.HasPrefix(clientID, "ip:") { return true } ip := strings.TrimPrefix(clientID, "ip:") burstKey := fmt.Sprintf("security:burst:%s", ip) // 使用Redis的原子操作检查短时并发 // 使用配置的时间窗口 current, err := m.redis.Get(burstKey) if err != nil && err != redis.Nil { logx.Errorf("获取短时并发计数失败: %v", err) return false // 出错时阻止请求 } var count int64 if current != "" { count, _ = strconv.ParseInt(current, 10, 64) } // 如果指定时间内并发请求超过阈值,认为是爆破攻击 if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值 logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow) m.addToBlacklist(clientID) return false } // 增加并发计数并设置过期时间 err = m.redis.Pipelined(func(pipe redis.Pipeliner) error { pipe.Incr(ctx, burstKey) pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口 return nil }) if err != nil { logx.Errorf("更新短时并发计数失败: %v", err) } return true } // detectAnomaly 异常检测 func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool { // 检测可疑的请求特征 suspicious := false // 1. 检查User-Agent userAgent := r.Header.Get("User-Agent") if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") { suspicious = true } // 2. 检查请求频率异常 clientID := m.getClientID(r) key := fmt.Sprintf("security:anomaly:%s", clientID) if suspicious { // 记录异常 m.redis.Incr(key) m.redis.Expire(key, 3600) // 1小时过期 // 如果异常次数过多,加入黑名单 count, _ := m.redis.Get(key) if count != "" { if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 { m.addToBlacklist(clientID) } } } return suspicious } // addToBlacklist 添加到黑名单 func (m *SecurityMiddleware) addToBlacklist(clientID string) { var key string var expireTime time.Duration if strings.HasPrefix(clientID, "user:") { key = fmt.Sprintf("security:blacklist:%s", clientID) expireTime = 24 * time.Hour // 用户黑名单24小时 } else { key = fmt.Sprintf("security:blacklist:%s", clientID) expireTime = 1 * time.Hour // IP黑名单1小时 } m.redis.Setex(key, "1", int(expireTime.Seconds())) logx.Infof("已将 %s 加入黑名单", clientID) } // logRequest 记录请求日志 func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) { logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s", clientID, r.Method, r.URL.Path, m.getClientIP(r)) }