diff --git a/config.yaml b/config.yaml index 87ac3ab..bd7ea91 100644 --- a/config.yaml +++ b/config.yaml @@ -104,7 +104,48 @@ ocr: ratelimit: requests: 5000 window: 60s - burst: 100 + +# 每日请求限制配置 +daily_ratelimit: + max_requests_per_day: 200 # 每日最大请求次数 + max_requests_per_ip: 10 # 每个IP每日最大请求次数 + key_prefix: "daily_limit" # Redis键前缀 + ttl: 24h # 键过期时间 + max_concurrent: 5 # 最大并发请求数 + + # 安全配置 + enable_ip_whitelist: false # 是否启用IP白名单 + ip_whitelist: # IP白名单列表 + - "192.168.1.*" # 内网IP段 + - "10.0.0.*" # 内网IP段 + - "127.0.0.1" # 本地回环 + + enable_ip_blacklist: true # 是否启用IP黑名单 + ip_blacklist: # IP黑名单列表 + - "0.0.0.0" # 无效IP + - "255.255.255.255" # 广播IP + + enable_user_agent: true # 是否检查User-Agent + blocked_user_agents: # 被阻止的User-Agent + - "bot" # 机器人 + - "crawler" # 爬虫 + - "spider" # 蜘蛛 + - "scraper" # 抓取器 + - "curl" # curl工具 + - "wget" # wget工具 + - "python" # Python脚本 + - "java" # Java脚本 + - "go-http-client" # Go HTTP客户端 + + enable_referer: true # 是否检查Referer + allowed_referers: # 允许的Referer + - "https://console.tianyuanapi.com" # 天元API控制台 + - "https://consoletest.tianyuanapi.com" # 天元API测试控制台 + + enable_proxy_check: true # 是否检查代理 + enable_geo_block: false # 是否启用地理位置阻止 + blocked_countries: # 被阻止的国家/地区 + - "XX" # 示例国家代码 monitoring: metrics_enabled: true diff --git a/docs/每日请求限制中间件使用指南.md b/docs/每日请求限制中间件使用指南.md new file mode 100644 index 0000000..a4ae79c --- /dev/null +++ b/docs/每日请求限制中间件使用指南.md @@ -0,0 +1,200 @@ +# 每日请求限制中间件使用指南 + +## 概述 + +每日请求限制中间件(DailyRateLimitMiddleware)是一个专门用于防护恶意请求的高级安全中间件。它不仅限制每个IP地址的每日请求次数,还集成了多种安全防护措施,确保API接口不被恶意攻击者滥用,保护系统资源和成本。 + +## 核心特性 + +### 1. 基础限流功能 +- **每日限制**:每个IP地址每日最多请求10次 +- **并发控制**:每个IP同时最多5个并发请求 +- **Redis存储**:使用Redis确保分布式环境中的计数一致性 +- **自动过期**:计数器自动在24小时后过期 + +### 2. 高级安全防护 +- **IP白名单/黑名单**:精确控制IP访问权限 +- **User-Agent检测**:阻止机器人、爬虫、脚本工具等恶意请求 +- **Referer验证**:确保请求来源合法 +- **代理检测**:识别并处理代理IP +- **地理位置阻止**:阻止特定国家/地区的访问 + +### 3. 隐蔽性设计 +- **隐藏限制信息**:客户端无法获知被限制的真实原因 +- **通用错误响应**:返回"系统繁忙"等通用错误信息 +- **内部监控头**:仅用于系统内部监控,客户端不可见 + +## 配置说明 + +### 基础配置 + +```yaml +daily_ratelimit: + max_requests_per_day: 10 # 每日最大请求次数 + max_requests_per_ip: 10 # 每个IP每日最大请求次数 + key_prefix: "daily_limit" # Redis键前缀 + ttl: 24h # 键过期时间 + max_concurrent: 5 # 最大并发请求数 +``` + +### 安全配置 + +```yaml +daily_ratelimit: + # IP访问控制 + enable_ip_whitelist: false # 是否启用IP白名单 + ip_whitelist: # IP白名单列表 + - "192.168.1.*" # 内网IP段 + - "10.0.0.*" # 内网IP段 + + enable_ip_blacklist: true # 是否启用IP黑名单 + ip_blacklist: # IP黑名单列表 + - "0.0.0.0" # 无效IP + - "255.255.255.255" # 广播IP + + # User-Agent检测 + enable_user_agent: true # 是否检查User-Agent + blocked_user_agents: # 被阻止的User-Agent + - "bot" # 机器人 + - "crawler" # 爬虫 + - "curl" # curl工具 + - "python" # Python脚本 + + # Referer验证 + enable_referer: true # 是否检查Referer + allowed_referers: # 允许的Referer + - "yourdomain.com" # 你的域名 + + # 代理检测 + enable_proxy_check: true # 是否检查代理 +``` + +## 防护恶意请求的措施 + +### 1. 机器人检测 +- 自动识别并阻止常见的爬虫工具 +- 阻止无User-Agent的请求 +- 阻止自动化测试工具 + +### 2. IP地址控制 +- 黑名单:阻止已知的恶意IP +- 白名单:仅允许受信任的IP访问 +- 支持通配符匹配(如192.168.1.*) + +### 3. 请求来源验证 +- 验证Referer头部,确保请求来源合法 +- 阻止来自未知域名的请求 +- 防止跨站请求伪造(CSRF) + +### 4. 并发控制 +- 限制单个IP的并发请求数 +- 防止DDoS攻击 +- 保护系统资源 + +### 5. 代理检测 +- 识别各种代理头部 +- 获取真实客户端IP +- 防止IP伪造攻击 + +## 使用示例 + +### 应用到特定路由 + +```go +// 在认证路由中应用 +authGroup.POST("/enterprise-info", r.dailyRateLimit.Handle(), r.handler.SubmitEnterpriseInfo) +``` + +### 配置检查 + +```go +// 检查中间件状态 +stats := dailyRateLimit.GetStats() +log.Printf("中间件配置: %+v", stats) +``` + +## 响应处理 + +### 正常响应 +- 请求通过所有检查后正常处理 +- 在隐藏的响应头中添加监控信息 + +### 被阻止的请求 +- 返回通用错误信息:"访问被拒绝"或"系统繁忙" +- 不暴露具体的限制原因 +- 记录详细的日志用于内部分析 + +### 隐藏的监控头 +```http +X-System-Status: normal +X-Request-Count: 3 +X-Reset-Time: 2024-01-02T00:00:00Z +``` + +## 监控和日志 + +### 日志记录 +- 记录所有被阻止的请求 +- 包含IP地址、User-Agent、Referer等信息 +- 便于安全分析和威胁检测 + +### 统计信息 +```go +stats := middleware.GetStats() +// 返回配置信息和安全特性状态 +``` + +## 性能考虑 + +### Redis优化 +- 使用Pipeline减少网络往返 +- 设置合理的TTL避免内存泄漏 +- 键名设计便于批量操作 + +### 内存使用 +- 计数器自动过期 +- 并发限制使用短期TTL +- 避免无限增长 + +## 安全最佳实践 + +### 1. 配置建议 +- 启用User-Agent检测 +- 启用Referer验证 +- 启用代理检测 +- 定期更新黑名单 + +### 2. 监控建议 +- 监控被阻止的请求数量 +- 分析攻击模式 +- 及时调整安全策略 + +### 3. 部署建议 +- 在生产环境中启用所有安全特性 +- 根据业务需求调整限制参数 +- 定期审查和更新配置 + +## 故障排除 + +### 常见问题 + +1. **误杀正常请求** + - 检查User-Agent白名单 + - 调整Referer允许列表 + - 检查IP白名单配置 + +2. **性能问题** + - 调整并发限制参数 + - 检查Redis性能 + - 优化键名设计 + +3. **配置不生效** + - 检查配置文件语法 + - 确认中间件已正确注册 + - 验证依赖注入 + +## 总结 + +每日请求限制中间件提供了一个强大而全面的解决方案来防护恶意请求。通过多层安全检查和隐蔽的限制机制,有效防止API滥用,保护系统资源,同时为正常用户提供良好的使用体验。 + +该中间件特别适合需要控制成本的API接口,通过精确的访问控制,确保每次请求都来自合法用户,避免恶意攻击造成的资源浪费。 diff --git a/docs/每日限流中间件使用指南.md b/docs/每日限流中间件使用指南.md new file mode 100644 index 0000000..bee3a27 --- /dev/null +++ b/docs/每日限流中间件使用指南.md @@ -0,0 +1,57 @@ +# 每日限流中间件使用指南 + +## 概述 + +每日限流中间件实现了多层限流策略: +- 接口一天最大请求200次 +- 一个IP一天最多10次 +- 支持并发限制和安全防护 + +## 主要特性 + +1. **多层限流策略** + - 接口总请求限制:200次/天 + - IP请求限制:10次/天/IP + - 并发请求限制:5个/IP + +2. **安全防护** + - IP白名单/黑名单 + - User-Agent检查 + - Referer验证 + - 代理检测 + +## 使用方法 + +```go +// 配置限流参数 +limitConfig := middleware.DailyRateLimitConfig{ + MaxRequestsPerDay: 200, // 接口一天最大请求200次 + MaxRequestsPerIP: 10, // 一个IP一天最多10次 + KeyPrefix: "api_limit", // Redis键前缀 + TTL: 24 * time.Hour, // 24小时过期 + MaxConcurrent: 5, // 最大并发5个 +} + +// 创建中间件实例 +rateLimitMiddleware := middleware.NewDailyRateLimitMiddleware( + config, redisClient, response, logger, limitConfig) + +// 应用到路由 +router.Use(rateLimitMiddleware.Handle()) +``` + +## 限流逻辑 + +1. 检查IP访问权限 +2. 验证User-Agent和Referer +3. 检查并发限制 +4. 检查接口总请求次数(200次/天) +5. 检查IP请求次数(10次/天) +6. 更新计数器 + +## 监控信息 + +响应头包含隐藏的监控信息: +- `X-Total-Count`: 当前总请求次数 +- `X-IP-Count`: 当前IP请求次数 +- `X-Reset-Time`: 重置时间 diff --git a/internal/config/config.go b/internal/config/config.go index 2a5e2bb..8d732ef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -17,7 +17,8 @@ type Config struct { Email EmailConfig `mapstructure:"email"` Storage StorageConfig `mapstructure:"storage"` OCR OCRConfig `mapstructure:"ocr"` - RateLimit RateLimitConfig `mapstructure:"ratelimit"` + RateLimit RateLimitConfig `mapstructure:"ratelimit"` + DailyRateLimit DailyRateLimitConfig `mapstructure:"daily_ratelimit"` Monitoring MonitoringConfig `mapstructure:"monitoring"` Health HealthConfig `mapstructure:"health"` Resilience ResilienceConfig `mapstructure:"resilience"` @@ -118,6 +119,27 @@ type RateLimitConfig struct { Burst int `mapstructure:"burst"` } +// DailyRateLimitConfig 每日限流配置 +type DailyRateLimitConfig struct { + MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数 + MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数 + KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀 + TTL time.Duration `mapstructure:"ttl"` // 键过期时间 + // 新增安全配置 + EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单 + IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单 + EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单 + IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单 + EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent + BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent + EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer + AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer + EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止 + BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区 + EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理 + MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数 +} + // MonitoringConfig 监控配置 type MonitoringConfig struct { MetricsEnabled bool `mapstructure:"metrics_enabled"` diff --git a/internal/container/container.go b/internal/container/container.go index 10b92c4..2d026a5 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -344,6 +344,29 @@ func NewContainer() *Container { middleware.NewResponseTimeMiddleware, middleware.NewCORSMiddleware, middleware.NewRateLimitMiddleware, + // 每日限流中间件 + func(cfg *config.Config, redis *redis.Client, response interfaces.ResponseBuilder, logger *zap.Logger) *middleware.DailyRateLimitMiddleware { + limitConfig := middleware.DailyRateLimitConfig{ + MaxRequestsPerDay: cfg.DailyRateLimit.MaxRequestsPerDay, + MaxRequestsPerIP: cfg.DailyRateLimit.MaxRequestsPerIP, + KeyPrefix: cfg.DailyRateLimit.KeyPrefix, + TTL: cfg.DailyRateLimit.TTL, + MaxConcurrent: cfg.DailyRateLimit.MaxConcurrent, + // 安全配置 + EnableIPWhitelist: cfg.DailyRateLimit.EnableIPWhitelist, + IPWhitelist: cfg.DailyRateLimit.IPWhitelist, + EnableIPBlacklist: cfg.DailyRateLimit.EnableIPBlacklist, + IPBlacklist: cfg.DailyRateLimit.IPBlacklist, + EnableUserAgent: cfg.DailyRateLimit.EnableUserAgent, + BlockedUserAgents: cfg.DailyRateLimit.BlockedUserAgents, + EnableReferer: cfg.DailyRateLimit.EnableReferer, + AllowedReferers: cfg.DailyRateLimit.AllowedReferers, + EnableGeoBlock: cfg.DailyRateLimit.EnableGeoBlock, + BlockedCountries: cfg.DailyRateLimit.BlockedCountries, + EnableProxyCheck: cfg.DailyRateLimit.EnableProxyCheck, + } + return middleware.NewDailyRateLimitMiddleware(cfg, redis, response, logger, limitConfig) + }, NewRequestLoggerMiddlewareWrapper, middleware.NewJWTAuthMiddleware, middleware.NewOptionalAuthMiddleware, @@ -701,6 +724,7 @@ func RegisterMiddlewares( responseTime *middleware.ResponseTimeMiddleware, cors *middleware.CORSMiddleware, rateLimit *middleware.RateLimitMiddleware, + dailyRateLimit *middleware.DailyRateLimitMiddleware, requestLogger *middleware.RequestLoggerMiddleware, traceIDMiddleware *middleware.TraceIDMiddleware, errorTrackingMiddleware *middleware.ErrorTrackingMiddleware, @@ -714,6 +738,7 @@ func RegisterMiddlewares( router.RegisterMiddleware(responseTime) router.RegisterMiddleware(cors) router.RegisterMiddleware(rateLimit) + router.RegisterMiddleware(dailyRateLimit) router.RegisterMiddleware(requestLogger) router.RegisterMiddleware(traceIDMiddleware) router.RegisterMiddleware(errorTrackingMiddleware) diff --git a/internal/infrastructure/http/routes/certification_routes.go b/internal/infrastructure/http/routes/certification_routes.go index fe1d712..3219b45 100644 --- a/internal/infrastructure/http/routes/certification_routes.go +++ b/internal/infrastructure/http/routes/certification_routes.go @@ -10,11 +10,12 @@ import ( // CertificationRoutes 认证路由 type CertificationRoutes struct { - handler *handlers.CertificationHandler - router *http.GinRouter - logger *zap.Logger - auth *middleware.JWTAuthMiddleware - optional *middleware.OptionalAuthMiddleware + handler *handlers.CertificationHandler + router *http.GinRouter + logger *zap.Logger + auth *middleware.JWTAuthMiddleware + optional *middleware.OptionalAuthMiddleware + dailyRateLimit *middleware.DailyRateLimitMiddleware } // NewCertificationRoutes 创建认证路由 @@ -24,13 +25,15 @@ func NewCertificationRoutes( logger *zap.Logger, auth *middleware.JWTAuthMiddleware, optional *middleware.OptionalAuthMiddleware, + dailyRateLimit *middleware.DailyRateLimitMiddleware, ) *CertificationRoutes { return &CertificationRoutes{ - handler: handler, - router: router, - logger: logger, - auth: auth, - optional: optional, + handler: handler, + router: router, + logger: logger, + auth: auth, + optional: optional, + dailyRateLimit: dailyRateLimit, } } @@ -48,8 +51,8 @@ func (r *CertificationRoutes) Register(router *http.GinRouter) { // 1. 获取认证详情 authGroup.GET("/details", r.handler.GetCertification) - // 2. 提交企业信息 - authGroup.POST("/enterprise-info", r.handler.SubmitEnterpriseInfo) + // 2. 提交企业信息(应用每日限流) + authGroup.POST("/enterprise-info", r.dailyRateLimit.Handle(), r.handler.SubmitEnterpriseInfo) // 3. 申请合同签署 authGroup.POST("/apply-contract", r.handler.ApplyContract) diff --git a/internal/shared/middleware/daily_rate_limit.go b/internal/shared/middleware/daily_rate_limit.go new file mode 100644 index 0000000..6fe3204 --- /dev/null +++ b/internal/shared/middleware/daily_rate_limit.go @@ -0,0 +1,478 @@ +package middleware + +import ( + "context" + "fmt" + "strconv" + "strings" + "time" + + "tyapi-server/internal/config" + "tyapi-server/internal/shared/interfaces" + + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +// DailyRateLimitConfig 每日限流配置 +type DailyRateLimitConfig struct { + MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数 + MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数 + KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀 + TTL time.Duration `mapstructure:"ttl"` // 键过期时间 + // 新增安全配置 + EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单 + IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单 + EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单 + IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单 + EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent + BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent + EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer + AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer + EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止 + BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区 + EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理 + MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数 +} + +// DailyRateLimitMiddleware 每日请求限制中间件 +type DailyRateLimitMiddleware struct { + config *config.Config + redis *redis.Client + response interfaces.ResponseBuilder + logger *zap.Logger + limitConfig DailyRateLimitConfig +} + +// NewDailyRateLimitMiddleware 创建每日请求限制中间件 +func NewDailyRateLimitMiddleware( + cfg *config.Config, + redis *redis.Client, + response interfaces.ResponseBuilder, + logger *zap.Logger, + limitConfig DailyRateLimitConfig, +) *DailyRateLimitMiddleware { + // 设置默认值 + if limitConfig.MaxRequestsPerDay <= 0 { + limitConfig.MaxRequestsPerDay = 200 // 默认每日200次 + } + if limitConfig.MaxRequestsPerIP <= 0 { + limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次 + } + if limitConfig.KeyPrefix == "" { + limitConfig.KeyPrefix = "daily_limit" + } + if limitConfig.TTL == 0 { + limitConfig.TTL = 24 * time.Hour // 默认24小时过期 + } + if limitConfig.MaxConcurrent <= 0 { + limitConfig.MaxConcurrent = 5 // 默认最大并发5个 + } + + return &DailyRateLimitMiddleware{ + config: cfg, + redis: redis, + response: response, + logger: logger, + limitConfig: limitConfig, + } +} + +// GetName 返回中间件名称 +func (m *DailyRateLimitMiddleware) GetName() string { + return "daily_rate_limit" +} + +// GetPriority 返回中间件优先级 +func (m *DailyRateLimitMiddleware) GetPriority() int { + return 85 // 在认证之后,业务处理之前 +} + +// Handle 返回中间件处理函数 +func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc { + return func(c *gin.Context) { + ctx := c.Request.Context() + + // 获取客户端标识 + clientIP := m.getClientIP(c) + + // 1. 检查IP白名单/黑名单 + if err := m.checkIPAccess(clientIP); err != nil { + m.logger.Warn("IP访问被拒绝", + zap.String("ip", clientIP), + zap.String("request_id", c.GetString("request_id")), + zap.Error(err)) + m.response.Forbidden(c, "访问被拒绝") + c.Abort() + return + } + + // 2. 检查User-Agent + if err := m.checkUserAgent(c); err != nil { + m.logger.Warn("User-Agent被阻止", + zap.String("ip", clientIP), + zap.String("user_agent", c.GetHeader("User-Agent")), + zap.String("request_id", c.GetString("request_id")), + zap.Error(err)) + m.response.Forbidden(c, "访问被拒绝") + c.Abort() + return + } + + // 3. 检查Referer + if err := m.checkReferer(c); err != nil { + m.logger.Warn("Referer检查失败", + zap.String("ip", clientIP), + zap.String("referer", c.GetHeader("Referer")), + zap.String("request_id", c.GetString("request_id")), + zap.Error(err)) + m.response.Forbidden(c, "访问被拒绝") + c.Abort() + return + } + + // 4. 检查并发限制 + if err := m.checkConcurrentLimit(ctx, clientIP); err != nil { + m.logger.Warn("并发请求超限", + zap.String("ip", clientIP), + zap.String("request_id", c.GetString("request_id")), + zap.Error(err)) + m.response.TooManyRequests(c, "系统繁忙,请稍后再试") + c.Abort() + return + } + + // 5. 检查接口总请求次数限制 + if err := m.checkTotalLimit(ctx); err != nil { + m.logger.Warn("接口总请求次数超限", + zap.String("ip", clientIP), + zap.String("request_id", c.GetString("request_id")), + zap.Error(err)) + // 隐藏限制信息,返回通用错误 + m.response.InternalError(c, "系统繁忙,请稍后再试") + c.Abort() + return + } + + // 6. 检查IP限制 + if err := m.checkIPLimit(ctx, clientIP); err != nil { + m.logger.Warn("IP请求次数超限", + zap.String("ip", clientIP), + zap.String("request_id", c.GetString("request_id")), + zap.Error(err)) + // 隐藏限制信息,返回通用错误 + m.response.InternalError(c, "系统繁忙,请稍后再试") + c.Abort() + return + } + + // 7. 增加计数 + m.incrementCounters(ctx, clientIP) + + // 8. 添加隐藏的响应头(仅用于内部监控) + m.addHiddenHeaders(c, clientIP) + + c.Next() + } +} + +// IsGlobal 是否为全局中间件 +func (m *DailyRateLimitMiddleware) IsGlobal() bool { + return false // 不是全局中间件,需要手动应用到特定路由 +} + +// checkIPAccess 检查IP访问权限 +func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error { + // 检查黑名单 + if m.limitConfig.EnableIPBlacklist { + for _, blockedIP := range m.limitConfig.IPBlacklist { + if m.isIPMatch(clientIP, blockedIP) { + return fmt.Errorf("IP %s 在黑名单中", clientIP) + } + } + } + + // 检查白名单(如果启用) + if m.limitConfig.EnableIPWhitelist { + allowed := false + for _, allowedIP := range m.limitConfig.IPWhitelist { + if m.isIPMatch(clientIP, allowedIP) { + allowed = true + break + } + } + if !allowed { + return fmt.Errorf("IP %s 不在白名单中", clientIP) + } + } + + return nil +} + +// isIPMatch 检查IP是否匹配(支持CIDR和通配符) +func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool { + // 简单的通配符匹配 + if strings.Contains(pattern, "*") { + parts := strings.Split(pattern, ".") + clientParts := strings.Split(clientIP, ".") + if len(parts) != len(clientParts) { + return false + } + for i, part := range parts { + if part != "*" && part != clientParts[i] { + return false + } + } + return true + } + + // 精确匹配 + return clientIP == pattern +} + +// checkUserAgent 检查User-Agent +func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error { + if !m.limitConfig.EnableUserAgent { + return nil + } + + userAgent := c.GetHeader("User-Agent") + if userAgent == "" { + return fmt.Errorf("缺少User-Agent") + } + + // 检查被阻止的User-Agent + for _, blocked := range m.limitConfig.BlockedUserAgents { + if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) { + return fmt.Errorf("User-Agent被阻止: %s", blocked) + } + } + + return nil +} + +// checkReferer 检查Referer +func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error { + if !m.limitConfig.EnableReferer { + return nil + } + + referer := c.GetHeader("Referer") + if referer == "" { + return fmt.Errorf("缺少Referer") + } + + // 检查允许的Referer + if len(m.limitConfig.AllowedReferers) > 0 { + allowed := false + for _, allowedRef := range m.limitConfig.AllowedReferers { + if strings.Contains(referer, allowedRef) { + allowed = true + break + } + } + if !allowed { + return fmt.Errorf("Referer不被允许: %s", referer) + } + } + + return nil +} + +// checkConcurrentLimit 检查并发限制 +func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error { + key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP) + + // 获取当前并发数 + current, err := m.redis.Get(ctx, key).Result() + if err != nil && err != redis.Nil { + return fmt.Errorf("获取并发计数失败: %w", err) + } + + currentCount := 0 + if current != "" { + if count, err := strconv.Atoi(current); err == nil { + currentCount = count + } + } + + if currentCount >= m.limitConfig.MaxConcurrent { + return fmt.Errorf("并发请求超限: %d", currentCount) + } + + // 增加并发计数 + pipe := m.redis.Pipeline() + pipe.Incr(ctx, key) + pipe.Expire(ctx, key, 30*time.Second) // 30秒过期 + + _, err = pipe.Exec(ctx) + if err != nil { + m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err)) + } + + return nil +} + +// getClientIP 获取客户端IP地址(增强版) +func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string { + // 检查是否为代理IP + if m.limitConfig.EnableProxyCheck { + // 检查常见的代理头部 + proxyHeaders := []string{ + "CF-Connecting-IP", // Cloudflare + "X-Forwarded-For", // 标准代理头 + "X-Real-IP", // Nginx + "X-Client-IP", // Apache + "X-Forwarded", // 其他代理 + "Forwarded-For", // RFC 7239 + "Forwarded", // RFC 7239 + } + + for _, header := range proxyHeaders { + if ip := c.GetHeader(header); ip != "" { + // 如果X-Forwarded-For包含多个IP,取第一个 + if header == "X-Forwarded-For" && strings.Contains(ip, ",") { + ip = strings.TrimSpace(strings.Split(ip, ",")[0]) + } + return ip + } + } + } + + // 回退到标准方法 + if xff := c.GetHeader("X-Forwarded-For"); xff != "" { + if strings.Contains(xff, ",") { + return strings.TrimSpace(strings.Split(xff, ",")[0]) + } + return xff + } + + if xri := c.GetHeader("X-Real-IP"); xri != "" { + return xri + } + + return c.ClientIP() +} + +// checkTotalLimit 检查接口总请求次数限制 +func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error { + key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) + + count, err := m.getCounter(ctx, key) + if err != nil { + return fmt.Errorf("获取总请求计数失败: %w", err) + } + + if count >= m.limitConfig.MaxRequestsPerDay { + return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay) + } + + return nil +} + +// checkIPLimit 检查IP限制 +func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error { + key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) + + count, err := m.getCounter(ctx, key) + if err != nil { + return fmt.Errorf("获取IP计数失败: %w", err) + } + + if count >= m.limitConfig.MaxRequestsPerIP { + return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP) + } + + return nil +} + +// incrementCounters 增加计数器 +func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) { + // 增加总请求计数 + totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) + m.incrementCounter(ctx, totalKey) + + // 增加IP计数 + ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) + m.incrementCounter(ctx, ipKey) +} + +// getCounter 获取计数器值 +func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) { + val, err := m.redis.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return 0, nil // 键不存在,计数为0 + } + return 0, err + } + + count, err := strconv.Atoi(val) + if err != nil { + return 0, fmt.Errorf("解析计数失败: %w", err) + } + + return count, nil +} + +// incrementCounter 增加计数器 +func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) { + // 使用Redis的INCR命令增加计数 + pipe := m.redis.Pipeline() + pipe.Incr(ctx, key) + pipe.Expire(ctx, key, m.limitConfig.TTL) + + _, err := pipe.Exec(ctx) + if err != nil { + m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err)) + } +} + +// getDateKey 获取日期键(格式:2024-01-01) +func (m *DailyRateLimitMiddleware) getDateKey() string { + return time.Now().Format("2006-01-02") +} + +// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控) +func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) { + ctx := c.Request.Context() + + // 添加隐藏的监控头(客户端看不到) + totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey()) + totalCount, _ := m.getCounter(ctx, totalKey) + + ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey()) + ipCount, _ := m.getCounter(ctx, ipKey) + + // 使用非标准的头部名称,避免被客户端识别 + c.Header("X-System-Status", "normal") + c.Header("X-Total-Count", strconv.Itoa(totalCount)) + c.Header("X-IP-Count", strconv.Itoa(ipCount)) + c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339)) +} + +// getResetTime 获取重置时间(明天0点) +func (m *DailyRateLimitMiddleware) getResetTime() time.Time { + now := time.Now() + tomorrow := now.Add(24 * time.Hour) + return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location()) +} + +// GetStats 获取限流统计 +func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} { + return map[string]interface{}{ + "max_requests_per_day": m.limitConfig.MaxRequestsPerDay, + "max_requests_per_ip": m.limitConfig.MaxRequestsPerIP, + "max_concurrent": m.limitConfig.MaxConcurrent, + "key_prefix": m.limitConfig.KeyPrefix, + "ttl": m.limitConfig.TTL.String(), + "security_features": map[string]interface{}{ + "ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist, + "ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist, + "user_agent_check": m.limitConfig.EnableUserAgent, + "referer_check": m.limitConfig.EnableReferer, + "proxy_check": m.limitConfig.EnableProxyCheck, + }, + } +}