package middleware import ( "context" "fmt" "strconv" "strings" "time" "tyapi-server/internal/config" "tyapi-server/internal/shared/interfaces" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // DailyRateLimitConfig 每日限流配置 type DailyRateLimitConfig struct { MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数 MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数 KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀 TTL time.Duration `mapstructure:"ttl"` // 键过期时间 // 新增安全配置 EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单 IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单 EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单 IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单 EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止 BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区 EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理 MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数 } // DailyRateLimitMiddleware 每日请求限制中间件 type DailyRateLimitMiddleware struct { config *config.Config redis *redis.Client response interfaces.ResponseBuilder logger *zap.Logger limitConfig DailyRateLimitConfig } // NewDailyRateLimitMiddleware 创建每日请求限制中间件 func NewDailyRateLimitMiddleware( cfg *config.Config, redis *redis.Client, response interfaces.ResponseBuilder, logger *zap.Logger, limitConfig DailyRateLimitConfig, ) *DailyRateLimitMiddleware { // 设置默认值 if limitConfig.MaxRequestsPerDay <= 0 { limitConfig.MaxRequestsPerDay = 200 // 默认每日200次 } if limitConfig.MaxRequestsPerIP <= 0 { limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次 } if limitConfig.KeyPrefix == "" { limitConfig.KeyPrefix = "daily_limit" } if limitConfig.TTL == 0 { limitConfig.TTL = 24 * time.Hour // 默认24小时过期 } if limitConfig.MaxConcurrent <= 0 { limitConfig.MaxConcurrent = 5 // 默认最大并发5个 } return &DailyRateLimitMiddleware{ config: cfg, redis: redis, response: response, logger: logger, limitConfig: limitConfig, } } // GetName 返回中间件名称 func (m *DailyRateLimitMiddleware) GetName() string { return "daily_rate_limit" } // GetPriority 返回中间件优先级 func (m *DailyRateLimitMiddleware) GetPriority() int { return 85 // 在认证之后,业务处理之前 } // Handle 返回中间件处理函数 func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { ctx := c.Request.Context() // 获取客户端标识 clientIP := m.getClientIP(c) // 1. 检查IP白名单/黑名单 if err := m.checkIPAccess(clientIP); err != nil { m.logger.Warn("IP访问被拒绝", zap.String("ip", clientIP), zap.String("request_id", c.GetString("request_id")), zap.Error(err)) m.response.Forbidden(c, "访问被拒绝") c.Abort() return } // 2. 检查User-Agent if err := m.checkUserAgent(c); err != nil { m.logger.Warn("User-Agent被阻止", zap.String("ip", clientIP), zap.String("user_agent", c.GetHeader("User-Agent")), zap.String("request_id", c.GetString("request_id")), zap.Error(err)) m.response.Forbidden(c, "访问被拒绝") c.Abort() return } // 3. 检查Referer if err := m.checkReferer(c); err != nil { m.logger.Warn("Referer检查失败", zap.String("ip", clientIP), zap.String("referer", c.GetHeader("Referer")), zap.String("request_id", c.GetString("request_id")), zap.Error(err)) m.response.Forbidden(c, "访问被拒绝") c.Abort() return } // 4. 检查并发限制 if err := m.checkConcurrentLimit(ctx, clientIP); err != nil { m.logger.Warn("并发请求超限", zap.String("ip", clientIP), zap.String("request_id", c.GetString("request_id")), zap.Error(err)) m.response.TooManyRequests(c, "系统繁忙,请稍后再试") c.Abort() return } // 5. 检查接口总请求次数限制 if err := m.checkTotalLimit(ctx); err != nil { m.logger.Warn("接口总请求次数超限", zap.String("ip", clientIP), zap.String("request_id", c.GetString("request_id")), zap.Error(err)) // 隐藏限制信息,返回通用错误 m.response.InternalError(c, "系统繁忙,请稍后再试") c.Abort() return } // 6. 检查IP限制 if err := m.checkIPLimit(ctx, clientIP); err != nil { m.logger.Warn("IP请求次数超限", zap.String("ip", clientIP), zap.String("request_id", c.GetString("request_id")), zap.Error(err)) // 隐藏限制信息,返回通用错误 m.response.InternalError(c, "系统繁忙,请稍后再试") c.Abort() return } // 7. 增加计数 m.incrementCounters(ctx, clientIP) // 8. 添加隐藏的响应头(仅用于内部监控) m.addHiddenHeaders(c, clientIP) c.Next() } } // IsGlobal 是否为全局中间件 func (m *DailyRateLimitMiddleware) IsGlobal() bool { return false // 不是全局中间件,需要手动应用到特定路由 } // checkIPAccess 检查IP访问权限 func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error { // 检查黑名单 if m.limitConfig.EnableIPBlacklist { for _, blockedIP := range m.limitConfig.IPBlacklist { if m.isIPMatch(clientIP, blockedIP) { return fmt.Errorf("IP %s 在黑名单中", clientIP) } } } // 检查白名单(如果启用) if m.limitConfig.EnableIPWhitelist { allowed := false for _, allowedIP := range m.limitConfig.IPWhitelist { if m.isIPMatch(clientIP, allowedIP) { allowed = true break } } if !allowed { return fmt.Errorf("IP %s 不在白名单中", clientIP) } } return nil } // isIPMatch 检查IP是否匹配(支持CIDR和通配符) func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool { // 简单的通配符匹配 if strings.Contains(pattern, "*") { parts := strings.Split(pattern, ".") clientParts := strings.Split(clientIP, ".") if len(parts) != len(clientParts) { return false } for i, part := range parts { if part != "*" && part != clientParts[i] { return false } } return true } // 精确匹配 return clientIP == pattern } // checkUserAgent 检查User-Agent func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error { if !m.limitConfig.EnableUserAgent { return nil } userAgent := c.GetHeader("User-Agent") if userAgent == "" { return fmt.Errorf("缺少User-Agent") } // 检查被阻止的User-Agent for _, blocked := range m.limitConfig.BlockedUserAgents { if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) { return fmt.Errorf("User-Agent被阻止: %s", blocked) } } return nil } // checkReferer 检查Referer func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error { if !m.limitConfig.EnableReferer { return nil } referer := c.GetHeader("Referer") if referer == "" { return fmt.Errorf("缺少Referer") } // 检查允许的Referer if len(m.limitConfig.AllowedReferers) > 0 { allowed := false for _, allowedRef := range m.limitConfig.AllowedReferers { if strings.Contains(referer, allowedRef) { allowed = true break } } if !allowed { return fmt.Errorf("Referer不被允许: %s", referer) } } return nil } // checkConcurrentLimit 检查并发限制 func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error { key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP) // 获取当前并发数 current, err := m.redis.Get(ctx, key).Result() if err != nil && err != redis.Nil { return fmt.Errorf("获取并发计数失败: %w", err) } currentCount := 0 if current != "" { if count, err := strconv.Atoi(current); err == nil { currentCount = count } } if currentCount >= m.limitConfig.MaxConcurrent { return fmt.Errorf("并发请求超限: %d", currentCount) } // 增加并发计数 pipe := m.redis.Pipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, 30*time.Second) // 30秒过期 _, err = pipe.Exec(ctx) if err != nil { m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err)) } return nil } // getClientIP 获取客户端IP地址(增强版) func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string { // 检查是否为代理IP if m.limitConfig.EnableProxyCheck { // 检查常见的代理头部 proxyHeaders := []string{ "CF-Connecting-IP", // Cloudflare "X-Forwarded-For", // 标准代理头 "X-Real-IP", // Nginx "X-Client-IP", // Apache "X-Forwarded", // 其他代理 "Forwarded-For", // RFC 7239 "Forwarded", // RFC 7239 } for _, header := range proxyHeaders { if ip := c.GetHeader(header); ip != "" { // 如果X-Forwarded-For包含多个IP,取第一个 if header == "X-Forwarded-For" && strings.Contains(ip, ",") { ip = strings.TrimSpace(strings.Split(ip, ",")[0]) } return ip } } } // 回退到标准方法 if xff := c.GetHeader("X-Forwarded-For"); xff != "" { if strings.Contains(xff, ",") { return strings.TrimSpace(strings.Split(xff, ",")[0]) } return xff } if xri := c.GetHeader("X-Real-IP"); xri != "" { return xri } return c.ClientIP() } // checkTotalLimit 检查接口总请求次数限制 func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error { key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) count, err := m.getCounter(ctx, key) if err != nil { return fmt.Errorf("获取总请求计数失败: %w", err) } if count >= m.limitConfig.MaxRequestsPerDay { return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay) } return nil } // checkIPLimit 检查IP限制 func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error { key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) count, err := m.getCounter(ctx, key) if err != nil { return fmt.Errorf("获取IP计数失败: %w", err) } if count >= m.limitConfig.MaxRequestsPerIP { return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP) } return nil } // incrementCounters 增加计数器 func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) { // 增加总请求计数 totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) m.incrementCounter(ctx, totalKey) // 增加IP计数 ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) m.incrementCounter(ctx, ipKey) } // getCounter 获取计数器值 func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) { val, err := m.redis.Get(ctx, key).Result() if err != nil { if err == redis.Nil { return 0, nil // 键不存在,计数为0 } return 0, err } count, err := strconv.Atoi(val) if err != nil { return 0, fmt.Errorf("解析计数失败: %w", err) } return count, nil } // incrementCounter 增加计数器 func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) { // 使用Redis的INCR命令增加计数 pipe := m.redis.Pipeline() pipe.Incr(ctx, key) pipe.Expire(ctx, key, m.limitConfig.TTL) _, err := pipe.Exec(ctx) if err != nil { m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err)) } } // getDateKey 获取日期键(格式:2024-01-01) func (m *DailyRateLimitMiddleware) getDateKey() string { return time.Now().Format("2006-01-02") } // addHiddenHeaders 添加隐藏的响应头(仅用于内部监控) func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) { ctx := c.Request.Context() // 添加隐藏的监控头(客户端看不到) totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) totalCount, _ := m.getCounter(ctx, totalKey) ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) ipCount, _ := m.getCounter(ctx, ipKey) // 使用非标准的头部名称,避免被客户端识别 c.Header("X-System-Status", "normal") c.Header("X-Total-Count", strconv.Itoa(totalCount)) c.Header("X-IP-Count", strconv.Itoa(ipCount)) c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339)) } // getResetTime 获取重置时间(明天0点) func (m *DailyRateLimitMiddleware) getResetTime() time.Time { now := time.Now() tomorrow := now.Add(24 * time.Hour) return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location()) } // GetStats 获取限流统计 func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} { return map[string]interface{}{ "max_requests_per_day": m.limitConfig.MaxRequestsPerDay, "max_requests_per_ip": m.limitConfig.MaxRequestsPerIP, "max_concurrent": m.limitConfig.MaxConcurrent, "key_prefix": m.limitConfig.KeyPrefix, "ttl": m.limitConfig.TTL.String(), "security_features": map[string]interface{}{ "ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist, "ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist, "user_agent_check": m.limitConfig.EnableUserAgent, "referer_check": m.limitConfig.EnableReferer, "proxy_check": m.limitConfig.EnableProxyCheck, }, } }