| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | package middleware | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import ( | 
					
						
							|  |  |  |  | 	"fmt" | 
					
						
							|  |  |  |  | 	"sync" | 
					
						
							|  |  |  |  | 	"time" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"tyapi-server/internal/config" | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 	"tyapi-server/internal/shared/interfaces" | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"github.com/gin-gonic/gin" | 
					
						
							|  |  |  |  | 	"golang.org/x/time/rate" | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // RateLimitMiddleware 限流中间件 | 
					
						
							|  |  |  |  | type RateLimitMiddleware struct { | 
					
						
							|  |  |  |  | 	config   *config.Config | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 	response interfaces.ResponseBuilder | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 	limiters map[string]*rate.Limiter | 
					
						
							|  |  |  |  | 	mutex    sync.RWMutex | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // NewRateLimitMiddleware 创建限流中间件 | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware { | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 	return &RateLimitMiddleware{ | 
					
						
							|  |  |  |  | 		config:   cfg, | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 		response: response, | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 		limiters: make(map[string]*rate.Limiter), | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetName 返回中间件名称 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) GetName() string { | 
					
						
							|  |  |  |  | 	return "ratelimit" | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetPriority 返回中间件优先级 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) GetPriority() int { | 
					
						
							|  |  |  |  | 	return 90 // 高优先级 | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // Handle 返回中间件处理函数 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) Handle() gin.HandlerFunc { | 
					
						
							|  |  |  |  | 	return func(c *gin.Context) { | 
					
						
							|  |  |  |  | 		// 获取客户端标识(IP地址) | 
					
						
							|  |  |  |  | 		clientID := m.getClientID(c) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 获取或创建限流器 | 
					
						
							|  |  |  |  | 		limiter := m.getLimiter(clientID) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 检查是否允许请求 | 
					
						
							|  |  |  |  | 		if !limiter.Allow() { | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			// 添加限流头部信息 | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 			c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests)) | 
					
						
							|  |  |  |  | 			c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String()) | 
					
						
							|  |  |  |  | 			c.Header("Retry-After", "60") | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-02 16:17:59 +08:00
										 |  |  |  | 			// 使用统一的响应格式 | 
					
						
							|  |  |  |  | 			m.response.TooManyRequests(c, "请求过于频繁,请稍后再试") | 
					
						
							| 
									
										
										
										
											2025-06-30 19:21:56 +08:00
										 |  |  |  | 			c.Abort() | 
					
						
							|  |  |  |  | 			return | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		// 添加限流头部信息 | 
					
						
							|  |  |  |  | 		c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests)) | 
					
						
							|  |  |  |  | 		c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		c.Next() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // IsGlobal 是否为全局中间件 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) IsGlobal() bool { | 
					
						
							|  |  |  |  | 	return true | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getClientID 获取客户端标识 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) getClientID(c *gin.Context) string { | 
					
						
							|  |  |  |  | 	// 优先使用X-Forwarded-For头部 | 
					
						
							|  |  |  |  | 	if xff := c.GetHeader("X-Forwarded-For"); xff != "" { | 
					
						
							|  |  |  |  | 		return xff | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 使用X-Real-IP头部 | 
					
						
							|  |  |  |  | 	if xri := c.GetHeader("X-Real-IP"); xri != "" { | 
					
						
							|  |  |  |  | 		return xri | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 使用RemoteAddr | 
					
						
							|  |  |  |  | 	return c.ClientIP() | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getLimiter 获取或创建限流器 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter { | 
					
						
							|  |  |  |  | 	m.mutex.RLock() | 
					
						
							|  |  |  |  | 	limiter, exists := m.limiters[clientID] | 
					
						
							|  |  |  |  | 	m.mutex.RUnlock() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	if exists { | 
					
						
							|  |  |  |  | 		return limiter | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	m.mutex.Lock() | 
					
						
							|  |  |  |  | 	defer m.mutex.Unlock() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 双重检查 | 
					
						
							|  |  |  |  | 	if limiter, exists := m.limiters[clientID]; exists { | 
					
						
							|  |  |  |  | 		return limiter | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 创建新的限流器 | 
					
						
							|  |  |  |  | 	// rate.Every计算每个请求之间的间隔 | 
					
						
							|  |  |  |  | 	rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests)) | 
					
						
							|  |  |  |  | 	limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	m.limiters[clientID] = limiter | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 启动清理协程(仅第一次创建时) | 
					
						
							|  |  |  |  | 	if len(m.limiters) == 1 { | 
					
						
							|  |  |  |  | 		go m.cleanupRoutine() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return limiter | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // cleanupRoutine 定期清理不活跃的限流器 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) cleanupRoutine() { | 
					
						
							|  |  |  |  | 	ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次 | 
					
						
							|  |  |  |  | 	defer ticker.Stop() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	for { | 
					
						
							|  |  |  |  | 		select { | 
					
						
							|  |  |  |  | 		case <-ticker.C: | 
					
						
							|  |  |  |  | 			m.cleanup() | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // cleanup 清理不活跃的限流器 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) cleanup() { | 
					
						
							|  |  |  |  | 	m.mutex.Lock() | 
					
						
							|  |  |  |  | 	defer m.mutex.Unlock() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	now := time.Now() | 
					
						
							|  |  |  |  | 	for clientID, limiter := range m.limiters { | 
					
						
							|  |  |  |  | 		// 如果限流器在过去1小时内没有被使用,则删除它 | 
					
						
							|  |  |  |  | 		if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour { | 
					
						
							|  |  |  |  | 			delete(m.limiters, clientID) | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetStats 获取限流统计 | 
					
						
							|  |  |  |  | func (m *RateLimitMiddleware) GetStats() map[string]interface{} { | 
					
						
							|  |  |  |  | 	m.mutex.RLock() | 
					
						
							|  |  |  |  | 	defer m.mutex.RUnlock() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return map[string]interface{}{ | 
					
						
							|  |  |  |  | 		"active_limiters": len(m.limiters), | 
					
						
							|  |  |  |  | 		"rate_limit": map[string]interface{}{ | 
					
						
							|  |  |  |  | 			"requests": m.config.RateLimit.Requests, | 
					
						
							|  |  |  |  | 			"window":   m.config.RateLimit.Window, | 
					
						
							|  |  |  |  | 			"burst":    m.config.RateLimit.Burst, | 
					
						
							|  |  |  |  | 		}, | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } |