add
This commit is contained in:
@@ -5,25 +5,32 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
"tyapi-server/internal/config"
|
||||
securityEntities "tyapi-server/internal/domains/security/entities"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
response interfaces.ResponseBuilder
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder, db *gorm.DB, logger *zap.Logger) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
response: response,
|
||||
db: db,
|
||||
logger: logger,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
@@ -49,6 +56,8 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
m.recordSuspiciousRequest(c, clientID, "rate_limit")
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
@@ -68,6 +77,28 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
|
||||
if m.db == nil {
|
||||
return
|
||||
}
|
||||
windowSeconds := int(m.config.RateLimit.Window.Seconds())
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 1
|
||||
}
|
||||
record := securityEntities.SuspiciousIPRecord{
|
||||
IP: ip,
|
||||
Path: c.Request.URL.Path,
|
||||
Method: c.Request.Method,
|
||||
RequestCount: 1,
|
||||
WindowSeconds: windowSeconds,
|
||||
TriggerReason: reason,
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
if err := m.db.Create(&record).Error; err != nil && m.logger != nil {
|
||||
m.logger.Warn("记录可疑IP失败", zap.String("ip", ip), zap.String("path", record.Path), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RateLimitMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
|
||||
Reference in New Issue
Block a user