167 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			167 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | ||
| 
 | ||
| import (
 | ||
| 	"fmt"
 | ||
| 	"sync"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"tyapi-server/internal/config"
 | ||
| 	"tyapi-server/internal/shared/interfaces"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| 	"golang.org/x/time/rate"
 | ||
| )
 | ||
| 
 | ||
| // RateLimitMiddleware 限流中间件
 | ||
| type RateLimitMiddleware struct {
 | ||
| 	config   *config.Config
 | ||
| 	response interfaces.ResponseBuilder
 | ||
| 	limiters map[string]*rate.Limiter
 | ||
| 	mutex    sync.RWMutex
 | ||
| }
 | ||
| 
 | ||
| // NewRateLimitMiddleware 创建限流中间件
 | ||
| func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
 | ||
| 	return &RateLimitMiddleware{
 | ||
| 		config:   cfg,
 | ||
| 		response: response,
 | ||
| 		limiters: make(map[string]*rate.Limiter),
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *RateLimitMiddleware) GetName() string {
 | ||
| 	return "ratelimit"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *RateLimitMiddleware) GetPriority() int {
 | ||
| 	return 90 // 高优先级
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 获取客户端标识(IP地址)
 | ||
| 		clientID := m.getClientID(c)
 | ||
| 
 | ||
| 		// 获取或创建限流器
 | ||
| 		limiter := m.getLimiter(clientID)
 | ||
| 
 | ||
| 		// 检查是否允许请求
 | ||
| 		if !limiter.Allow() {
 | ||
| 			// 添加限流头部信息
 | ||
| 			c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
 | ||
| 			c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
 | ||
| 			c.Header("Retry-After", "60")
 | ||
| 
 | ||
| 			// 使用统一的响应格式
 | ||
| 			m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
 | ||
| 			c.Abort()
 | ||
| 			return
 | ||
| 		}
 | ||
| 
 | ||
| 		// 添加限流头部信息
 | ||
| 		c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
 | ||
| 		c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *RateLimitMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // getClientID 获取客户端标识
 | ||
| func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
 | ||
| 	// 优先使用X-Forwarded-For头部
 | ||
| 	if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
 | ||
| 		return xff
 | ||
| 	}
 | ||
| 
 | ||
| 	// 使用X-Real-IP头部
 | ||
| 	if xri := c.GetHeader("X-Real-IP"); xri != "" {
 | ||
| 		return xri
 | ||
| 	}
 | ||
| 
 | ||
| 	// 使用RemoteAddr
 | ||
| 	return c.ClientIP()
 | ||
| }
 | ||
| 
 | ||
| // getLimiter 获取或创建限流器
 | ||
| func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
 | ||
| 	m.mutex.RLock()
 | ||
| 	limiter, exists := m.limiters[clientID]
 | ||
| 	m.mutex.RUnlock()
 | ||
| 
 | ||
| 	if exists {
 | ||
| 		return limiter
 | ||
| 	}
 | ||
| 
 | ||
| 	m.mutex.Lock()
 | ||
| 	defer m.mutex.Unlock()
 | ||
| 
 | ||
| 	// 双重检查
 | ||
| 	if limiter, exists := m.limiters[clientID]; exists {
 | ||
| 		return limiter
 | ||
| 	}
 | ||
| 
 | ||
| 	// 创建新的限流器
 | ||
| 	// rate.Every计算每个请求之间的间隔
 | ||
| 	rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
 | ||
| 	limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
 | ||
| 
 | ||
| 	m.limiters[clientID] = limiter
 | ||
| 
 | ||
| 	// 启动清理协程(仅第一次创建时)
 | ||
| 	if len(m.limiters) == 1 {
 | ||
| 		go m.cleanupRoutine()
 | ||
| 	}
 | ||
| 
 | ||
| 	return limiter
 | ||
| }
 | ||
| 
 | ||
| // cleanupRoutine 定期清理不活跃的限流器
 | ||
| func (m *RateLimitMiddleware) cleanupRoutine() {
 | ||
| 	ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
 | ||
| 	defer ticker.Stop()
 | ||
| 
 | ||
| 	for {
 | ||
| 		select {
 | ||
| 		case <-ticker.C:
 | ||
| 			m.cleanup()
 | ||
| 		}
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // cleanup 清理不活跃的限流器
 | ||
| func (m *RateLimitMiddleware) cleanup() {
 | ||
| 	m.mutex.Lock()
 | ||
| 	defer m.mutex.Unlock()
 | ||
| 
 | ||
| 	now := time.Now()
 | ||
| 	for clientID, limiter := range m.limiters {
 | ||
| 		// 如果限流器在过去1小时内没有被使用,则删除它
 | ||
| 		if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
 | ||
| 			delete(m.limiters, clientID)
 | ||
| 		}
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetStats 获取限流统计
 | ||
| func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
 | ||
| 	m.mutex.RLock()
 | ||
| 	defer m.mutex.RUnlock()
 | ||
| 
 | ||
| 	return map[string]interface{}{
 | ||
| 		"active_limiters": len(m.limiters),
 | ||
| 		"rate_limit": map[string]interface{}{
 | ||
| 			"requests": m.config.RateLimit.Requests,
 | ||
| 			"window":   m.config.RateLimit.Window,
 | ||
| 			"burst":    m.config.RateLimit.Burst,
 | ||
| 		},
 | ||
| 	}
 | ||
| }
 |