342 lines
9.4 KiB
Go
342 lines
9.4 KiB
Go
|
|
package security
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"fmt"
|
|||
|
|
"net/http"
|
|||
|
|
"strconv"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
"tyc-server/app/main/api/internal/config"
|
|||
|
|
|
|||
|
|
"github.com/zeromicro/go-zero/core/logx"
|
|||
|
|
"github.com/zeromicro/go-zero/core/stores/redis"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// SecurityMiddleware 安全防护中间件
|
|||
|
|
type SecurityMiddleware struct {
|
|||
|
|
config *config.SecurityConfig
|
|||
|
|
redis *redis.Redis
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewSecurityMiddleware 创建安全中间件
|
|||
|
|
func NewSecurityMiddleware(config *config.SecurityConfig, redis *redis.Redis) *SecurityMiddleware {
|
|||
|
|
return &SecurityMiddleware{
|
|||
|
|
config: config,
|
|||
|
|
redis: redis,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Handle 处理请求
|
|||
|
|
func (m *SecurityMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
|||
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|||
|
|
ctx := r.Context()
|
|||
|
|
|
|||
|
|
// 1. 获取客户端标识
|
|||
|
|
clientID := m.getClientID(r)
|
|||
|
|
|
|||
|
|
// 2. IP黑名单检查
|
|||
|
|
if m.config.IPBlacklist.Enabled {
|
|||
|
|
if m.isIPBlacklisted(r) {
|
|||
|
|
logx.WithContext(ctx).Errorf("IP被拉黑: %s", m.getClientIP(r))
|
|||
|
|
http.Error(w, "访问被拒绝", http.StatusForbidden)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 3. 用户黑名单检查
|
|||
|
|
if m.config.UserBlacklist.Enabled {
|
|||
|
|
if m.isUserBlacklisted(ctx, r) {
|
|||
|
|
logx.WithContext(ctx).Errorf("用户被拉黑: %s", clientID)
|
|||
|
|
http.Error(w, "访问被拒绝", http.StatusForbidden)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 4. 短时并发攻击检测
|
|||
|
|
if !m.checkBurstAttack(ctx, clientID, r) {
|
|||
|
|
logx.WithContext(ctx).Errorf("检测到并发攻击: %s", clientID)
|
|||
|
|
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 5. 频率限制检查
|
|||
|
|
if m.config.RateLimit.Enabled {
|
|||
|
|
if !m.checkRateLimit(ctx, clientID, r) {
|
|||
|
|
logx.WithContext(ctx).Errorf("频率限制触发: %s", clientID)
|
|||
|
|
http.Error(w, "请求过于频繁,请稍后再试", http.StatusTooManyRequests)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 6. 异常检测
|
|||
|
|
if m.config.AnomalyDetection.Enabled {
|
|||
|
|
if m.detectAnomaly(ctx, r) {
|
|||
|
|
logx.WithContext(ctx).Errorf("检测到异常请求: %s", clientID)
|
|||
|
|
// 记录异常但不阻止请求,用于监控
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 7. 记录请求日志
|
|||
|
|
m.logRequest(ctx, r, clientID)
|
|||
|
|
|
|||
|
|
// 继续处理请求
|
|||
|
|
next(w, r)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getClientID 获取客户端唯一标识
|
|||
|
|
func (m *SecurityMiddleware) getClientID(r *http.Request) string {
|
|||
|
|
// 优先使用用户ID(如果已认证)
|
|||
|
|
if userID := m.getUserIDFromContext(r.Context()); userID != "" {
|
|||
|
|
return fmt.Sprintf("user:%s", userID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 使用IP地址作为标识
|
|||
|
|
return fmt.Sprintf("ip:%s", m.getClientIP(r))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getClientIP 获取客户端真实IP
|
|||
|
|
func (m *SecurityMiddleware) getClientIP(r *http.Request) string {
|
|||
|
|
// 检查代理头
|
|||
|
|
if ip := r.Header.Get("X-Forwarded-For"); ip != "" {
|
|||
|
|
// 取第一个IP(最原始的客户端IP)
|
|||
|
|
if commaIndex := strings.Index(ip, ","); commaIndex != -1 {
|
|||
|
|
return strings.TrimSpace(ip[:commaIndex])
|
|||
|
|
}
|
|||
|
|
return strings.TrimSpace(ip)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if ip := r.Header.Get("X-Real-IP"); ip != "" {
|
|||
|
|
return strings.TrimSpace(ip)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if ip := r.Header.Get("X-Client-IP"); ip != "" {
|
|||
|
|
return strings.TrimSpace(ip)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 直接连接
|
|||
|
|
if r.RemoteAddr != "" {
|
|||
|
|
if colonIndex := strings.LastIndex(r.RemoteAddr, ":"); colonIndex != -1 {
|
|||
|
|
return r.RemoteAddr[:colonIndex]
|
|||
|
|
}
|
|||
|
|
return r.RemoteAddr
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return "unknown"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getUserIDFromContext 从上下文中获取用户ID
|
|||
|
|
func (m *SecurityMiddleware) getUserIDFromContext(ctx context.Context) string {
|
|||
|
|
// 这里需要根据你的JWT实现来获取用户ID
|
|||
|
|
// 示例实现
|
|||
|
|
if claims, ok := ctx.Value("claims").(map[string]interface{}); ok {
|
|||
|
|
if userID, exists := claims["userId"]; exists {
|
|||
|
|
return fmt.Sprintf("%v", userID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
return ""
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// isIPBlacklisted 检查IP是否在黑名单中
|
|||
|
|
func (m *SecurityMiddleware) isIPBlacklisted(r *http.Request) bool {
|
|||
|
|
ip := m.getClientIP(r)
|
|||
|
|
key := fmt.Sprintf("security:blacklist:ip:%s", ip)
|
|||
|
|
|
|||
|
|
exists, err := m.redis.Exists(key)
|
|||
|
|
if err != nil {
|
|||
|
|
logx.Errorf("检查IP黑名单失败: %v", err)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return exists
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// isUserBlacklisted 检查用户是否在黑名单中
|
|||
|
|
func (m *SecurityMiddleware) isUserBlacklisted(ctx context.Context, r *http.Request) bool {
|
|||
|
|
userID := m.getUserIDFromContext(ctx)
|
|||
|
|
if userID == "" {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
key := fmt.Sprintf("security:blacklist:user:%s", userID)
|
|||
|
|
exists, err := m.redis.Exists(key)
|
|||
|
|
if err != nil {
|
|||
|
|
logx.Errorf("检查用户黑名单失败: %v", err)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return exists
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// checkRateLimit 检查频率限制
|
|||
|
|
func (m *SecurityMiddleware) checkRateLimit(ctx context.Context, clientID string, r *http.Request) bool {
|
|||
|
|
key := fmt.Sprintf("security:ratelimit:%s", clientID)
|
|||
|
|
|
|||
|
|
// 获取当前计数
|
|||
|
|
current, err := m.redis.Get(key)
|
|||
|
|
if err != nil && err != redis.Nil {
|
|||
|
|
logx.Errorf("获取频率限制计数失败: %v", err)
|
|||
|
|
return true // 出错时允许请求
|
|||
|
|
}
|
|||
|
|
logx.Infof("current: %s", current)
|
|||
|
|
var count int64
|
|||
|
|
if current != "" {
|
|||
|
|
count, _ = strconv.ParseInt(current, 10, 64)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否超过限制
|
|||
|
|
if count >= m.config.RateLimit.MaxRequests {
|
|||
|
|
// 频率限制触发,记录触发次数
|
|||
|
|
m.recordRateLimitTrigger(clientID)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 增加计数
|
|||
|
|
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
|||
|
|
pipe.Incr(ctx, key)
|
|||
|
|
pipe.Expire(ctx, key, time.Duration(m.config.RateLimit.WindowSize)*time.Second)
|
|||
|
|
return nil
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
logx.Errorf("更新频率限制计数失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// recordRateLimitTrigger 记录频率限制触发次数
|
|||
|
|
func (m *SecurityMiddleware) recordRateLimitTrigger(clientID string) {
|
|||
|
|
// 记录IP触发频率限制的次数
|
|||
|
|
if strings.HasPrefix(clientID, "ip:") {
|
|||
|
|
ip := strings.TrimPrefix(clientID, "ip:")
|
|||
|
|
triggerKey := fmt.Sprintf("security:ratelimit_trigger:ip:%s", ip)
|
|||
|
|
|
|||
|
|
// 增加触发次数
|
|||
|
|
err := m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
|||
|
|
pipe.Incr(context.Background(), triggerKey)
|
|||
|
|
pipe.Expire(context.Background(), triggerKey, time.Duration(m.config.RateLimit.TriggerWindow)*time.Hour) // 使用配置的时间窗口
|
|||
|
|
return nil
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
logx.Errorf("记录频率限制触发次数失败: %v", err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否达到黑名单阈值
|
|||
|
|
triggerCount, err := m.redis.Get(triggerKey)
|
|||
|
|
if err == nil && triggerCount != "" {
|
|||
|
|
if count, _ := strconv.ParseInt(triggerCount, 10, 64); count >= m.config.RateLimit.TriggerThreshold { // 使用配置的阈值
|
|||
|
|
logx.Infof("IP %s 触发频率限制次数过多(%d次/%d小时),自动加入黑名单", ip, count, m.config.RateLimit.TriggerWindow)
|
|||
|
|
m.addToBlacklist(clientID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// checkBurstAttack 检查短时并发攻击
|
|||
|
|
func (m *SecurityMiddleware) checkBurstAttack(ctx context.Context, clientID string, r *http.Request) bool {
|
|||
|
|
// 检查是否启用短时并发攻击检测
|
|||
|
|
if !m.config.BurstAttack.Enabled {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 只对IP进行检查,用户级别的并发检测在业务层处理
|
|||
|
|
if !strings.HasPrefix(clientID, "ip:") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
ip := strings.TrimPrefix(clientID, "ip:")
|
|||
|
|
burstKey := fmt.Sprintf("security:burst:%s", ip)
|
|||
|
|
|
|||
|
|
// 使用Redis的原子操作检查短时并发
|
|||
|
|
// 使用配置的时间窗口
|
|||
|
|
current, err := m.redis.Get(burstKey)
|
|||
|
|
if err != nil && err != redis.Nil {
|
|||
|
|
logx.Errorf("获取短时并发计数失败: %v", err)
|
|||
|
|
return false // 出错时阻止请求
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var count int64
|
|||
|
|
if current != "" {
|
|||
|
|
count, _ = strconv.ParseInt(current, 10, 64)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 如果指定时间内并发请求超过阈值,认为是爆破攻击
|
|||
|
|
if count >= m.config.BurstAttack.MaxConcurrent { // 使用配置的并发阈值
|
|||
|
|
logx.Errorf("检测到IP %s 的爆破攻击(%d个请求/%d秒),自动加入黑名单", ip, count, m.config.BurstAttack.TimeWindow)
|
|||
|
|
m.addToBlacklist(clientID)
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 增加并发计数并设置过期时间
|
|||
|
|
err = m.redis.Pipelined(func(pipe redis.Pipeliner) error {
|
|||
|
|
pipe.Incr(ctx, burstKey)
|
|||
|
|
pipe.Expire(ctx, burstKey, time.Duration(m.config.BurstAttack.TimeWindow)*time.Second) // 使用配置的时间窗口
|
|||
|
|
return nil
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
logx.Errorf("更新短时并发计数失败: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// detectAnomaly 异常检测
|
|||
|
|
func (m *SecurityMiddleware) detectAnomaly(ctx context.Context, r *http.Request) bool {
|
|||
|
|
// 检测可疑的请求特征
|
|||
|
|
suspicious := false
|
|||
|
|
|
|||
|
|
// 1. 检查User-Agent
|
|||
|
|
userAgent := r.Header.Get("User-Agent")
|
|||
|
|
if userAgent == "" || strings.Contains(strings.ToLower(userAgent), "bot") {
|
|||
|
|
suspicious = true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 2. 检查请求频率异常
|
|||
|
|
clientID := m.getClientID(r)
|
|||
|
|
key := fmt.Sprintf("security:anomaly:%s", clientID)
|
|||
|
|
|
|||
|
|
if suspicious {
|
|||
|
|
// 记录异常
|
|||
|
|
m.redis.Incr(key)
|
|||
|
|
m.redis.Expire(key, 3600) // 1小时过期
|
|||
|
|
|
|||
|
|
// 如果异常次数过多,加入黑名单
|
|||
|
|
count, _ := m.redis.Get(key)
|
|||
|
|
if count != "" {
|
|||
|
|
if countInt, _ := strconv.ParseInt(count, 10, 64); countInt > 10 {
|
|||
|
|
m.addToBlacklist(clientID)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return suspicious
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// addToBlacklist 添加到黑名单
|
|||
|
|
func (m *SecurityMiddleware) addToBlacklist(clientID string) {
|
|||
|
|
var key string
|
|||
|
|
var expireTime time.Duration
|
|||
|
|
|
|||
|
|
if strings.HasPrefix(clientID, "user:") {
|
|||
|
|
key = fmt.Sprintf("security:blacklist:%s", clientID)
|
|||
|
|
expireTime = 24 * time.Hour // 用户黑名单24小时
|
|||
|
|
} else {
|
|||
|
|
key = fmt.Sprintf("security:blacklist:%s", clientID)
|
|||
|
|
expireTime = 1 * time.Hour // IP黑名单1小时
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
m.redis.Setex(key, "1", int(expireTime.Seconds()))
|
|||
|
|
logx.Infof("已将 %s 加入黑名单", clientID)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// logRequest 记录请求日志
|
|||
|
|
func (m *SecurityMiddleware) logRequest(ctx context.Context, r *http.Request, clientID string) {
|
|||
|
|
logx.WithContext(ctx).Infof("安全中间件 - 客户端: %s, 方法: %s, 路径: %s, IP: %s",
|
|||
|
|
clientID, r.Method, r.URL.Path, m.getClientIP(r))
|
|||
|
|
}
|