| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | package middleware | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import ( | 
					
						
							|  |  |  |  | 	"context" | 
					
						
							|  |  |  |  | 	"fmt" | 
					
						
							|  |  |  |  | 	"strconv" | 
					
						
							|  |  |  |  | 	"strings" | 
					
						
							|  |  |  |  | 	"time" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"tyapi-server/internal/config" | 
					
						
							|  |  |  |  | 	"tyapi-server/internal/shared/interfaces" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"github.com/gin-gonic/gin" | 
					
						
							|  |  |  |  | 	"github.com/redis/go-redis/v9" | 
					
						
							|  |  |  |  | 	"go.uber.org/zap" | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // DailyRateLimitConfig 每日限流配置 | 
					
						
							|  |  |  |  | type DailyRateLimitConfig struct { | 
					
						
							|  |  |  |  | 	MaxRequestsPerDay int           `mapstructure:"max_requests_per_day"` // 每日最大请求次数 | 
					
						
							|  |  |  |  | 	MaxRequestsPerIP  int           `mapstructure:"max_requests_per_ip"`  // 每个IP每日最大请求次数 | 
					
						
							|  |  |  |  | 	KeyPrefix         string        `mapstructure:"key_prefix"`           // Redis键前缀 | 
					
						
							|  |  |  |  | 	TTL              time.Duration `mapstructure:"ttl"`                  // 键过期时间 | 
					
						
							|  |  |  |  | 	// 新增安全配置 | 
					
						
							|  |  |  |  | 	EnableIPWhitelist bool          `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单 | 
					
						
							|  |  |  |  | 	IPWhitelist       []string      `mapstructure:"ip_whitelist"`        // IP白名单 | 
					
						
							|  |  |  |  | 	EnableIPBlacklist bool          `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单 | 
					
						
							|  |  |  |  | 	IPBlacklist       []string      `mapstructure:"ip_blacklist"`        // IP黑名单 | 
					
						
							|  |  |  |  | 	EnableUserAgent   bool          `mapstructure:"enable_user_agent"`   // 是否检查User-Agent | 
					
						
							|  |  |  |  | 	BlockedUserAgents []string      `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent | 
					
						
							|  |  |  |  | 	EnableReferer     bool          `mapstructure:"enable_referer"`      // 是否检查Referer | 
					
						
							|  |  |  |  | 	AllowedReferers   []string      `mapstructure:"allowed_referers"`    // 允许的Referer | 
					
						
							|  |  |  |  | 	EnableGeoBlock    bool          `mapstructure:"enable_geo_block"`    // 是否启用地理位置阻止 | 
					
						
							|  |  |  |  | 	BlockedCountries  []string      `mapstructure:"blocked_countries"`   // 被阻止的国家/地区 | 
					
						
							|  |  |  |  | 	EnableProxyCheck  bool          `mapstructure:"enable_proxy_check"`  // 是否检查代理 | 
					
						
							|  |  |  |  | 	MaxConcurrent     int           `mapstructure:"max_concurrent"`      // 最大并发请求数 | 
					
						
							| 
									
										
										
										
											2025-09-12 01:15:09 +08:00
										 |  |  |  | 	// 路径排除配置 | 
					
						
							|  |  |  |  | 	ExcludePaths      []string      `mapstructure:"exclude_paths"`       // 排除频率限制的路径 | 
					
						
							|  |  |  |  | 	// 域名排除配置 | 
					
						
							|  |  |  |  | 	ExcludeDomains    []string      `mapstructure:"exclude_domains"`      // 排除频率限制的域名 | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // DailyRateLimitMiddleware 每日请求限制中间件 | 
					
						
							|  |  |  |  | type DailyRateLimitMiddleware struct { | 
					
						
							|  |  |  |  | 	config   *config.Config | 
					
						
							|  |  |  |  | 	redis    *redis.Client | 
					
						
							|  |  |  |  | 	response interfaces.ResponseBuilder | 
					
						
							|  |  |  |  | 	logger   *zap.Logger | 
					
						
							|  |  |  |  | 	limitConfig DailyRateLimitConfig | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // NewDailyRateLimitMiddleware 创建每日请求限制中间件 | 
					
						
							|  |  |  |  | func NewDailyRateLimitMiddleware( | 
					
						
							|  |  |  |  | 	cfg *config.Config,  | 
					
						
							|  |  |  |  | 	redis *redis.Client,  | 
					
						
							|  |  |  |  | 	response interfaces.ResponseBuilder, | 
					
						
							|  |  |  |  | 	logger *zap.Logger, | 
					
						
							|  |  |  |  | 	limitConfig DailyRateLimitConfig, | 
					
						
							|  |  |  |  | ) *DailyRateLimitMiddleware { | 
					
						
							|  |  |  |  | 	// 设置默认值 | 
					
						
							|  |  |  |  | 	if limitConfig.MaxRequestsPerDay <= 0 { | 
					
						
							|  |  |  |  | 		limitConfig.MaxRequestsPerDay = 200 // 默认每日200次 | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	if limitConfig.MaxRequestsPerIP <= 0 { | 
					
						
							|  |  |  |  | 		limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次 | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	if limitConfig.KeyPrefix == "" { | 
					
						
							|  |  |  |  | 		limitConfig.KeyPrefix = "daily_limit" | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	if limitConfig.TTL == 0 { | 
					
						
							|  |  |  |  | 		limitConfig.TTL = 24 * time.Hour // 默认24小时过期 | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	if limitConfig.MaxConcurrent <= 0 { | 
					
						
							|  |  |  |  | 		limitConfig.MaxConcurrent = 5 // 默认最大并发5个 | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return &DailyRateLimitMiddleware{ | 
					
						
							|  |  |  |  | 		config:      cfg, | 
					
						
							|  |  |  |  | 		redis:       redis, | 
					
						
							|  |  |  |  | 		response:    response, | 
					
						
							|  |  |  |  | 		logger:      logger, | 
					
						
							|  |  |  |  | 		limitConfig: limitConfig, | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetName 返回中间件名称 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) GetName() string { | 
					
						
							|  |  |  |  | 	return "daily_rate_limit" | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetPriority 返回中间件优先级 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) GetPriority() int { | 
					
						
							|  |  |  |  | 	return 85 // 在认证之后,业务处理之前 | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // Handle 返回中间件处理函数 | 
					
						
							| 
									
										
										
										
											2025-08-10 15:19:10 +08:00
										 |  |  |  | func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc { | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  |  | 		ctx := c.Request.Context() | 
					
						
							|  |  |  |  | 		 | 
					
						
							| 
									
										
										
										
											2025-09-12 01:15:09 +08:00
										 |  |  |  | 		// 检查是否在排除路径中 | 
					
						
							|  |  |  |  | 		if m.isExcludedPath(c.Request.URL.Path) { | 
					
						
							|  |  |  |  | 			c.Next() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 检查是否在排除域名中 | 
					
						
							|  |  |  |  | 		host := c.Request.Host | 
					
						
							|  |  |  |  | 		if m.isExcludedDomain(host) { | 
					
						
							|  |  |  |  | 			c.Next() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | 		// 获取客户端标识 | 
					
						
							|  |  |  |  | 		clientIP := m.getClientIP(c) | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 1. 检查IP白名单/黑名单 | 
					
						
							|  |  |  |  | 		if err := m.checkIPAccess(clientIP); err != nil { | 
					
						
							|  |  |  |  | 			m.logger.Warn("IP访问被拒绝", | 
					
						
							|  |  |  |  | 				zap.String("ip", clientIP), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id")), | 
					
						
							|  |  |  |  | 				zap.Error(err)) | 
					
						
							|  |  |  |  | 			m.response.Forbidden(c, "访问被拒绝") | 
					
						
							|  |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 2. 检查User-Agent | 
					
						
							|  |  |  |  | 		if err := m.checkUserAgent(c); err != nil { | 
					
						
							|  |  |  |  | 			m.logger.Warn("User-Agent被阻止", | 
					
						
							|  |  |  |  | 				zap.String("ip", clientIP), | 
					
						
							|  |  |  |  | 				zap.String("user_agent", c.GetHeader("User-Agent")), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id")), | 
					
						
							|  |  |  |  | 				zap.Error(err)) | 
					
						
							|  |  |  |  | 			m.response.Forbidden(c, "访问被拒绝") | 
					
						
							|  |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 3. 检查Referer | 
					
						
							|  |  |  |  | 		if err := m.checkReferer(c); err != nil { | 
					
						
							|  |  |  |  | 			m.logger.Warn("Referer检查失败", | 
					
						
							|  |  |  |  | 				zap.String("ip", clientIP), | 
					
						
							|  |  |  |  | 				zap.String("referer", c.GetHeader("Referer")), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id")), | 
					
						
							|  |  |  |  | 				zap.Error(err)) | 
					
						
							|  |  |  |  | 			m.response.Forbidden(c, "访问被拒绝") | 
					
						
							|  |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 4. 检查并发限制 | 
					
						
							|  |  |  |  | 		if err := m.checkConcurrentLimit(ctx, clientIP); err != nil { | 
					
						
							|  |  |  |  | 			m.logger.Warn("并发请求超限", | 
					
						
							|  |  |  |  | 				zap.String("ip", clientIP), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id")), | 
					
						
							|  |  |  |  | 				zap.Error(err)) | 
					
						
							|  |  |  |  | 			m.response.TooManyRequests(c, "系统繁忙,请稍后再试") | 
					
						
							|  |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 5. 检查接口总请求次数限制 | 
					
						
							|  |  |  |  | 		if err := m.checkTotalLimit(ctx); err != nil { | 
					
						
							|  |  |  |  | 			m.logger.Warn("接口总请求次数超限", | 
					
						
							|  |  |  |  | 				zap.String("ip", clientIP), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id")), | 
					
						
							|  |  |  |  | 				zap.Error(err)) | 
					
						
							|  |  |  |  | 			// 隐藏限制信息,返回通用错误 | 
					
						
							| 
									
										
										
										
											2025-08-10 15:19:10 +08:00
										 |  |  |  | 			m.response.InternalError(c, "系统繁忙,请稍后再试") | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 6. 检查IP限制 | 
					
						
							|  |  |  |  | 		if err := m.checkIPLimit(ctx, clientIP); err != nil { | 
					
						
							|  |  |  |  | 			m.logger.Warn("IP请求次数超限", | 
					
						
							|  |  |  |  | 				zap.String("ip", clientIP), | 
					
						
							|  |  |  |  | 				zap.String("request_id", c.GetString("request_id")), | 
					
						
							|  |  |  |  | 				zap.Error(err)) | 
					
						
							|  |  |  |  | 			// 隐藏限制信息,返回通用错误 | 
					
						
							| 
									
										
										
										
											2025-08-10 15:19:10 +08:00
										 |  |  |  | 			m.response.InternalError(c, "系统繁忙,请稍后再试") | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 7. 增加计数 | 
					
						
							|  |  |  |  | 		m.incrementCounters(ctx, clientIP) | 
					
						
							|  |  |  |  | 		 | 
					
						
							| 
									
										
										
										
											2025-08-10 15:19:10 +08:00
										 |  |  |  | 		// 8. 添加隐藏的响应头(仅用于内部监控) | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | 		m.addHiddenHeaders(c, clientIP) | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		c.Next() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-09-12 01:15:09 +08:00
										 |  |  |  | // isExcludedDomain 检查域名是否在排除列表中 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool { | 
					
						
							|  |  |  |  | 	for _, excludeDomain := range m.limitConfig.ExcludeDomains { | 
					
						
							|  |  |  |  | 		// 支持通配符匹配 | 
					
						
							|  |  |  |  | 		if strings.HasPrefix(excludeDomain, "*") { | 
					
						
							|  |  |  |  | 			// 后缀匹配,如 "*.api.example.com" 匹配 "api.example.com" | 
					
						
							|  |  |  |  | 			if strings.HasSuffix(host, excludeDomain[1:]) { | 
					
						
							|  |  |  |  | 				return true | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} else if strings.HasSuffix(excludeDomain, "*") { | 
					
						
							|  |  |  |  | 			// 前缀匹配,如 "api.*" 匹配 "api.example.com" | 
					
						
							|  |  |  |  | 			if strings.HasPrefix(host, excludeDomain[:len(excludeDomain)-1]) { | 
					
						
							|  |  |  |  | 				return true | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} else { | 
					
						
							|  |  |  |  | 			// 精确匹配 | 
					
						
							|  |  |  |  | 			if host == excludeDomain { | 
					
						
							|  |  |  |  | 				return true | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	return false | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // isExcludedPath 检查路径是否在排除列表中 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) isExcludedPath(path string) bool { | 
					
						
							|  |  |  |  | 	for _, excludePath := range m.limitConfig.ExcludePaths { | 
					
						
							|  |  |  |  | 		// 支持多种匹配模式 | 
					
						
							|  |  |  |  | 		if strings.HasPrefix(excludePath, "*") { | 
					
						
							|  |  |  |  | 			// 前缀匹配,如 "*api_name" 匹配 "/api/v1/any_api_name" | 
					
						
							|  |  |  |  | 			if strings.Contains(path, excludePath[1:]) { | 
					
						
							|  |  |  |  | 				return true | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} else if strings.HasSuffix(excludePath, "*") { | 
					
						
							|  |  |  |  | 			// 后缀匹配,如 "/api/v1/*" 匹配 "/api/v1/any_api_name" | 
					
						
							|  |  |  |  | 			if strings.HasPrefix(path, excludePath[:len(excludePath)-1]) { | 
					
						
							|  |  |  |  | 				return true | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} else if strings.Contains(excludePath, "*") { | 
					
						
							|  |  |  |  | 			// 中间通配符匹配,如 "/api/v1/*api_name" 匹配 "/api/v1/any_api_name" | 
					
						
							|  |  |  |  | 			parts := strings.Split(excludePath, "*") | 
					
						
							|  |  |  |  | 			if len(parts) == 2 { | 
					
						
							|  |  |  |  | 				prefix := parts[0] | 
					
						
							|  |  |  |  | 				suffix := parts[1] | 
					
						
							|  |  |  |  | 				if strings.HasPrefix(path, prefix) && strings.HasSuffix(path, suffix) { | 
					
						
							|  |  |  |  | 					return true | 
					
						
							|  |  |  |  | 				} | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} else { | 
					
						
							|  |  |  |  | 			// 精确匹配 | 
					
						
							|  |  |  |  | 			if path == excludePath { | 
					
						
							|  |  |  |  | 				return true | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	return false | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | // IsGlobal 是否为全局中间件 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) IsGlobal() bool { | 
					
						
							|  |  |  |  | 	return false // 不是全局中间件,需要手动应用到特定路由 | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // checkIPAccess 检查IP访问权限 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error { | 
					
						
							|  |  |  |  | 	// 检查黑名单 | 
					
						
							|  |  |  |  | 	if m.limitConfig.EnableIPBlacklist { | 
					
						
							|  |  |  |  | 		for _, blockedIP := range m.limitConfig.IPBlacklist { | 
					
						
							|  |  |  |  | 			if m.isIPMatch(clientIP, blockedIP) { | 
					
						
							|  |  |  |  | 				return fmt.Errorf("IP %s 在黑名单中", clientIP) | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 检查白名单(如果启用) | 
					
						
							|  |  |  |  | 	if m.limitConfig.EnableIPWhitelist { | 
					
						
							|  |  |  |  | 		allowed := false | 
					
						
							|  |  |  |  | 		for _, allowedIP := range m.limitConfig.IPWhitelist { | 
					
						
							|  |  |  |  | 			if m.isIPMatch(clientIP, allowedIP) { | 
					
						
							|  |  |  |  | 				allowed = true | 
					
						
							|  |  |  |  | 				break | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		if !allowed { | 
					
						
							|  |  |  |  | 			return fmt.Errorf("IP %s 不在白名单中", clientIP) | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // isIPMatch 检查IP是否匹配(支持CIDR和通配符) | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool { | 
					
						
							|  |  |  |  | 	// 简单的通配符匹配 | 
					
						
							|  |  |  |  | 	if strings.Contains(pattern, "*") { | 
					
						
							|  |  |  |  | 		parts := strings.Split(pattern, ".") | 
					
						
							|  |  |  |  | 		clientParts := strings.Split(clientIP, ".") | 
					
						
							|  |  |  |  | 		if len(parts) != len(clientParts) { | 
					
						
							|  |  |  |  | 			return false | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		for i, part := range parts { | 
					
						
							|  |  |  |  | 			if part != "*" && part != clientParts[i] { | 
					
						
							|  |  |  |  | 				return false | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return true | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 精确匹配 | 
					
						
							|  |  |  |  | 	return clientIP == pattern | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // checkUserAgent 检查User-Agent | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error { | 
					
						
							|  |  |  |  | 	if !m.limitConfig.EnableUserAgent { | 
					
						
							|  |  |  |  | 		return nil | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	userAgent := c.GetHeader("User-Agent") | 
					
						
							|  |  |  |  | 	if userAgent == "" { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("缺少User-Agent") | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 检查被阻止的User-Agent | 
					
						
							|  |  |  |  | 	for _, blocked := range m.limitConfig.BlockedUserAgents { | 
					
						
							|  |  |  |  | 		if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) { | 
					
						
							|  |  |  |  | 			return fmt.Errorf("User-Agent被阻止: %s", blocked) | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // checkReferer 检查Referer | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error { | 
					
						
							|  |  |  |  | 	if !m.limitConfig.EnableReferer { | 
					
						
							|  |  |  |  | 		return nil | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	referer := c.GetHeader("Referer") | 
					
						
							|  |  |  |  | 	if referer == "" { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("缺少Referer") | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 检查允许的Referer | 
					
						
							|  |  |  |  | 	if len(m.limitConfig.AllowedReferers) > 0 { | 
					
						
							|  |  |  |  | 		allowed := false | 
					
						
							|  |  |  |  | 		for _, allowedRef := range m.limitConfig.AllowedReferers { | 
					
						
							|  |  |  |  | 			if strings.Contains(referer, allowedRef) { | 
					
						
							|  |  |  |  | 				allowed = true | 
					
						
							|  |  |  |  | 				break | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		if !allowed { | 
					
						
							|  |  |  |  | 			return fmt.Errorf("Referer不被允许: %s", referer) | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // checkConcurrentLimit 检查并发限制 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error { | 
					
						
							|  |  |  |  | 	key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 获取当前并发数 | 
					
						
							|  |  |  |  | 	current, err := m.redis.Get(ctx, key).Result() | 
					
						
							|  |  |  |  | 	if err != nil && err != redis.Nil { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("获取并发计数失败: %w", err) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	currentCount := 0 | 
					
						
							|  |  |  |  | 	if current != "" { | 
					
						
							|  |  |  |  | 		if count, err := strconv.Atoi(current); err == nil { | 
					
						
							|  |  |  |  | 			currentCount = count | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	if currentCount >= m.limitConfig.MaxConcurrent { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("并发请求超限: %d", currentCount) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 增加并发计数 | 
					
						
							|  |  |  |  | 	pipe := m.redis.Pipeline() | 
					
						
							|  |  |  |  | 	pipe.Incr(ctx, key) | 
					
						
							|  |  |  |  | 	pipe.Expire(ctx, key, 30*time.Second) // 30秒过期 | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	_, err = pipe.Exec(ctx) | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err)) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getClientIP 获取客户端IP地址(增强版) | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string { | 
					
						
							|  |  |  |  | 	// 检查是否为代理IP | 
					
						
							|  |  |  |  | 	if m.limitConfig.EnableProxyCheck { | 
					
						
							|  |  |  |  | 		// 检查常见的代理头部 | 
					
						
							|  |  |  |  | 		proxyHeaders := []string{ | 
					
						
							|  |  |  |  | 			"CF-Connecting-IP",     // Cloudflare | 
					
						
							|  |  |  |  | 			"X-Forwarded-For",      // 标准代理头 | 
					
						
							|  |  |  |  | 			"X-Real-IP",            // Nginx | 
					
						
							|  |  |  |  | 			"X-Client-IP",          // Apache | 
					
						
							|  |  |  |  | 			"X-Forwarded",          // 其他代理 | 
					
						
							|  |  |  |  | 			"Forwarded-For",        // RFC 7239 | 
					
						
							|  |  |  |  | 			"Forwarded",            // RFC 7239 | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		for _, header := range proxyHeaders { | 
					
						
							|  |  |  |  | 			if ip := c.GetHeader(header); ip != "" { | 
					
						
							|  |  |  |  | 				// 如果X-Forwarded-For包含多个IP,取第一个 | 
					
						
							|  |  |  |  | 				if header == "X-Forwarded-For" && strings.Contains(ip, ",") { | 
					
						
							|  |  |  |  | 					ip = strings.TrimSpace(strings.Split(ip, ",")[0]) | 
					
						
							|  |  |  |  | 				} | 
					
						
							|  |  |  |  | 				return ip | 
					
						
							|  |  |  |  | 			} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 回退到标准方法 | 
					
						
							|  |  |  |  | 	if xff := c.GetHeader("X-Forwarded-For"); xff != "" { | 
					
						
							|  |  |  |  | 		if strings.Contains(xff, ",") { | 
					
						
							|  |  |  |  | 			return strings.TrimSpace(strings.Split(xff, ",")[0]) | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return xff | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	if xri := c.GetHeader("X-Real-IP"); xri != "" { | 
					
						
							|  |  |  |  | 		return xri | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return c.ClientIP() | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // checkTotalLimit 检查接口总请求次数限制 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error { | 
					
						
							|  |  |  |  | 	key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	count, err := m.getCounter(ctx, key) | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("获取总请求计数失败: %w", err) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	if count >= m.limitConfig.MaxRequestsPerDay { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // checkIPLimit 检查IP限制 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error { | 
					
						
							|  |  |  |  | 	key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	count, err := m.getCounter(ctx, key) | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("获取IP计数失败: %w", err) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	if count >= m.limitConfig.MaxRequestsPerIP { | 
					
						
							|  |  |  |  | 		return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // incrementCounters 增加计数器 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) { | 
					
						
							|  |  |  |  | 	// 增加总请求计数 | 
					
						
							|  |  |  |  | 	totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) | 
					
						
							|  |  |  |  | 	m.incrementCounter(ctx, totalKey) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 增加IP计数 | 
					
						
							|  |  |  |  | 	ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) | 
					
						
							|  |  |  |  | 	m.incrementCounter(ctx, ipKey) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getCounter 获取计数器值 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) { | 
					
						
							|  |  |  |  | 	val, err := m.redis.Get(ctx, key).Result() | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		if err == redis.Nil { | 
					
						
							|  |  |  |  | 			return 0, nil // 键不存在,计数为0 | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return 0, err | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	count, err := strconv.Atoi(val) | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		return 0, fmt.Errorf("解析计数失败: %w", err) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	 | 
					
						
							| 
									
										
										
										
											2025-08-10 15:19:10 +08:00
										 |  |  |  | 	return count, nil | 
					
						
							| 
									
										
										
										
											2025-08-10 14:40:02 +08:00
										 |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // incrementCounter 增加计数器 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) { | 
					
						
							|  |  |  |  | 	// 使用Redis的INCR命令增加计数 | 
					
						
							|  |  |  |  | 	pipe := m.redis.Pipeline() | 
					
						
							|  |  |  |  | 	pipe.Incr(ctx, key) | 
					
						
							|  |  |  |  | 	pipe.Expire(ctx, key, m.limitConfig.TTL) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	_, err := pipe.Exec(ctx) | 
					
						
							|  |  |  |  | 	if err != nil { | 
					
						
							|  |  |  |  | 		m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err)) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getDateKey 获取日期键(格式:2024-01-01) | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) getDateKey() string { | 
					
						
							|  |  |  |  | 	return time.Now().Format("2006-01-02") | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // addHiddenHeaders 添加隐藏的响应头(仅用于内部监控) | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) { | 
					
						
							|  |  |  |  | 	ctx := c.Request.Context() | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 添加隐藏的监控头(客户端看不到) | 
					
						
							|  |  |  |  | 	totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) | 
					
						
							|  |  |  |  | 	totalCount, _ := m.getCounter(ctx, totalKey) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) | 
					
						
							|  |  |  |  | 	ipCount, _ := m.getCounter(ctx, ipKey) | 
					
						
							|  |  |  |  | 	 | 
					
						
							|  |  |  |  | 	// 使用非标准的头部名称,避免被客户端识别 | 
					
						
							|  |  |  |  | 	c.Header("X-System-Status", "normal") | 
					
						
							|  |  |  |  | 	c.Header("X-Total-Count", strconv.Itoa(totalCount)) | 
					
						
							|  |  |  |  | 	c.Header("X-IP-Count", strconv.Itoa(ipCount)) | 
					
						
							|  |  |  |  | 	c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339)) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getResetTime 获取重置时间(明天0点) | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) getResetTime() time.Time { | 
					
						
							|  |  |  |  | 	now := time.Now() | 
					
						
							|  |  |  |  | 	tomorrow := now.Add(24 * time.Hour) | 
					
						
							|  |  |  |  | 	return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location()) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetStats 获取限流统计 | 
					
						
							|  |  |  |  | func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} { | 
					
						
							|  |  |  |  | 	return map[string]interface{}{ | 
					
						
							|  |  |  |  | 		"max_requests_per_day": m.limitConfig.MaxRequestsPerDay, | 
					
						
							|  |  |  |  | 		"max_requests_per_ip":  m.limitConfig.MaxRequestsPerIP, | 
					
						
							|  |  |  |  | 		"max_concurrent":       m.limitConfig.MaxConcurrent, | 
					
						
							|  |  |  |  | 		"key_prefix":           m.limitConfig.KeyPrefix, | 
					
						
							|  |  |  |  | 		"ttl":                  m.limitConfig.TTL.String(), | 
					
						
							|  |  |  |  | 		"security_features": map[string]interface{}{ | 
					
						
							|  |  |  |  | 			"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist, | 
					
						
							|  |  |  |  | 			"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist, | 
					
						
							|  |  |  |  | 			"user_agent_check":     m.limitConfig.EnableUserAgent, | 
					
						
							|  |  |  |  | 			"referer_check":        m.limitConfig.EnableReferer, | 
					
						
							|  |  |  |  | 			"proxy_check":          m.limitConfig.EnableProxyCheck, | 
					
						
							|  |  |  |  | 		}, | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } |