add limit

This commit is contained in:
2025-08-10 14:40:02 +08:00
parent 9e6248efb2
commit bb291f6847
7 changed files with 841 additions and 14 deletions

View File

@@ -104,7 +104,48 @@ ocr:
ratelimit: ratelimit:
requests: 5000 requests: 5000
window: 60s window: 60s
burst: 100
# 每日请求限制配置
daily_ratelimit:
max_requests_per_day: 200 # 每日最大请求次数
max_requests_per_ip: 10 # 每个IP每日最大请求次数
key_prefix: "daily_limit" # Redis键前缀
ttl: 24h # 键过期时间
max_concurrent: 5 # 最大并发请求数
# 安全配置
enable_ip_whitelist: false # 是否启用IP白名单
ip_whitelist: # IP白名单列表
- "192.168.1.*" # 内网IP段
- "10.0.0.*" # 内网IP段
- "127.0.0.1" # 本地回环
enable_ip_blacklist: true # 是否启用IP黑名单
ip_blacklist: # IP黑名单列表
- "0.0.0.0" # 无效IP
- "255.255.255.255" # 广播IP
enable_user_agent: true # 是否检查User-Agent
blocked_user_agents: # 被阻止的User-Agent
- "bot" # 机器人
- "crawler" # 爬虫
- "spider" # 蜘蛛
- "scraper" # 抓取器
- "curl" # curl工具
- "wget" # wget工具
- "python" # Python脚本
- "java" # Java脚本
- "go-http-client" # Go HTTP客户端
enable_referer: true # 是否检查Referer
allowed_referers: # 允许的Referer
- "yourdomain.com" # 你的域名
- "api.yourdomain.com" # API域名
enable_proxy_check: true # 是否检查代理
enable_geo_block: false # 是否启用地理位置阻止
blocked_countries: # 被阻止的国家/地区
- "XX" # 示例国家代码
monitoring: monitoring:
metrics_enabled: true metrics_enabled: true

View File

@@ -0,0 +1,200 @@
# 每日请求限制中间件使用指南
## 概述
每日请求限制中间件DailyRateLimitMiddleware是一个专门用于防护恶意请求的高级安全中间件。它不仅限制每个IP地址的每日请求次数还集成了多种安全防护措施确保API接口不被恶意攻击者滥用保护系统资源和成本。
## 核心特性
### 1. 基础限流功能
- **每日限制**每个IP地址每日最多请求10次
- **并发控制**每个IP同时最多5个并发请求
- **Redis存储**使用Redis确保分布式环境中的计数一致性
- **自动过期**计数器自动在24小时后过期
### 2. 高级安全防护
- **IP白名单/黑名单**精确控制IP访问权限
- **User-Agent检测**:阻止机器人、爬虫、脚本工具等恶意请求
- **Referer验证**:确保请求来源合法
- **代理检测**识别并处理代理IP
- **地理位置阻止**:阻止特定国家/地区的访问
### 3. 隐蔽性设计
- **隐藏限制信息**:客户端无法获知被限制的真实原因
- **通用错误响应**:返回"系统繁忙"等通用错误信息
- **内部监控头**:仅用于系统内部监控,客户端不可见
## 配置说明
### 基础配置
```yaml
daily_ratelimit:
max_requests_per_day: 10 # 每日最大请求次数
max_requests_per_ip: 10 # 每个IP每日最大请求次数
key_prefix: "daily_limit" # Redis键前缀
ttl: 24h # 键过期时间
max_concurrent: 5 # 最大并发请求数
```
### 安全配置
```yaml
daily_ratelimit:
# IP访问控制
enable_ip_whitelist: false # 是否启用IP白名单
ip_whitelist: # IP白名单列表
- "192.168.1.*" # 内网IP段
- "10.0.0.*" # 内网IP段
enable_ip_blacklist: true # 是否启用IP黑名单
ip_blacklist: # IP黑名单列表
- "0.0.0.0" # 无效IP
- "255.255.255.255" # 广播IP
# User-Agent检测
enable_user_agent: true # 是否检查User-Agent
blocked_user_agents: # 被阻止的User-Agent
- "bot" # 机器人
- "crawler" # 爬虫
- "curl" # curl工具
- "python" # Python脚本
# Referer验证
enable_referer: true # 是否检查Referer
allowed_referers: # 允许的Referer
- "yourdomain.com" # 你的域名
# 代理检测
enable_proxy_check: true # 是否检查代理
```
## 防护恶意请求的措施
### 1. 机器人检测
- 自动识别并阻止常见的爬虫工具
- 阻止无User-Agent的请求
- 阻止自动化测试工具
### 2. IP地址控制
- 黑名单阻止已知的恶意IP
- 白名单仅允许受信任的IP访问
- 支持通配符匹配如192.168.1.*
### 3. 请求来源验证
- 验证Referer头部确保请求来源合法
- 阻止来自未知域名的请求
- 防止跨站请求伪造CSRF
### 4. 并发控制
- 限制单个IP的并发请求数
- 防止DDoS攻击
- 保护系统资源
### 5. 代理检测
- 识别各种代理头部
- 获取真实客户端IP
- 防止IP伪造攻击
## 使用示例
### 应用到特定路由
```go
// 在认证路由中应用
authGroup.POST("/enterprise-info", r.dailyRateLimit.Handle(), r.handler.SubmitEnterpriseInfo)
```
### 配置检查
```go
// 检查中间件状态
stats := dailyRateLimit.GetStats()
log.Printf("中间件配置: %+v", stats)
```
## 响应处理
### 正常响应
- 请求通过所有检查后正常处理
- 在隐藏的响应头中添加监控信息
### 被阻止的请求
- 返回通用错误信息:"访问被拒绝"或"系统繁忙"
- 不暴露具体的限制原因
- 记录详细的日志用于内部分析
### 隐藏的监控头
```http
X-System-Status: normal
X-Request-Count: 3
X-Reset-Time: 2024-01-02T00:00:00Z
```
## 监控和日志
### 日志记录
- 记录所有被阻止的请求
- 包含IP地址、User-Agent、Referer等信息
- 便于安全分析和威胁检测
### 统计信息
```go
stats := middleware.GetStats()
// 返回配置信息和安全特性状态
```
## 性能考虑
### Redis优化
- 使用Pipeline减少网络往返
- 设置合理的TTL避免内存泄漏
- 键名设计便于批量操作
### 内存使用
- 计数器自动过期
- 并发限制使用短期TTL
- 避免无限增长
## 安全最佳实践
### 1. 配置建议
- 启用User-Agent检测
- 启用Referer验证
- 启用代理检测
- 定期更新黑名单
### 2. 监控建议
- 监控被阻止的请求数量
- 分析攻击模式
- 及时调整安全策略
### 3. 部署建议
- 在生产环境中启用所有安全特性
- 根据业务需求调整限制参数
- 定期审查和更新配置
## 故障排除
### 常见问题
1. **误杀正常请求**
- 检查User-Agent白名单
- 调整Referer允许列表
- 检查IP白名单配置
2. **性能问题**
- 调整并发限制参数
- 检查Redis性能
- 优化键名设计
3. **配置不生效**
- 检查配置文件语法
- 确认中间件已正确注册
- 验证依赖注入
## 总结
每日请求限制中间件提供了一个强大而全面的解决方案来防护恶意请求。通过多层安全检查和隐蔽的限制机制有效防止API滥用保护系统资源同时为正常用户提供良好的使用体验。
该中间件特别适合需要控制成本的API接口通过精确的访问控制确保每次请求都来自合法用户避免恶意攻击造成的资源浪费。

View File

@@ -0,0 +1,57 @@
# 每日限流中间件使用指南
## 概述
每日限流中间件实现了多层限流策略:
- 接口一天最大请求200次
- 一个IP一天最多10次
- 支持并发限制和安全防护
## 主要特性
1. **多层限流策略**
- 接口总请求限制200次/天
- IP请求限制10次/天/IP
- 并发请求限制5个/IP
2. **安全防护**
- IP白名单/黑名单
- User-Agent检查
- Referer验证
- 代理检测
## 使用方法
```go
// 配置限流参数
limitConfig := middleware.DailyRateLimitConfig{
MaxRequestsPerDay: 200, // 接口一天最大请求200次
MaxRequestsPerIP: 10, // 一个IP一天最多10次
KeyPrefix: "api_limit", // Redis键前缀
TTL: 24 * time.Hour, // 24小时过期
MaxConcurrent: 5, // 最大并发5个
}
// 创建中间件实例
rateLimitMiddleware := middleware.NewDailyRateLimitMiddleware(
config, redisClient, response, logger, limitConfig)
// 应用到路由
router.Use(rateLimitMiddleware.Handle())
```
## 限流逻辑
1. 检查IP访问权限
2. 验证User-Agent和Referer
3. 检查并发限制
4. 检查接口总请求次数200次/天)
5. 检查IP请求次数10次/天)
6. 更新计数器
## 监控信息
响应头包含隐藏的监控信息:
- `X-Total-Count`: 当前总请求次数
- `X-IP-Count`: 当前IP请求次数
- `X-Reset-Time`: 重置时间

View File

@@ -17,7 +17,8 @@ type Config struct {
Email EmailConfig `mapstructure:"email"` Email EmailConfig `mapstructure:"email"`
Storage StorageConfig `mapstructure:"storage"` Storage StorageConfig `mapstructure:"storage"`
OCR OCRConfig `mapstructure:"ocr"` OCR OCRConfig `mapstructure:"ocr"`
RateLimit RateLimitConfig `mapstructure:"ratelimit"` RateLimit RateLimitConfig `mapstructure:"ratelimit"`
DailyRateLimit DailyRateLimitConfig `mapstructure:"daily_ratelimit"`
Monitoring MonitoringConfig `mapstructure:"monitoring"` Monitoring MonitoringConfig `mapstructure:"monitoring"`
Health HealthConfig `mapstructure:"health"` Health HealthConfig `mapstructure:"health"`
Resilience ResilienceConfig `mapstructure:"resilience"` Resilience ResilienceConfig `mapstructure:"resilience"`
@@ -118,6 +119,27 @@ type RateLimitConfig struct {
Burst int `mapstructure:"burst"` Burst int `mapstructure:"burst"`
} }
// DailyRateLimitConfig 每日限流配置
type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
}
// MonitoringConfig 监控配置 // MonitoringConfig 监控配置
type MonitoringConfig struct { type MonitoringConfig struct {
MetricsEnabled bool `mapstructure:"metrics_enabled"` MetricsEnabled bool `mapstructure:"metrics_enabled"`

View File

@@ -344,6 +344,29 @@ func NewContainer() *Container {
middleware.NewResponseTimeMiddleware, middleware.NewResponseTimeMiddleware,
middleware.NewCORSMiddleware, middleware.NewCORSMiddleware,
middleware.NewRateLimitMiddleware, middleware.NewRateLimitMiddleware,
// 每日限流中间件
func(cfg *config.Config, redis *redis.Client, response interfaces.ResponseBuilder, logger *zap.Logger) *middleware.DailyRateLimitMiddleware {
limitConfig := middleware.DailyRateLimitConfig{
MaxRequestsPerDay: cfg.DailyRateLimit.MaxRequestsPerDay,
MaxRequestsPerIP: cfg.DailyRateLimit.MaxRequestsPerIP,
KeyPrefix: cfg.DailyRateLimit.KeyPrefix,
TTL: cfg.DailyRateLimit.TTL,
MaxConcurrent: cfg.DailyRateLimit.MaxConcurrent,
// 安全配置
EnableIPWhitelist: cfg.DailyRateLimit.EnableIPWhitelist,
IPWhitelist: cfg.DailyRateLimit.IPWhitelist,
EnableIPBlacklist: cfg.DailyRateLimit.EnableIPBlacklist,
IPBlacklist: cfg.DailyRateLimit.IPBlacklist,
EnableUserAgent: cfg.DailyRateLimit.EnableUserAgent,
BlockedUserAgents: cfg.DailyRateLimit.BlockedUserAgents,
EnableReferer: cfg.DailyRateLimit.EnableReferer,
AllowedReferers: cfg.DailyRateLimit.AllowedReferers,
EnableGeoBlock: cfg.DailyRateLimit.EnableGeoBlock,
BlockedCountries: cfg.DailyRateLimit.BlockedCountries,
EnableProxyCheck: cfg.DailyRateLimit.EnableProxyCheck,
}
return middleware.NewDailyRateLimitMiddleware(cfg, redis, response, logger, limitConfig)
},
NewRequestLoggerMiddlewareWrapper, NewRequestLoggerMiddlewareWrapper,
middleware.NewJWTAuthMiddleware, middleware.NewJWTAuthMiddleware,
middleware.NewOptionalAuthMiddleware, middleware.NewOptionalAuthMiddleware,
@@ -701,6 +724,7 @@ func RegisterMiddlewares(
responseTime *middleware.ResponseTimeMiddleware, responseTime *middleware.ResponseTimeMiddleware,
cors *middleware.CORSMiddleware, cors *middleware.CORSMiddleware,
rateLimit *middleware.RateLimitMiddleware, rateLimit *middleware.RateLimitMiddleware,
dailyRateLimit *middleware.DailyRateLimitMiddleware,
requestLogger *middleware.RequestLoggerMiddleware, requestLogger *middleware.RequestLoggerMiddleware,
traceIDMiddleware *middleware.TraceIDMiddleware, traceIDMiddleware *middleware.TraceIDMiddleware,
errorTrackingMiddleware *middleware.ErrorTrackingMiddleware, errorTrackingMiddleware *middleware.ErrorTrackingMiddleware,
@@ -714,6 +738,7 @@ func RegisterMiddlewares(
router.RegisterMiddleware(responseTime) router.RegisterMiddleware(responseTime)
router.RegisterMiddleware(cors) router.RegisterMiddleware(cors)
router.RegisterMiddleware(rateLimit) router.RegisterMiddleware(rateLimit)
router.RegisterMiddleware(dailyRateLimit)
router.RegisterMiddleware(requestLogger) router.RegisterMiddleware(requestLogger)
router.RegisterMiddleware(traceIDMiddleware) router.RegisterMiddleware(traceIDMiddleware)
router.RegisterMiddleware(errorTrackingMiddleware) router.RegisterMiddleware(errorTrackingMiddleware)

View File

@@ -10,11 +10,12 @@ import (
// CertificationRoutes 认证路由 // CertificationRoutes 认证路由
type CertificationRoutes struct { type CertificationRoutes struct {
handler *handlers.CertificationHandler handler *handlers.CertificationHandler
router *http.GinRouter router *http.GinRouter
logger *zap.Logger logger *zap.Logger
auth *middleware.JWTAuthMiddleware auth *middleware.JWTAuthMiddleware
optional *middleware.OptionalAuthMiddleware optional *middleware.OptionalAuthMiddleware
dailyRateLimit *middleware.DailyRateLimitMiddleware
} }
// NewCertificationRoutes 创建认证路由 // NewCertificationRoutes 创建认证路由
@@ -24,13 +25,15 @@ func NewCertificationRoutes(
logger *zap.Logger, logger *zap.Logger,
auth *middleware.JWTAuthMiddleware, auth *middleware.JWTAuthMiddleware,
optional *middleware.OptionalAuthMiddleware, optional *middleware.OptionalAuthMiddleware,
dailyRateLimit *middleware.DailyRateLimitMiddleware,
) *CertificationRoutes { ) *CertificationRoutes {
return &CertificationRoutes{ return &CertificationRoutes{
handler: handler, handler: handler,
router: router, router: router,
logger: logger, logger: logger,
auth: auth, auth: auth,
optional: optional, optional: optional,
dailyRateLimit: dailyRateLimit,
} }
} }
@@ -48,8 +51,8 @@ func (r *CertificationRoutes) Register(router *http.GinRouter) {
// 1. 获取认证详情 // 1. 获取认证详情
authGroup.GET("/details", r.handler.GetCertification) authGroup.GET("/details", r.handler.GetCertification)
// 2. 提交企业信息 // 2. 提交企业信息(应用每日限流)
authGroup.POST("/enterprise-info", r.handler.SubmitEnterpriseInfo) authGroup.POST("/enterprise-info", r.dailyRateLimit.Handle(), r.handler.SubmitEnterpriseInfo)
// 3. 申请合同签署 // 3. 申请合同签署
authGroup.POST("/apply-contract", r.handler.ApplyContract) authGroup.POST("/apply-contract", r.handler.ApplyContract)

View File

@@ -0,0 +1,479 @@
package middleware
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// DailyRateLimitConfig 每日限流配置
type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
}
// DailyRateLimitMiddleware 每日请求限制中间件
type DailyRateLimitMiddleware struct {
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
limitConfig DailyRateLimitConfig
}
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
func NewDailyRateLimitMiddleware(
cfg *config.Config,
redis *redis.Client,
response interfaces.ResponseBuilder,
logger *zap.Logger,
limitConfig DailyRateLimitConfig,
) *DailyRateLimitMiddleware {
// 设置默认值
if limitConfig.MaxRequestsPerDay <= 0 {
limitConfig.MaxRequestsPerDay = 200 // 默认每日200次
}
if limitConfig.MaxRequestsPerIP <= 0 {
limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次
}
if limitConfig.KeyPrefix == "" {
limitConfig.KeyPrefix = "daily_limit"
}
if limitConfig.TTL == 0 {
limitConfig.TTL = 24 * time.Hour // 默认24小时过期
}
if limitConfig.MaxConcurrent <= 0 {
limitConfig.MaxConcurrent = 5 // 默认最大并发5个
}
return &DailyRateLimitMiddleware{
config: cfg,
redis: redis,
response: response,
logger: logger,
limitConfig: limitConfig,
}
}
// GetName 返回中间件名称
func (m *DailyRateLimitMiddleware) GetName() string {
return "daily_rate_limit"
}
// GetPriority 返回中间件优先级
func (m *DailyRateLimitMiddleware) GetPriority() int {
return 85 // 在认证之后,业务处理之前
}
// Handle 返回中间件处理函数
func (c *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 获取客户端标识
clientIP := m.getClientIP(c)
// 1. 检查IP白名单/黑名单
if err := m.checkIPAccess(clientIP); err != nil {
m.logger.Warn("IP访问被拒绝",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 2. 检查User-Agent
if err := m.checkUserAgent(c); err != nil {
m.logger.Warn("User-Agent被阻止",
zap.String("ip", clientIP),
zap.String("user_agent", c.GetHeader("User-Agent")),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 3. 检查Referer
if err := m.checkReferer(c); err != nil {
m.logger.Warn("Referer检查失败",
zap.String("ip", clientIP),
zap.String("referer", c.GetHeader("Referer")),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 4. 检查并发限制
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
m.logger.Warn("并发请求超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.TooManyRequests(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 5. 检查接口总请求次数限制
if err := m.checkTotalLimit(ctx); err != nil {
m.logger.Warn("接口总请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
// 隐藏限制信息,返回通用错误
m.response.InternalServerError(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 6. 检查IP限制
if err := m.checkIPLimit(ctx, clientIP); err != nil {
m.logger.Warn("IP请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
// 隐藏限制信息,返回通用错误
m.response.InternalServerError(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 7. 增加计数
m.incrementCounters(ctx, clientIP)
// 7. 添加隐藏的响应头(仅用于内部监控)
m.addHiddenHeaders(c, clientIP)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *DailyRateLimitMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到特定路由
}
// checkIPAccess 检查IP访问权限
func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
// 检查黑名单
if m.limitConfig.EnableIPBlacklist {
for _, blockedIP := range m.limitConfig.IPBlacklist {
if m.isIPMatch(clientIP, blockedIP) {
return fmt.Errorf("IP %s 在黑名单中", clientIP)
}
}
}
// 检查白名单(如果启用)
if m.limitConfig.EnableIPWhitelist {
allowed := false
for _, allowedIP := range m.limitConfig.IPWhitelist {
if m.isIPMatch(clientIP, allowedIP) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("IP %s 不在白名单中", clientIP)
}
}
return nil
}
// isIPMatch 检查IP是否匹配支持CIDR和通配符
func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
// 简单的通配符匹配
if strings.Contains(pattern, "*") {
parts := strings.Split(pattern, ".")
clientParts := strings.Split(clientIP, ".")
if len(parts) != len(clientParts) {
return false
}
for i, part := range parts {
if part != "*" && part != clientParts[i] {
return false
}
}
return true
}
// 精确匹配
return clientIP == pattern
}
// checkUserAgent 检查User-Agent
func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
if !m.limitConfig.EnableUserAgent {
return nil
}
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
return fmt.Errorf("缺少User-Agent")
}
// 检查被阻止的User-Agent
for _, blocked := range m.limitConfig.BlockedUserAgents {
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
return fmt.Errorf("User-Agent被阻止: %s", blocked)
}
}
return nil
}
// checkReferer 检查Referer
func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
if !m.limitConfig.EnableReferer {
return nil
}
referer := c.GetHeader("Referer")
if referer == "" {
return fmt.Errorf("缺少Referer")
}
// 检查允许的Referer
if len(m.limitConfig.AllowedReferers) > 0 {
allowed := false
for _, allowedRef := range m.limitConfig.AllowedReferers {
if strings.Contains(referer, allowedRef) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("Referer不被允许: %s", referer)
}
}
return nil
}
// checkConcurrentLimit 检查并发限制
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
// 获取当前并发数
current, err := m.redis.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
return fmt.Errorf("获取并发计数失败: %w", err)
}
currentCount := 0
if current != "" {
if count, err := strconv.Atoi(current); err == nil {
currentCount = count
}
}
if currentCount >= m.limitConfig.MaxConcurrent {
return fmt.Errorf("并发请求超限: %d", currentCount)
}
// 增加并发计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
_, err = pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
}
return nil
}
// getClientIP 获取客户端IP地址增强版
func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
// 检查是否为代理IP
if m.limitConfig.EnableProxyCheck {
// 检查常见的代理头部
proxyHeaders := []string{
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
}
for _, header := range proxyHeaders {
if ip := c.GetHeader(header); ip != "" {
// 如果X-Forwarded-For包含多个IP取第一个
if header == "X-Forwarded-For" && strings.Contains(ip, ",") {
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
}
return ip
}
}
}
// 回退到标准方法
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
if strings.Contains(xff, ",") {
return strings.TrimSpace(strings.Split(xff, ",")[0])
}
return xff
}
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// checkTotalLimit 检查接口总请求次数限制
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取总请求计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerDay {
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
}
return nil
}
// checkIPLimit 检查IP限制
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取IP计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerIP {
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
}
return nil
}
// incrementCounters 增加计数器
func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) {
// 增加总请求计数
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
m.incrementCounter(ctx, totalKey)
// 增加IP计数
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
m.incrementCounter(ctx, ipKey)
}
// getCounter 获取计数器值
func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) {
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return 0, nil // 键不存在计数为0
}
return 0, err
}
count, err := strconv.Atoi(val)
if err != nil {
return 0, fmt.Errorf("解析计数失败: %w", err)
}
return count
}
// incrementCounter 增加计数器
func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) {
// 使用Redis的INCR命令增加计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, m.limitConfig.TTL)
_, err := pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
}
}
// getDateKey 获取日期键格式2024-01-01
func (m *DailyRateLimitMiddleware) getDateKey() string {
return time.Now().Format("2006-01-02")
}
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
ctx := c.Request.Context()
// 添加隐藏的监控头(客户端看不到)
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
totalCount, _ := m.getCounter(ctx, totalKey)
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
ipCount, _ := m.getCounter(ctx, ipKey)
// 使用非标准的头部名称,避免被客户端识别
c.Header("X-System-Status", "normal")
c.Header("X-Total-Count", strconv.Itoa(totalCount))
c.Header("X-IP-Count", strconv.Itoa(ipCount))
c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339))
}
// getResetTime 获取重置时间明天0点
func (m *DailyRateLimitMiddleware) getResetTime() time.Time {
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location())
}
// GetStats 获取限流统计
func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} {
return map[string]interface{}{
"max_requests_per_day": m.limitConfig.MaxRequestsPerDay,
"max_requests_per_ip": m.limitConfig.MaxRequestsPerIP,
"max_concurrent": m.limitConfig.MaxConcurrent,
"key_prefix": m.limitConfig.KeyPrefix,
"ttl": m.limitConfig.TTL.String(),
"security_features": map[string]interface{}{
"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist,
"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist,
"user_agent_check": m.limitConfig.EnableUserAgent,
"referer_check": m.limitConfig.EnableReferer,
"proxy_check": m.limitConfig.EnableProxyCheck,
},
}
}