add
This commit is contained in:
@@ -3,16 +3,19 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
securityEntities "tyapi-server/internal/domains/security/entities"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DailyRateLimitConfig 每日限流配置
|
||||
@@ -45,6 +48,7 @@ type DailyRateLimitConfig struct {
|
||||
type DailyRateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
response interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
limitConfig DailyRateLimitConfig
|
||||
@@ -54,6 +58,7 @@ type DailyRateLimitMiddleware struct {
|
||||
func NewDailyRateLimitMiddleware(
|
||||
cfg *config.Config,
|
||||
redis *redis.Client,
|
||||
db *gorm.DB,
|
||||
response interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
limitConfig DailyRateLimitConfig,
|
||||
@@ -78,6 +83,7 @@ func NewDailyRateLimitMiddleware(
|
||||
return &DailyRateLimitMiddleware{
|
||||
config: cfg,
|
||||
redis: redis,
|
||||
db: db,
|
||||
response: response,
|
||||
logger: logger,
|
||||
limitConfig: limitConfig,
|
||||
@@ -154,7 +160,9 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
// 4. 检查并发限制
|
||||
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
|
||||
concurrentCount, err := m.checkConcurrentLimit(ctx, clientIP)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
|
||||
m.logger.Warn("并发请求超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
@@ -163,9 +171,14 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(concurrentCount, m.limitConfig.MaxConcurrent) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
|
||||
}
|
||||
|
||||
// 5. 检查接口总请求次数限制
|
||||
if err := m.checkTotalLimit(ctx); err != nil {
|
||||
totalCount, err := m.checkTotalLimit(ctx)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
|
||||
m.logger.Warn("接口总请求次数超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
@@ -175,9 +188,14 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(totalCount+1, m.limitConfig.MaxRequestsPerDay) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
|
||||
}
|
||||
|
||||
// 6. 检查IP限制
|
||||
if err := m.checkIPLimit(ctx, clientIP); err != nil {
|
||||
ipCount, err := m.checkIPLimit(ctx, clientIP)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
|
||||
m.logger.Warn("IP请求次数超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
@@ -187,6 +205,9 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(ipCount+1, m.limitConfig.MaxRequestsPerIP) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
|
||||
}
|
||||
|
||||
// 7. 增加计数
|
||||
m.incrementCounters(ctx, clientIP)
|
||||
@@ -198,6 +219,38 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DailyRateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
|
||||
if m.db == nil {
|
||||
return
|
||||
}
|
||||
record := securityEntities.SuspiciousIPRecord{
|
||||
IP: ip,
|
||||
Path: c.Request.URL.Path,
|
||||
Method: c.Request.Method,
|
||||
RequestCount: 1,
|
||||
WindowSeconds: int(m.limitConfig.TTL.Seconds()),
|
||||
TriggerReason: reason,
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
if record.WindowSeconds <= 0 {
|
||||
record.WindowSeconds = 10
|
||||
}
|
||||
if err := m.db.Create(&record).Error; err != nil {
|
||||
m.logger.Warn("记录每日限流可疑IP失败", zap.String("ip", ip), zap.String("reason", reason), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DailyRateLimitMiddleware) shouldRecordNearLimit(current, max int) bool {
|
||||
if max <= 0 {
|
||||
return false
|
||||
}
|
||||
threshold := int(math.Ceil(float64(max) * 0.8))
|
||||
if threshold < 1 {
|
||||
threshold = 1
|
||||
}
|
||||
return current >= threshold
|
||||
}
|
||||
|
||||
// isExcludedDomain 检查域名是否在排除列表中
|
||||
func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool {
|
||||
for _, excludeDomain := range m.limitConfig.ExcludeDomains {
|
||||
@@ -360,13 +413,13 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
|
||||
}
|
||||
|
||||
// checkConcurrentLimit 检查并发限制
|
||||
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
|
||||
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) (int, error) {
|
||||
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
|
||||
|
||||
// 获取当前并发数
|
||||
current, err := m.redis.Get(ctx, key).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("获取并发计数失败: %w", err)
|
||||
return 0, fmt.Errorf("获取并发计数失败: %w", err)
|
||||
}
|
||||
|
||||
currentCount := 0
|
||||
@@ -377,7 +430,7 @@ func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, cli
|
||||
}
|
||||
|
||||
if currentCount >= m.limitConfig.MaxConcurrent {
|
||||
return fmt.Errorf("并发请求超限: %d", currentCount)
|
||||
return currentCount, fmt.Errorf("并发请求超限: %d", currentCount)
|
||||
}
|
||||
|
||||
// 增加并发计数
|
||||
@@ -390,7 +443,7 @@ func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, cli
|
||||
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
return currentCount + 1, nil
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP地址(增强版)
|
||||
@@ -435,35 +488,35 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
|
||||
}
|
||||
|
||||
// checkTotalLimit 检查接口总请求次数限制
|
||||
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
|
||||
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) (int, error) {
|
||||
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||||
|
||||
count, err := m.getCounter(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取总请求计数失败: %w", err)
|
||||
return 0, fmt.Errorf("获取总请求计数失败: %w", err)
|
||||
}
|
||||
|
||||
if count >= m.limitConfig.MaxRequestsPerDay {
|
||||
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
|
||||
return count, fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
|
||||
}
|
||||
|
||||
return nil
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// checkIPLimit 检查IP限制
|
||||
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
|
||||
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) (int, error) {
|
||||
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||||
|
||||
count, err := m.getCounter(ctx, key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取IP计数失败: %w", err)
|
||||
return 0, fmt.Errorf("获取IP计数失败: %w", err)
|
||||
}
|
||||
|
||||
if count >= m.limitConfig.MaxRequestsPerIP {
|
||||
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
|
||||
return count, fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
|
||||
}
|
||||
|
||||
return nil
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// incrementCounters 增加计数器
|
||||
|
||||
Reference in New Issue
Block a user