package middleware import ( "fmt" "strings" "sync" "time" "tyapi-server/internal/config" "tyapi-server/internal/shared/interfaces" "github.com/gin-gonic/gin" "golang.org/x/time/rate" ) // RateLimitMiddleware 限流中间件 type RateLimitMiddleware struct { config *config.Config response interfaces.ResponseBuilder limiters map[string]*rate.Limiter mutex sync.RWMutex } // NewRateLimitMiddleware 创建限流中间件 func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware { return &RateLimitMiddleware{ config: cfg, response: response, limiters: make(map[string]*rate.Limiter), } } // GetName 返回中间件名称 func (m *RateLimitMiddleware) GetName() string { return "ratelimit" } // GetPriority 返回中间件优先级 func (m *RateLimitMiddleware) GetPriority() int { return 90 // 高优先级 } // Handle 返回中间件处理函数 func (m *RateLimitMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { // 检查是否在排除域名中 host := c.Request.Host // 移除端口部分 if idx := strings.Index(host, ":"); idx != -1 { host = host[:idx] } if m.isExcludedDomain(host) { c.Next() return } // 检查是否在排除路径中 if m.isExcludedPath(c.Request.URL.Path) { c.Next() return } // 获取客户端标识(IP地址) clientID := m.getClientID(c) // 获取或创建限流器 limiter := m.getLimiter(clientID) // 检查是否允许请求 if !limiter.Allow() { // 添加限流头部信息 c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests)) c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String()) c.Header("Retry-After", "60") // 使用统一的响应格式 m.response.TooManyRequests(c, "请求过于频繁,请稍后再试") c.Abort() return } // 添加限流头部信息 c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests)) c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String()) c.Next() } } // IsGlobal 是否为全局中间件 func (m *RateLimitMiddleware) IsGlobal() bool { return true } // getClientID 获取客户端标识 func (m *RateLimitMiddleware) getClientID(c *gin.Context) string { // 优先使用X-Forwarded-For头部 if xff := c.GetHeader("X-Forwarded-For"); xff != "" { return xff } // 使用X-Real-IP头部 if xri := c.GetHeader("X-Real-IP"); xri != "" { return xri } // 使用RemoteAddr return c.ClientIP() } // getLimiter 获取或创建限流器 func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter { m.mutex.RLock() limiter, exists := m.limiters[clientID] m.mutex.RUnlock() if exists { return limiter } m.mutex.Lock() defer m.mutex.Unlock() // 双重检查 if limiter, exists := m.limiters[clientID]; exists { return limiter } // 创建新的限流器 // rate.Every计算每个请求之间的间隔 rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests)) limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst) m.limiters[clientID] = limiter // 启动清理协程(仅第一次创建时) if len(m.limiters) == 1 { go m.cleanupRoutine() } return limiter } // cleanupRoutine 定期清理不活跃的限流器 func (m *RateLimitMiddleware) cleanupRoutine() { ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次 defer ticker.Stop() for { select { case <-ticker.C: m.cleanup() } } } // cleanup 清理不活跃的限流器 func (m *RateLimitMiddleware) cleanup() { m.mutex.Lock() defer m.mutex.Unlock() now := time.Now() for clientID, limiter := range m.limiters { // 如果限流器在过去1小时内没有被使用,则删除它 if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour { delete(m.limiters, clientID) } } } // isExcludedDomain 检查域名是否在排除列表中 func (m *RateLimitMiddleware) isExcludedDomain(host string) bool { for _, excludeDomain := range m.config.RateLimit.ExcludeDomains { // 支持通配符匹配 if strings.HasPrefix(excludeDomain, "*") { // 后缀匹配,如 "*.api.example.com" 匹配 "api.example.com" if strings.HasSuffix(host, excludeDomain[1:]) { return true } } else if strings.HasSuffix(excludeDomain, "*") { // 前缀匹配,如 "api.*" 匹配 "api.example.com" if strings.HasPrefix(host, excludeDomain[:len(excludeDomain)-1]) { return true } } else { // 精确匹配 if host == excludeDomain { return true } } } return false } // isExcludedPath 检查路径是否在排除列表中 func (m *RateLimitMiddleware) isExcludedPath(path string) bool { for _, excludePath := range m.config.RateLimit.ExcludePaths { // 支持多种匹配模式 if strings.HasPrefix(excludePath, "*") { // 前缀匹配,如 "*api_name" 匹配 "/api/v1/any_api_name" if strings.Contains(path, excludePath[1:]) { return true } } else if strings.HasSuffix(excludePath, "*") { // 后缀匹配,如 "/api/v1/*" 匹配 "/api/v1/any_api_name" if strings.HasPrefix(path, excludePath[:len(excludePath)-1]) { return true } } else if strings.Contains(excludePath, "*") { // 中间通配符匹配,如 "/api/v1/*api_name" 匹配 "/api/v1/any_api_name" parts := strings.Split(excludePath, "*") if len(parts) == 2 { prefix := parts[0] suffix := parts[1] if strings.HasPrefix(path, prefix) && strings.HasSuffix(path, suffix) { return true } } } else { // 精确匹配 if path == excludePath { return true } } } return false } // GetStats 获取限流统计 func (m *RateLimitMiddleware) GetStats() map[string]interface{} { m.mutex.RLock() defer m.mutex.RUnlock() return map[string]interface{}{ "active_limiters": len(m.limiters), "rate_limit": map[string]interface{}{ "requests": m.config.RateLimit.Requests, "window": m.config.RateLimit.Window, "burst": m.config.RateLimit.Burst, }, } }