This commit is contained in:
Mrx
2026-02-27 14:49:29 +08:00
parent f17e22f4c8
commit d12529307b
16 changed files with 633 additions and 95 deletions

View File

@@ -20,39 +20,39 @@ type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
// 路径排除配置
ExcludePaths []string `mapstructure:"exclude_paths"` // 排除频率限制的路径
ExcludePaths []string `mapstructure:"exclude_paths"` // 排除频率限制的路径
// 域名排除配置
ExcludeDomains []string `mapstructure:"exclude_domains"` // 排除频率限制的域名
ExcludeDomains []string `mapstructure:"exclude_domains"` // 排除频率限制的域名
}
// DailyRateLimitMiddleware 每日请求限制中间件
type DailyRateLimitMiddleware struct {
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
limitConfig DailyRateLimitConfig
}
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
func NewDailyRateLimitMiddleware(
cfg *config.Config,
redis *redis.Client,
cfg *config.Config,
redis *redis.Client,
response interfaces.ResponseBuilder,
logger *zap.Logger,
limitConfig DailyRateLimitConfig,
@@ -97,23 +97,23 @@ func (m *DailyRateLimitMiddleware) GetPriority() int {
func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 检查是否在排除路径中
if m.isExcludedPath(c.Request.URL.Path) {
c.Next()
return
}
// 检查是否在排除域名中
host := c.Request.Host
if m.isExcludedDomain(host) {
c.Next()
return
}
// 获取客户端标识
clientIP := m.getClientIP(c)
// 1. 检查IP白名单/黑名单
if err := m.checkIPAccess(clientIP); err != nil {
m.logger.Warn("IP访问被拒绝",
@@ -124,7 +124,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 2. 检查User-Agent
if err := m.checkUserAgent(c); err != nil {
m.logger.Warn("User-Agent被阻止",
@@ -136,7 +136,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 3. 检查Referer
if err := m.checkReferer(c); err != nil {
m.logger.Warn("Referer检查失败",
@@ -148,7 +148,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 4. 检查并发限制
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
m.logger.Warn("并发请求超限",
@@ -159,7 +159,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 5. 检查接口总请求次数限制
if err := m.checkTotalLimit(ctx); err != nil {
m.logger.Warn("接口总请求次数超限",
@@ -171,7 +171,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 6. 检查IP限制
if err := m.checkIPLimit(ctx, clientIP); err != nil {
m.logger.Warn("IP请求次数超限",
@@ -183,13 +183,13 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 7. 增加计数
m.incrementCounters(ctx, clientIP)
// 8. 添加隐藏的响应头(仅用于内部监控)
m.addHiddenHeaders(c, clientIP)
c.Next()
}
}
@@ -267,7 +267,7 @@ func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
}
}
}
// 检查白名单(如果启用)
if m.limitConfig.EnableIPWhitelist {
allowed := false
@@ -281,7 +281,7 @@ func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
return fmt.Errorf("IP %s 不在白名单中", clientIP)
}
}
return nil
}
@@ -301,7 +301,7 @@ func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
}
return true
}
// 精确匹配
return clientIP == pattern
}
@@ -311,19 +311,19 @@ func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
if !m.limitConfig.EnableUserAgent {
return nil
}
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
return fmt.Errorf("缺少User-Agent")
}
// 检查被阻止的User-Agent
for _, blocked := range m.limitConfig.BlockedUserAgents {
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
return fmt.Errorf("User-Agent被阻止: %s", blocked)
}
}
return nil
}
@@ -332,12 +332,12 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
if !m.limitConfig.EnableReferer {
return nil
}
referer := c.GetHeader("Referer")
if referer == "" {
return fmt.Errorf("缺少Referer")
}
// 检查允许的Referer
if len(m.limitConfig.AllowedReferers) > 0 {
allowed := false
@@ -351,41 +351,41 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
return fmt.Errorf("Referer不被允许: %s", referer)
}
}
return nil
}
// checkConcurrentLimit 检查并发限制
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
// 获取当前并发数
current, err := m.redis.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
return fmt.Errorf("获取并发计数失败: %w", err)
}
currentCount := 0
if current != "" {
if count, err := strconv.Atoi(current); err == nil {
currentCount = count
}
}
if currentCount >= m.limitConfig.MaxConcurrent {
return fmt.Errorf("并发请求超限: %d", currentCount)
}
// 增加并发计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
_, err = pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
}
return nil
}
@@ -395,15 +395,15 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
if m.limitConfig.EnableProxyCheck {
// 检查常见的代理头部
proxyHeaders := []string{
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
}
for _, header := range proxyHeaders {
if ip := c.GetHeader(header); ip != "" {
// 如果X-Forwarded-For包含多个IP取第一个
@@ -414,7 +414,7 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
}
}
}
// 回退到标准方法
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
if strings.Contains(xff, ",") {
@@ -422,43 +422,43 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
}
return xff
}
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// checkTotalLimit 检查接口总请求次数限制
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取总请求计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerDay {
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
}
return nil
}
// checkIPLimit 检查IP限制
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取IP计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerIP {
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
}
return nil
}
@@ -467,7 +467,7 @@ func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, client
// 增加总请求计数
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
m.incrementCounter(ctx, totalKey)
// 增加IP计数
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
m.incrementCounter(ctx, ipKey)
@@ -482,12 +482,12 @@ func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (
}
return 0, err
}
count, err := strconv.Atoi(val)
if err != nil {
return 0, fmt.Errorf("解析计数失败: %w", err)
}
return count, nil
}
@@ -497,7 +497,7 @@ func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key str
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, m.limitConfig.TTL)
_, err := pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
@@ -512,14 +512,14 @@ func (m *DailyRateLimitMiddleware) getDateKey() string {
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
ctx := c.Request.Context()
// 添加隐藏的监控头(客户端看不到)
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
totalCount, _ := m.getCounter(ctx, totalKey)
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
ipCount, _ := m.getCounter(ctx, ipKey)
// 使用非标准的头部名称,避免被客户端识别
c.Header("X-System-Status", "normal")
c.Header("X-Total-Count", strconv.Itoa(totalCount))