This commit is contained in:
Mrx
2026-03-20 13:24:45 +08:00
parent 3779a7d66d
commit 521bfeb4ef
15 changed files with 530 additions and 40 deletions

View File

@@ -5,25 +5,32 @@ import (
"sync"
"time"
"tyapi-server/internal/config"
securityEntities "tyapi-server/internal/domains/security/entities"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"golang.org/x/time/rate"
"gorm.io/gorm"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
response interfaces.ResponseBuilder
db *gorm.DB
logger *zap.Logger
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder, db *gorm.DB, logger *zap.Logger) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
response: response,
db: db,
logger: logger,
limiters: make(map[string]*rate.Limiter),
}
}
@@ -49,6 +56,8 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
// 检查是否允许请求
if !limiter.Allow() {
m.recordSuspiciousRequest(c, clientID, "rate_limit")
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
@@ -68,6 +77,28 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
}
}
func (m *RateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
if m.db == nil {
return
}
windowSeconds := int(m.config.RateLimit.Window.Seconds())
if windowSeconds <= 0 {
windowSeconds = 1
}
record := securityEntities.SuspiciousIPRecord{
IP: ip,
Path: c.Request.URL.Path,
Method: c.Request.Method,
RequestCount: 1,
WindowSeconds: windowSeconds,
TriggerReason: reason,
UserAgent: c.GetHeader("User-Agent"),
}
if err := m.db.Create(&record).Error; err != nil && m.logger != nil {
m.logger.Warn("记录可疑IP失败", zap.String("ip", ip), zap.String("path", record.Path), zap.Error(err))
}
}
// IsGlobal 是否为全局中间件
func (m *RateLimitMiddleware) IsGlobal() bool {
return true