Files
tyapi-server/internal/shared/middleware/ratelimit.go

167 lines
3.9 KiB
Go
Raw Normal View History

package middleware
import (
"fmt"
"sync"
"time"
"tyapi-server/internal/config"
2025-07-02 16:17:59 +08:00
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
2025-07-02 16:17:59 +08:00
response interfaces.ResponseBuilder
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
2025-07-02 16:17:59 +08:00
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
2025-07-02 16:17:59 +08:00
response: response,
limiters: make(map[string]*rate.Limiter),
}
}
// GetName 返回中间件名称
func (m *RateLimitMiddleware) GetName() string {
return "ratelimit"
}
// GetPriority 返回中间件优先级
func (m *RateLimitMiddleware) GetPriority() int {
return 90 // 高优先级
}
// Handle 返回中间件处理函数
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端标识IP地址
clientID := m.getClientID(c)
// 获取或创建限流器
limiter := m.getLimiter(clientID)
// 检查是否允许请求
if !limiter.Allow() {
2025-07-02 16:17:59 +08:00
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Header("Retry-After", "60")
2025-07-02 16:17:59 +08:00
// 使用统一的响应格式
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
c.Abort()
return
}
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RateLimitMiddleware) IsGlobal() bool {
return true
}
// getClientID 获取客户端标识
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
// 优先使用X-Forwarded-For头部
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
return xff
}
// 使用X-Real-IP头部
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
// 使用RemoteAddr
return c.ClientIP()
}
// getLimiter 获取或创建限流器
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
m.mutex.RLock()
limiter, exists := m.limiters[clientID]
m.mutex.RUnlock()
if exists {
return limiter
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 双重检查
if limiter, exists := m.limiters[clientID]; exists {
return limiter
}
// 创建新的限流器
// rate.Every计算每个请求之间的间隔
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
m.limiters[clientID] = limiter
// 启动清理协程(仅第一次创建时)
if len(m.limiters) == 1 {
go m.cleanupRoutine()
}
return limiter
}
// cleanupRoutine 定期清理不活跃的限流器
func (m *RateLimitMiddleware) cleanupRoutine() {
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanup()
}
}
}
// cleanup 清理不活跃的限流器
func (m *RateLimitMiddleware) cleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
for clientID, limiter := range m.limiters {
// 如果限流器在过去1小时内没有被使用则删除它
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
delete(m.limiters, clientID)
}
}
}
// GetStats 获取限流统计
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"active_limiters": len(m.limiters),
"rate_limit": map[string]interface{}{
"requests": m.config.RateLimit.Requests,
"window": m.config.RateLimit.Window,
"burst": m.config.RateLimit.Burst,
},
}
}