package middleware import ( "fmt" "sync" "time" "tyapi-server/internal/config" "tyapi-server/internal/shared/interfaces" "github.com/gin-gonic/gin" "golang.org/x/time/rate" ) // RateLimitMiddleware 限流中间件 type RateLimitMiddleware struct { config *config.Config response interfaces.ResponseBuilder limiters map[string]*rate.Limiter mutex sync.RWMutex } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware { return &RateLimitMiddleware{ config: cfg, response: response, limiters: make(map[string]*rate.Limiter), } } // GetName 返回中间件名称 func (m *RateLimitMiddleware) GetName() string { return "ratelimit" } // GetPriority 返回中间件优先级 func (m *RateLimitMiddleware) GetPriority() int { return 90 // 高优先级 } // Handle 返回中间件处理函数 func (m *RateLimitMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { // 获取客户端标识(IP地址) clientID := m.getClientID(c) // 获取或创建限流器 limiter := m.getLimiter(clientID) // 检查是否允许请求 if !limiter.Allow() { // 添加限流头部信息 c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests)) c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String()) c.Header("Retry-After", "60") // 使用统一的响应格式 m.response.TooManyRequests(c, "请求过于频繁,请稍后再试") c.Abort() return } // 添加限流头部信息 c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests)) c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String()) c.Next() } } // IsGlobal 是否为全局中间件 func (m *RateLimitMiddleware) IsGlobal() bool { return true } // getClientID 获取客户端标识 func (m *RateLimitMiddleware) getClientID(c *gin.Context) string { // 优先使用X-Forwarded-For头部 if xff := c.GetHeader("X-Forwarded-For"); xff != "" { return xff } // 使用X-Real-IP头部 if xri := c.GetHeader("X-Real-IP"); xri != "" { return xri } // 使用RemoteAddr return c.ClientIP() } // getLimiter 获取或创建限流器 func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter { m.mutex.RLock() limiter, exists := m.limiters[clientID] m.mutex.RUnlock() if exists { return limiter } m.mutex.Lock() defer m.mutex.Unlock() // 双重检查 if limiter, exists := m.limiters[clientID]; exists { return limiter } // 创建新的限流器 // rate.Every计算每个请求之间的间隔 rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests)) limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst) m.limiters[clientID] = limiter // 启动清理协程(仅第一次创建时) if len(m.limiters) == 1 { go m.cleanupRoutine() } return limiter } // cleanupRoutine 定期清理不活跃的限流器 func (m *RateLimitMiddleware) cleanupRoutine() { ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次 defer ticker.Stop() for { select { case <-ticker.C: m.cleanup() } } } // cleanup 清理不活跃的限流器 func (m *RateLimitMiddleware) cleanup() { m.mutex.Lock() defer m.mutex.Unlock() now := time.Now() for clientID, limiter := range m.limiters { // 如果限流器在过去1小时内没有被使用,则删除它 if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour { delete(m.limiters, clientID) } } } // GetStats 获取限流统计 func (m *RateLimitMiddleware) GetStats() map[string]interface{} { m.mutex.RLock() defer m.mutex.RUnlock() return map[string]interface{}{ "active_limiters": len(m.limiters), "rate_limit": map[string]interface{}{ "requests": m.config.RateLimit.Requests, "window": m.config.RateLimit.Window, "burst": m.config.RateLimit.Burst, }, } }