554 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			554 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"fmt"
 | ||
| 	"strconv"
 | ||
| 	"strings"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"tyapi-server/internal/config"
 | ||
| 	"tyapi-server/internal/shared/interfaces"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| 	"github.com/redis/go-redis/v9"
 | ||
| 	"go.uber.org/zap"
 | ||
| )
 | ||
| 
 | ||
| // DailyRateLimitConfig 每日限流配置
 | ||
| type DailyRateLimitConfig struct {
 | ||
| 	MaxRequestsPerDay int           `mapstructure:"max_requests_per_day"` // 每日最大请求次数
 | ||
| 	MaxRequestsPerIP  int           `mapstructure:"max_requests_per_ip"`  // 每个IP每日最大请求次数
 | ||
| 	KeyPrefix         string        `mapstructure:"key_prefix"`           // Redis键前缀
 | ||
| 	TTL              time.Duration `mapstructure:"ttl"`                  // 键过期时间
 | ||
| 	// 新增安全配置
 | ||
| 	EnableIPWhitelist bool          `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
 | ||
| 	IPWhitelist       []string      `mapstructure:"ip_whitelist"`        // IP白名单
 | ||
| 	EnableIPBlacklist bool          `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
 | ||
| 	IPBlacklist       []string      `mapstructure:"ip_blacklist"`        // IP黑名单
 | ||
| 	EnableUserAgent   bool          `mapstructure:"enable_user_agent"`   // 是否检查User-Agent
 | ||
| 	BlockedUserAgents []string      `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
 | ||
| 	EnableReferer     bool          `mapstructure:"enable_referer"`      // 是否检查Referer
 | ||
| 	AllowedReferers   []string      `mapstructure:"allowed_referers"`    // 允许的Referer
 | ||
| 	EnableGeoBlock    bool          `mapstructure:"enable_geo_block"`    // 是否启用地理位置阻止
 | ||
| 	BlockedCountries  []string      `mapstructure:"blocked_countries"`   // 被阻止的国家/地区
 | ||
| 	EnableProxyCheck  bool          `mapstructure:"enable_proxy_check"`  // 是否检查代理
 | ||
| 	MaxConcurrent     int           `mapstructure:"max_concurrent"`      // 最大并发请求数
 | ||
| 	// 路径排除配置
 | ||
| 	ExcludePaths      []string      `mapstructure:"exclude_paths"`       // 排除频率限制的路径
 | ||
| 	// 域名排除配置
 | ||
| 	ExcludeDomains    []string      `mapstructure:"exclude_domains"`      // 排除频率限制的域名
 | ||
| }
 | ||
| 
 | ||
| // DailyRateLimitMiddleware 每日请求限制中间件
 | ||
| type DailyRateLimitMiddleware struct {
 | ||
| 	config   *config.Config
 | ||
| 	redis    *redis.Client
 | ||
| 	response interfaces.ResponseBuilder
 | ||
| 	logger   *zap.Logger
 | ||
| 	limitConfig DailyRateLimitConfig
 | ||
| }
 | ||
| 
 | ||
| // NewDailyRateLimitMiddleware 创建每日请求限制中间件
 | ||
| func NewDailyRateLimitMiddleware(
 | ||
| 	cfg *config.Config, 
 | ||
| 	redis *redis.Client, 
 | ||
| 	response interfaces.ResponseBuilder,
 | ||
| 	logger *zap.Logger,
 | ||
| 	limitConfig DailyRateLimitConfig,
 | ||
| ) *DailyRateLimitMiddleware {
 | ||
| 	// 设置默认值
 | ||
| 	if limitConfig.MaxRequestsPerDay <= 0 {
 | ||
| 		limitConfig.MaxRequestsPerDay = 200 // 默认每日200次
 | ||
| 	}
 | ||
| 	if limitConfig.MaxRequestsPerIP <= 0 {
 | ||
| 		limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次
 | ||
| 	}
 | ||
| 	if limitConfig.KeyPrefix == "" {
 | ||
| 		limitConfig.KeyPrefix = "daily_limit"
 | ||
| 	}
 | ||
| 	if limitConfig.TTL == 0 {
 | ||
| 		limitConfig.TTL = 24 * time.Hour // 默认24小时过期
 | ||
| 	}
 | ||
| 	if limitConfig.MaxConcurrent <= 0 {
 | ||
| 		limitConfig.MaxConcurrent = 5 // 默认最大并发5个
 | ||
| 	}
 | ||
| 
 | ||
| 	return &DailyRateLimitMiddleware{
 | ||
| 		config:      cfg,
 | ||
| 		redis:       redis,
 | ||
| 		response:    response,
 | ||
| 		logger:      logger,
 | ||
| 		limitConfig: limitConfig,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *DailyRateLimitMiddleware) GetName() string {
 | ||
| 	return "daily_rate_limit"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *DailyRateLimitMiddleware) GetPriority() int {
 | ||
| 	return 85 // 在认证之后,业务处理之前
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		ctx := c.Request.Context()
 | ||
| 		
 | ||
| 		// 检查是否在排除路径中
 | ||
| 		if m.isExcludedPath(c.Request.URL.Path) {
 | ||
| 			c.Next()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 检查是否在排除域名中
 | ||
| 		host := c.Request.Host
 | ||
| 		if m.isExcludedDomain(host) {
 | ||
| 			c.Next()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 获取客户端标识
 | ||
| 		clientIP := m.getClientIP(c)
 | ||
| 		
 | ||
| 		// 1. 检查IP白名单/黑名单
 | ||
| 		if err := m.checkIPAccess(clientIP); err != nil {
 | ||
| 			m.logger.Warn("IP访问被拒绝",
 | ||
| 				zap.String("ip", clientIP),
 | ||
| 				zap.String("request_id", c.GetString("request_id")),
 | ||
| 				zap.Error(err))
 | ||
| 			m.response.Forbidden(c, "访问被拒绝")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 2. 检查User-Agent
 | ||
| 		if err := m.checkUserAgent(c); err != nil {
 | ||
| 			m.logger.Warn("User-Agent被阻止",
 | ||
| 				zap.String("ip", clientIP),
 | ||
| 				zap.String("user_agent", c.GetHeader("User-Agent")),
 | ||
| 				zap.String("request_id", c.GetString("request_id")),
 | ||
| 				zap.Error(err))
 | ||
| 			m.response.Forbidden(c, "访问被拒绝")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 3. 检查Referer
 | ||
| 		if err := m.checkReferer(c); err != nil {
 | ||
| 			m.logger.Warn("Referer检查失败",
 | ||
| 				zap.String("ip", clientIP),
 | ||
| 				zap.String("referer", c.GetHeader("Referer")),
 | ||
| 				zap.String("request_id", c.GetString("request_id")),
 | ||
| 				zap.Error(err))
 | ||
| 			m.response.Forbidden(c, "访问被拒绝")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 4. 检查并发限制
 | ||
| 		if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
 | ||
| 			m.logger.Warn("并发请求超限",
 | ||
| 				zap.String("ip", clientIP),
 | ||
| 				zap.String("request_id", c.GetString("request_id")),
 | ||
| 				zap.Error(err))
 | ||
| 			m.response.TooManyRequests(c, "系统繁忙,请稍后再试")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 5. 检查接口总请求次数限制
 | ||
| 		if err := m.checkTotalLimit(ctx); err != nil {
 | ||
| 			m.logger.Warn("接口总请求次数超限",
 | ||
| 				zap.String("ip", clientIP),
 | ||
| 				zap.String("request_id", c.GetString("request_id")),
 | ||
| 				zap.Error(err))
 | ||
| 			// 隐藏限制信息,返回通用错误
 | ||
| 			m.response.InternalError(c, "系统繁忙,请稍后再试")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 6. 检查IP限制
 | ||
| 		if err := m.checkIPLimit(ctx, clientIP); err != nil {
 | ||
| 			m.logger.Warn("IP请求次数超限",
 | ||
| 				zap.String("ip", clientIP),
 | ||
| 				zap.String("request_id", c.GetString("request_id")),
 | ||
| 				zap.Error(err))
 | ||
| 			// 隐藏限制信息,返回通用错误
 | ||
| 			m.response.InternalError(c, "系统繁忙,请稍后再试")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 		
 | ||
| 		// 7. 增加计数
 | ||
| 		m.incrementCounters(ctx, clientIP)
 | ||
| 		
 | ||
| 		// 8. 添加隐藏的响应头(仅用于内部监控)
 | ||
| 		m.addHiddenHeaders(c, clientIP)
 | ||
| 		
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // isExcludedDomain 检查域名是否在排除列表中
 | ||
| func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool {
 | ||
| 	for _, excludeDomain := range m.limitConfig.ExcludeDomains {
 | ||
| 		// 支持通配符匹配
 | ||
| 		if strings.HasPrefix(excludeDomain, "*") {
 | ||
| 			// 后缀匹配,如 "*.api.example.com" 匹配 "api.example.com"
 | ||
| 			if strings.HasSuffix(host, excludeDomain[1:]) {
 | ||
| 				return true
 | ||
| 			}
 | ||
| 		} else if strings.HasSuffix(excludeDomain, "*") {
 | ||
| 			// 前缀匹配,如 "api.*" 匹配 "api.example.com"
 | ||
| 			if strings.HasPrefix(host, excludeDomain[:len(excludeDomain)-1]) {
 | ||
| 				return true
 | ||
| 			}
 | ||
| 		} else {
 | ||
| 			// 精确匹配
 | ||
| 			if host == excludeDomain {
 | ||
| 				return true
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // isExcludedPath 检查路径是否在排除列表中
 | ||
| func (m *DailyRateLimitMiddleware) isExcludedPath(path string) bool {
 | ||
| 	for _, excludePath := range m.limitConfig.ExcludePaths {
 | ||
| 		// 支持多种匹配模式
 | ||
| 		if strings.HasPrefix(excludePath, "*") {
 | ||
| 			// 前缀匹配,如 "*api_name" 匹配 "/api/v1/any_api_name"
 | ||
| 			if strings.Contains(path, excludePath[1:]) {
 | ||
| 				return true
 | ||
| 			}
 | ||
| 		} else if strings.HasSuffix(excludePath, "*") {
 | ||
| 			// 后缀匹配,如 "/api/v1/*" 匹配 "/api/v1/any_api_name"
 | ||
| 			if strings.HasPrefix(path, excludePath[:len(excludePath)-1]) {
 | ||
| 				return true
 | ||
| 			}
 | ||
| 		} else if strings.Contains(excludePath, "*") {
 | ||
| 			// 中间通配符匹配,如 "/api/v1/*api_name" 匹配 "/api/v1/any_api_name"
 | ||
| 			parts := strings.Split(excludePath, "*")
 | ||
| 			if len(parts) == 2 {
 | ||
| 				prefix := parts[0]
 | ||
| 				suffix := parts[1]
 | ||
| 				if strings.HasPrefix(path, prefix) && strings.HasSuffix(path, suffix) {
 | ||
| 					return true
 | ||
| 				}
 | ||
| 			}
 | ||
| 		} else {
 | ||
| 			// 精确匹配
 | ||
| 			if path == excludePath {
 | ||
| 				return true
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *DailyRateLimitMiddleware) IsGlobal() bool {
 | ||
| 	return false // 不是全局中间件,需要手动应用到特定路由
 | ||
| }
 | ||
| 
 | ||
| // checkIPAccess 检查IP访问权限
 | ||
| func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
 | ||
| 	// 检查黑名单
 | ||
| 	if m.limitConfig.EnableIPBlacklist {
 | ||
| 		for _, blockedIP := range m.limitConfig.IPBlacklist {
 | ||
| 			if m.isIPMatch(clientIP, blockedIP) {
 | ||
| 				return fmt.Errorf("IP %s 在黑名单中", clientIP)
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 检查白名单(如果启用)
 | ||
| 	if m.limitConfig.EnableIPWhitelist {
 | ||
| 		allowed := false
 | ||
| 		for _, allowedIP := range m.limitConfig.IPWhitelist {
 | ||
| 			if m.isIPMatch(clientIP, allowedIP) {
 | ||
| 				allowed = true
 | ||
| 				break
 | ||
| 			}
 | ||
| 		}
 | ||
| 		if !allowed {
 | ||
| 			return fmt.Errorf("IP %s 不在白名单中", clientIP)
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // isIPMatch 检查IP是否匹配(支持CIDR和通配符)
 | ||
| func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
 | ||
| 	// 简单的通配符匹配
 | ||
| 	if strings.Contains(pattern, "*") {
 | ||
| 		parts := strings.Split(pattern, ".")
 | ||
| 		clientParts := strings.Split(clientIP, ".")
 | ||
| 		if len(parts) != len(clientParts) {
 | ||
| 			return false
 | ||
| 		}
 | ||
| 		for i, part := range parts {
 | ||
| 			if part != "*" && part != clientParts[i] {
 | ||
| 				return false
 | ||
| 			}
 | ||
| 		}
 | ||
| 		return true
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 精确匹配
 | ||
| 	return clientIP == pattern
 | ||
| }
 | ||
| 
 | ||
| // checkUserAgent 检查User-Agent
 | ||
| func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
 | ||
| 	if !m.limitConfig.EnableUserAgent {
 | ||
| 		return nil
 | ||
| 	}
 | ||
| 	
 | ||
| 	userAgent := c.GetHeader("User-Agent")
 | ||
| 	if userAgent == "" {
 | ||
| 		return fmt.Errorf("缺少User-Agent")
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 检查被阻止的User-Agent
 | ||
| 	for _, blocked := range m.limitConfig.BlockedUserAgents {
 | ||
| 		if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
 | ||
| 			return fmt.Errorf("User-Agent被阻止: %s", blocked)
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // checkReferer 检查Referer
 | ||
| func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
 | ||
| 	if !m.limitConfig.EnableReferer {
 | ||
| 		return nil
 | ||
| 	}
 | ||
| 	
 | ||
| 	referer := c.GetHeader("Referer")
 | ||
| 	if referer == "" {
 | ||
| 		return fmt.Errorf("缺少Referer")
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 检查允许的Referer
 | ||
| 	if len(m.limitConfig.AllowedReferers) > 0 {
 | ||
| 		allowed := false
 | ||
| 		for _, allowedRef := range m.limitConfig.AllowedReferers {
 | ||
| 			if strings.Contains(referer, allowedRef) {
 | ||
| 				allowed = true
 | ||
| 				break
 | ||
| 			}
 | ||
| 		}
 | ||
| 		if !allowed {
 | ||
| 			return fmt.Errorf("Referer不被允许: %s", referer)
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // checkConcurrentLimit 检查并发限制
 | ||
| func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
 | ||
| 	key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
 | ||
| 	
 | ||
| 	// 获取当前并发数
 | ||
| 	current, err := m.redis.Get(ctx, key).Result()
 | ||
| 	if err != nil && err != redis.Nil {
 | ||
| 		return fmt.Errorf("获取并发计数失败: %w", err)
 | ||
| 	}
 | ||
| 	
 | ||
| 	currentCount := 0
 | ||
| 	if current != "" {
 | ||
| 		if count, err := strconv.Atoi(current); err == nil {
 | ||
| 			currentCount = count
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	if currentCount >= m.limitConfig.MaxConcurrent {
 | ||
| 		return fmt.Errorf("并发请求超限: %d", currentCount)
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 增加并发计数
 | ||
| 	pipe := m.redis.Pipeline()
 | ||
| 	pipe.Incr(ctx, key)
 | ||
| 	pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
 | ||
| 	
 | ||
| 	_, err = pipe.Exec(ctx)
 | ||
| 	if err != nil {
 | ||
| 		m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
 | ||
| 	}
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // getClientIP 获取客户端IP地址(增强版)
 | ||
| func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
 | ||
| 	// 检查是否为代理IP
 | ||
| 	if m.limitConfig.EnableProxyCheck {
 | ||
| 		// 检查常见的代理头部
 | ||
| 		proxyHeaders := []string{
 | ||
| 			"CF-Connecting-IP",     // Cloudflare
 | ||
| 			"X-Forwarded-For",      // 标准代理头
 | ||
| 			"X-Real-IP",            // Nginx
 | ||
| 			"X-Client-IP",          // Apache
 | ||
| 			"X-Forwarded",          // 其他代理
 | ||
| 			"Forwarded-For",        // RFC 7239
 | ||
| 			"Forwarded",            // RFC 7239
 | ||
| 		}
 | ||
| 		
 | ||
| 		for _, header := range proxyHeaders {
 | ||
| 			if ip := c.GetHeader(header); ip != "" {
 | ||
| 				// 如果X-Forwarded-For包含多个IP,取第一个
 | ||
| 				if header == "X-Forwarded-For" && strings.Contains(ip, ",") {
 | ||
| 					ip = strings.TrimSpace(strings.Split(ip, ",")[0])
 | ||
| 				}
 | ||
| 				return ip
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 回退到标准方法
 | ||
| 	if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
 | ||
| 		if strings.Contains(xff, ",") {
 | ||
| 			return strings.TrimSpace(strings.Split(xff, ",")[0])
 | ||
| 		}
 | ||
| 		return xff
 | ||
| 	}
 | ||
| 	
 | ||
| 	if xri := c.GetHeader("X-Real-IP"); xri != "" {
 | ||
| 		return xri
 | ||
| 	}
 | ||
| 	
 | ||
| 	return c.ClientIP()
 | ||
| }
 | ||
| 
 | ||
| // checkTotalLimit 检查接口总请求次数限制
 | ||
| func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
 | ||
| 	key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
 | ||
| 	
 | ||
| 	count, err := m.getCounter(ctx, key)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("获取总请求计数失败: %w", err)
 | ||
| 	}
 | ||
| 	
 | ||
| 	if count >= m.limitConfig.MaxRequestsPerDay {
 | ||
| 		return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
 | ||
| 	}
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // checkIPLimit 检查IP限制
 | ||
| func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
 | ||
| 	key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
 | ||
| 	
 | ||
| 	count, err := m.getCounter(ctx, key)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("获取IP计数失败: %w", err)
 | ||
| 	}
 | ||
| 	
 | ||
| 	if count >= m.limitConfig.MaxRequestsPerIP {
 | ||
| 		return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
 | ||
| 	}
 | ||
| 	
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // incrementCounters 增加计数器
 | ||
| func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) {
 | ||
| 	// 增加总请求计数
 | ||
| 	totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
 | ||
| 	m.incrementCounter(ctx, totalKey)
 | ||
| 	
 | ||
| 	// 增加IP计数
 | ||
| 	ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
 | ||
| 	m.incrementCounter(ctx, ipKey)
 | ||
| }
 | ||
| 
 | ||
| // getCounter 获取计数器值
 | ||
| func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) {
 | ||
| 	val, err := m.redis.Get(ctx, key).Result()
 | ||
| 	if err != nil {
 | ||
| 		if err == redis.Nil {
 | ||
| 			return 0, nil // 键不存在,计数为0
 | ||
| 		}
 | ||
| 		return 0, err
 | ||
| 	}
 | ||
| 	
 | ||
| 	count, err := strconv.Atoi(val)
 | ||
| 	if err != nil {
 | ||
| 		return 0, fmt.Errorf("解析计数失败: %w", err)
 | ||
| 	}
 | ||
| 	
 | ||
| 	return count, nil
 | ||
| }
 | ||
| 
 | ||
| // incrementCounter 增加计数器
 | ||
| func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) {
 | ||
| 	// 使用Redis的INCR命令增加计数
 | ||
| 	pipe := m.redis.Pipeline()
 | ||
| 	pipe.Incr(ctx, key)
 | ||
| 	pipe.Expire(ctx, key, m.limitConfig.TTL)
 | ||
| 	
 | ||
| 	_, err := pipe.Exec(ctx)
 | ||
| 	if err != nil {
 | ||
| 		m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // getDateKey 获取日期键(格式:2024-01-01)
 | ||
| func (m *DailyRateLimitMiddleware) getDateKey() string {
 | ||
| 	return time.Now().Format("2006-01-02")
 | ||
| }
 | ||
| 
 | ||
| // addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
 | ||
| func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
 | ||
| 	ctx := c.Request.Context()
 | ||
| 	
 | ||
| 	// 添加隐藏的监控头(客户端看不到)
 | ||
| 	totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
 | ||
| 	totalCount, _ := m.getCounter(ctx, totalKey)
 | ||
| 	
 | ||
| 	ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
 | ||
| 	ipCount, _ := m.getCounter(ctx, ipKey)
 | ||
| 	
 | ||
| 	// 使用非标准的头部名称,避免被客户端识别
 | ||
| 	c.Header("X-System-Status", "normal")
 | ||
| 	c.Header("X-Total-Count", strconv.Itoa(totalCount))
 | ||
| 	c.Header("X-IP-Count", strconv.Itoa(ipCount))
 | ||
| 	c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339))
 | ||
| }
 | ||
| 
 | ||
| // getResetTime 获取重置时间(明天0点)
 | ||
| func (m *DailyRateLimitMiddleware) getResetTime() time.Time {
 | ||
| 	now := time.Now()
 | ||
| 	tomorrow := now.Add(24 * time.Hour)
 | ||
| 	return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location())
 | ||
| }
 | ||
| 
 | ||
| // GetStats 获取限流统计
 | ||
| func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} {
 | ||
| 	return map[string]interface{}{
 | ||
| 		"max_requests_per_day": m.limitConfig.MaxRequestsPerDay,
 | ||
| 		"max_requests_per_ip":  m.limitConfig.MaxRequestsPerIP,
 | ||
| 		"max_concurrent":       m.limitConfig.MaxConcurrent,
 | ||
| 		"key_prefix":           m.limitConfig.KeyPrefix,
 | ||
| 		"ttl":                  m.limitConfig.TTL.String(),
 | ||
| 		"security_features": map[string]interface{}{
 | ||
| 			"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist,
 | ||
| 			"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist,
 | ||
| 			"user_agent_check":     m.limitConfig.EnableUserAgent,
 | ||
| 			"referer_check":        m.limitConfig.EnableReferer,
 | ||
| 			"proxy_check":          m.limitConfig.EnableProxyCheck,
 | ||
| 		},
 | ||
| 	}
 | ||
| }
 |