This commit is contained in:
Mrx
2026-03-20 13:24:45 +08:00
parent 3779a7d66d
commit 521bfeb4ef
15 changed files with 530 additions and 40 deletions

View File

@@ -3,16 +3,19 @@ package middleware
import (
"context"
"fmt"
"math"
"strconv"
"strings"
"time"
"tyapi-server/internal/config"
securityEntities "tyapi-server/internal/domains/security/entities"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"gorm.io/gorm"
)
// DailyRateLimitConfig 每日限流配置
@@ -45,6 +48,7 @@ type DailyRateLimitConfig struct {
type DailyRateLimitMiddleware struct {
config *config.Config
redis *redis.Client
db *gorm.DB
response interfaces.ResponseBuilder
logger *zap.Logger
limitConfig DailyRateLimitConfig
@@ -54,6 +58,7 @@ type DailyRateLimitMiddleware struct {
func NewDailyRateLimitMiddleware(
cfg *config.Config,
redis *redis.Client,
db *gorm.DB,
response interfaces.ResponseBuilder,
logger *zap.Logger,
limitConfig DailyRateLimitConfig,
@@ -78,6 +83,7 @@ func NewDailyRateLimitMiddleware(
return &DailyRateLimitMiddleware{
config: cfg,
redis: redis,
db: db,
response: response,
logger: logger,
limitConfig: limitConfig,
@@ -154,7 +160,9 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
}
// 4. 检查并发限制
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
concurrentCount, err := m.checkConcurrentLimit(ctx, clientIP)
if err != nil {
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
m.logger.Warn("并发请求超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
@@ -163,9 +171,14 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
if m.shouldRecordNearLimit(concurrentCount, m.limitConfig.MaxConcurrent) {
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
}
// 5. 检查接口总请求次数限制
if err := m.checkTotalLimit(ctx); err != nil {
totalCount, err := m.checkTotalLimit(ctx)
if err != nil {
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
m.logger.Warn("接口总请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
@@ -175,9 +188,14 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
if m.shouldRecordNearLimit(totalCount+1, m.limitConfig.MaxRequestsPerDay) {
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
}
// 6. 检查IP限制
if err := m.checkIPLimit(ctx, clientIP); err != nil {
ipCount, err := m.checkIPLimit(ctx, clientIP)
if err != nil {
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
m.logger.Warn("IP请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
@@ -187,6 +205,9 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
if m.shouldRecordNearLimit(ipCount+1, m.limitConfig.MaxRequestsPerIP) {
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
}
// 7. 增加计数
m.incrementCounters(ctx, clientIP)
@@ -198,6 +219,38 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
}
}
func (m *DailyRateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
if m.db == nil {
return
}
record := securityEntities.SuspiciousIPRecord{
IP: ip,
Path: c.Request.URL.Path,
Method: c.Request.Method,
RequestCount: 1,
WindowSeconds: int(m.limitConfig.TTL.Seconds()),
TriggerReason: reason,
UserAgent: c.GetHeader("User-Agent"),
}
if record.WindowSeconds <= 0 {
record.WindowSeconds = 10
}
if err := m.db.Create(&record).Error; err != nil {
m.logger.Warn("记录每日限流可疑IP失败", zap.String("ip", ip), zap.String("reason", reason), zap.Error(err))
}
}
func (m *DailyRateLimitMiddleware) shouldRecordNearLimit(current, max int) bool {
if max <= 0 {
return false
}
threshold := int(math.Ceil(float64(max) * 0.8))
if threshold < 1 {
threshold = 1
}
return current >= threshold
}
// isExcludedDomain 检查域名是否在排除列表中
func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool {
for _, excludeDomain := range m.limitConfig.ExcludeDomains {
@@ -360,13 +413,13 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
}
// checkConcurrentLimit 检查并发限制
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) (int, error) {
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
// 获取当前并发数
current, err := m.redis.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
return fmt.Errorf("获取并发计数失败: %w", err)
return 0, fmt.Errorf("获取并发计数失败: %w", err)
}
currentCount := 0
@@ -377,7 +430,7 @@ func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, cli
}
if currentCount >= m.limitConfig.MaxConcurrent {
return fmt.Errorf("并发请求超限: %d", currentCount)
return currentCount, fmt.Errorf("并发请求超限: %d", currentCount)
}
// 增加并发计数
@@ -390,7 +443,7 @@ func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, cli
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
}
return nil
return currentCount + 1, nil
}
// getClientIP 获取客户端IP地址增强版
@@ -435,35 +488,35 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
}
// checkTotalLimit 检查接口总请求次数限制
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) (int, error) {
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取总请求计数失败: %w", err)
return 0, fmt.Errorf("获取总请求计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerDay {
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
return count, fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
}
return nil
return count, nil
}
// checkIPLimit 检查IP限制
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) (int, error) {
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取IP计数失败: %w", err)
return 0, fmt.Errorf("获取IP计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerIP {
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
return count, fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
}
return nil
return count, nil
}
// incrementCounters 增加计数器