2025-06-30 19:21:56 +08:00
|
|
|
|
package middleware
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"fmt"
|
2025-11-22 15:43:24 +08:00
|
|
|
|
"strings"
|
2025-06-30 19:21:56 +08:00
|
|
|
|
"sync"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"tyapi-server/internal/config"
|
2025-07-02 16:17:59 +08:00
|
|
|
|
"tyapi-server/internal/shared/interfaces"
|
2025-06-30 19:21:56 +08:00
|
|
|
|
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
"golang.org/x/time/rate"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// RateLimitMiddleware 限流中间件
|
|
|
|
|
|
type RateLimitMiddleware struct {
|
|
|
|
|
|
config *config.Config
|
2025-07-02 16:17:59 +08:00
|
|
|
|
response interfaces.ResponseBuilder
|
2025-06-30 19:21:56 +08:00
|
|
|
|
limiters map[string]*rate.Limiter
|
|
|
|
|
|
mutex sync.RWMutex
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// NewRateLimitMiddleware 创建限流中间件
|
2025-07-02 16:17:59 +08:00
|
|
|
|
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
|
2025-06-30 19:21:56 +08:00
|
|
|
|
return &RateLimitMiddleware{
|
|
|
|
|
|
config: cfg,
|
2025-07-02 16:17:59 +08:00
|
|
|
|
response: response,
|
2025-06-30 19:21:56 +08:00
|
|
|
|
limiters: make(map[string]*rate.Limiter),
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetName 返回中间件名称
|
|
|
|
|
|
func (m *RateLimitMiddleware) GetName() string {
|
|
|
|
|
|
return "ratelimit"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetPriority 返回中间件优先级
|
|
|
|
|
|
func (m *RateLimitMiddleware) GetPriority() int {
|
|
|
|
|
|
return 90 // 高优先级
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Handle 返回中间件处理函数
|
|
|
|
|
|
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
|
|
|
|
|
return func(c *gin.Context) {
|
2025-11-22 15:43:24 +08:00
|
|
|
|
// 检查是否在排除域名中
|
|
|
|
|
|
host := c.Request.Host
|
|
|
|
|
|
// 移除端口部分
|
|
|
|
|
|
if idx := strings.Index(host, ":"); idx != -1 {
|
|
|
|
|
|
host = host[:idx]
|
|
|
|
|
|
}
|
|
|
|
|
|
if m.isExcludedDomain(host) {
|
|
|
|
|
|
c.Next()
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 检查是否在排除路径中
|
|
|
|
|
|
if m.isExcludedPath(c.Request.URL.Path) {
|
|
|
|
|
|
c.Next()
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-06-30 19:21:56 +08:00
|
|
|
|
// 获取客户端标识(IP地址)
|
|
|
|
|
|
clientID := m.getClientID(c)
|
|
|
|
|
|
|
|
|
|
|
|
// 获取或创建限流器
|
|
|
|
|
|
limiter := m.getLimiter(clientID)
|
|
|
|
|
|
|
|
|
|
|
|
// 检查是否允许请求
|
|
|
|
|
|
if !limiter.Allow() {
|
2025-07-02 16:17:59 +08:00
|
|
|
|
// 添加限流头部信息
|
2025-06-30 19:21:56 +08:00
|
|
|
|
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
|
|
|
|
|
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
|
|
|
|
|
c.Header("Retry-After", "60")
|
|
|
|
|
|
|
2025-07-02 16:17:59 +08:00
|
|
|
|
// 使用统一的响应格式
|
|
|
|
|
|
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
|
2025-06-30 19:21:56 +08:00
|
|
|
|
c.Abort()
|
|
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 添加限流头部信息
|
|
|
|
|
|
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
|
|
|
|
|
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
|
|
|
|
|
|
|
|
|
|
|
c.Next()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// IsGlobal 是否为全局中间件
|
|
|
|
|
|
func (m *RateLimitMiddleware) IsGlobal() bool {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getClientID 获取客户端标识
|
|
|
|
|
|
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
|
|
|
|
|
|
// 优先使用X-Forwarded-For头部
|
|
|
|
|
|
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
|
|
|
|
|
return xff
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 使用X-Real-IP头部
|
|
|
|
|
|
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
|
|
|
|
|
return xri
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 使用RemoteAddr
|
|
|
|
|
|
return c.ClientIP()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getLimiter 获取或创建限流器
|
|
|
|
|
|
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
|
|
|
|
|
|
m.mutex.RLock()
|
|
|
|
|
|
limiter, exists := m.limiters[clientID]
|
|
|
|
|
|
m.mutex.RUnlock()
|
|
|
|
|
|
|
|
|
|
|
|
if exists {
|
|
|
|
|
|
return limiter
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
m.mutex.Lock()
|
|
|
|
|
|
defer m.mutex.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
// 双重检查
|
|
|
|
|
|
if limiter, exists := m.limiters[clientID]; exists {
|
|
|
|
|
|
return limiter
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 创建新的限流器
|
|
|
|
|
|
// rate.Every计算每个请求之间的间隔
|
|
|
|
|
|
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
|
|
|
|
|
|
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
|
|
|
|
|
|
|
|
|
|
|
|
m.limiters[clientID] = limiter
|
|
|
|
|
|
|
|
|
|
|
|
// 启动清理协程(仅第一次创建时)
|
|
|
|
|
|
if len(m.limiters) == 1 {
|
|
|
|
|
|
go m.cleanupRoutine()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return limiter
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// cleanupRoutine 定期清理不活跃的限流器
|
|
|
|
|
|
func (m *RateLimitMiddleware) cleanupRoutine() {
|
|
|
|
|
|
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
|
|
|
|
|
|
defer ticker.Stop()
|
|
|
|
|
|
|
|
|
|
|
|
for {
|
|
|
|
|
|
select {
|
|
|
|
|
|
case <-ticker.C:
|
|
|
|
|
|
m.cleanup()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// cleanup 清理不活跃的限流器
|
|
|
|
|
|
func (m *RateLimitMiddleware) cleanup() {
|
|
|
|
|
|
m.mutex.Lock()
|
|
|
|
|
|
defer m.mutex.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
now := time.Now()
|
|
|
|
|
|
for clientID, limiter := range m.limiters {
|
|
|
|
|
|
// 如果限流器在过去1小时内没有被使用,则删除它
|
|
|
|
|
|
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
|
|
|
|
|
|
delete(m.limiters, clientID)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-11-22 15:43:24 +08:00
|
|
|
|
// isExcludedDomain 检查域名是否在排除列表中
|
|
|
|
|
|
func (m *RateLimitMiddleware) isExcludedDomain(host string) bool {
|
|
|
|
|
|
for _, excludeDomain := range m.config.RateLimit.ExcludeDomains {
|
|
|
|
|
|
// 支持通配符匹配
|
|
|
|
|
|
if strings.HasPrefix(excludeDomain, "*") {
|
|
|
|
|
|
// 后缀匹配,如 "*.api.example.com" 匹配 "api.example.com"
|
|
|
|
|
|
if strings.HasSuffix(host, excludeDomain[1:]) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if strings.HasSuffix(excludeDomain, "*") {
|
|
|
|
|
|
// 前缀匹配,如 "api.*" 匹配 "api.example.com"
|
|
|
|
|
|
if strings.HasPrefix(host, excludeDomain[:len(excludeDomain)-1]) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// 精确匹配
|
|
|
|
|
|
if host == excludeDomain {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// isExcludedPath 检查路径是否在排除列表中
|
|
|
|
|
|
func (m *RateLimitMiddleware) isExcludedPath(path string) bool {
|
|
|
|
|
|
for _, excludePath := range m.config.RateLimit.ExcludePaths {
|
|
|
|
|
|
// 支持多种匹配模式
|
|
|
|
|
|
if strings.HasPrefix(excludePath, "*") {
|
|
|
|
|
|
// 前缀匹配,如 "*api_name" 匹配 "/api/v1/any_api_name"
|
|
|
|
|
|
if strings.Contains(path, excludePath[1:]) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if strings.HasSuffix(excludePath, "*") {
|
|
|
|
|
|
// 后缀匹配,如 "/api/v1/*" 匹配 "/api/v1/any_api_name"
|
|
|
|
|
|
if strings.HasPrefix(path, excludePath[:len(excludePath)-1]) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
} else if strings.Contains(excludePath, "*") {
|
|
|
|
|
|
// 中间通配符匹配,如 "/api/v1/*api_name" 匹配 "/api/v1/any_api_name"
|
|
|
|
|
|
parts := strings.Split(excludePath, "*")
|
|
|
|
|
|
if len(parts) == 2 {
|
|
|
|
|
|
prefix := parts[0]
|
|
|
|
|
|
suffix := parts[1]
|
|
|
|
|
|
if strings.HasPrefix(path, prefix) && strings.HasSuffix(path, suffix) {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
} else {
|
|
|
|
|
|
// 精确匹配
|
|
|
|
|
|
if path == excludePath {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
return false
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-06-30 19:21:56 +08:00
|
|
|
|
// GetStats 获取限流统计
|
|
|
|
|
|
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
|
|
|
|
|
|
m.mutex.RLock()
|
|
|
|
|
|
defer m.mutex.RUnlock()
|
|
|
|
|
|
|
|
|
|
|
|
return map[string]interface{}{
|
|
|
|
|
|
"active_limiters": len(m.limiters),
|
|
|
|
|
|
"rate_limit": map[string]interface{}{
|
|
|
|
|
|
"requests": m.config.RateLimit.Requests,
|
|
|
|
|
|
"window": m.config.RateLimit.Window,
|
|
|
|
|
|
"burst": m.config.RateLimit.Burst,
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|