Files
tyapi-server/internal/shared/middleware/ratelimit.go

167 lines
3.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
limiters: make(map[string]*rate.Limiter),
}
}
// GetName 返回中间件名称
func (m *RateLimitMiddleware) GetName() string {
return "ratelimit"
}
// GetPriority 返回中间件优先级
func (m *RateLimitMiddleware) GetPriority() int {
return 90 // 高优先级
}
// Handle 返回中间件处理函数
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端标识IP地址
clientID := m.getClientID(c)
// 获取或创建限流器
limiter := m.getLimiter(clientID)
// 检查是否允许请求
if !limiter.Allow() {
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Header("Retry-After", "60")
c.JSON(http.StatusTooManyRequests, gin.H{
"success": false,
"message": "Rate limit exceeded",
"error": "Too many requests",
})
c.Abort()
return
}
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RateLimitMiddleware) IsGlobal() bool {
return true
}
// getClientID 获取客户端标识
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
// 优先使用X-Forwarded-For头部
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
return xff
}
// 使用X-Real-IP头部
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
// 使用RemoteAddr
return c.ClientIP()
}
// getLimiter 获取或创建限流器
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
m.mutex.RLock()
limiter, exists := m.limiters[clientID]
m.mutex.RUnlock()
if exists {
return limiter
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 双重检查
if limiter, exists := m.limiters[clientID]; exists {
return limiter
}
// 创建新的限流器
// rate.Every计算每个请求之间的间隔
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
m.limiters[clientID] = limiter
// 启动清理协程(仅第一次创建时)
if len(m.limiters) == 1 {
go m.cleanupRoutine()
}
return limiter
}
// cleanupRoutine 定期清理不活跃的限流器
func (m *RateLimitMiddleware) cleanupRoutine() {
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanup()
}
}
}
// cleanup 清理不活跃的限流器
func (m *RateLimitMiddleware) cleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
for clientID, limiter := range m.limiters {
// 如果限流器在过去1小时内没有被使用则删除它
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
delete(m.limiters, clientID)
}
}
}
// GetStats 获取限流统计
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"active_limiters": len(m.limiters),
"rate_limit": map[string]interface{}{
"requests": m.config.RateLimit.Requests,
"window": m.config.RateLimit.Window,
"burst": m.config.RateLimit.Burst,
},
}
}