Files
tyapi-server/internal/shared/middleware/daily_rate_limit.go
2025-08-10 15:19:10 +08:00

479 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// DailyRateLimitConfig 每日限流配置
type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
}
// DailyRateLimitMiddleware 每日请求限制中间件
type DailyRateLimitMiddleware struct {
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
limitConfig DailyRateLimitConfig
}
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
func NewDailyRateLimitMiddleware(
cfg *config.Config,
redis *redis.Client,
response interfaces.ResponseBuilder,
logger *zap.Logger,
limitConfig DailyRateLimitConfig,
) *DailyRateLimitMiddleware {
// 设置默认值
if limitConfig.MaxRequestsPerDay <= 0 {
limitConfig.MaxRequestsPerDay = 200 // 默认每日200次
}
if limitConfig.MaxRequestsPerIP <= 0 {
limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次
}
if limitConfig.KeyPrefix == "" {
limitConfig.KeyPrefix = "daily_limit"
}
if limitConfig.TTL == 0 {
limitConfig.TTL = 24 * time.Hour // 默认24小时过期
}
if limitConfig.MaxConcurrent <= 0 {
limitConfig.MaxConcurrent = 5 // 默认最大并发5个
}
return &DailyRateLimitMiddleware{
config: cfg,
redis: redis,
response: response,
logger: logger,
limitConfig: limitConfig,
}
}
// GetName 返回中间件名称
func (m *DailyRateLimitMiddleware) GetName() string {
return "daily_rate_limit"
}
// GetPriority 返回中间件优先级
func (m *DailyRateLimitMiddleware) GetPriority() int {
return 85 // 在认证之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 获取客户端标识
clientIP := m.getClientIP(c)
// 1. 检查IP白名单/黑名单
if err := m.checkIPAccess(clientIP); err != nil {
m.logger.Warn("IP访问被拒绝",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 2. 检查User-Agent
if err := m.checkUserAgent(c); err != nil {
m.logger.Warn("User-Agent被阻止",
zap.String("ip", clientIP),
zap.String("user_agent", c.GetHeader("User-Agent")),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 3. 检查Referer
if err := m.checkReferer(c); err != nil {
m.logger.Warn("Referer检查失败",
zap.String("ip", clientIP),
zap.String("referer", c.GetHeader("Referer")),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 4. 检查并发限制
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
m.logger.Warn("并发请求超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.TooManyRequests(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 5. 检查接口总请求次数限制
if err := m.checkTotalLimit(ctx); err != nil {
m.logger.Warn("接口总请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
// 隐藏限制信息,返回通用错误
m.response.InternalError(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 6. 检查IP限制
if err := m.checkIPLimit(ctx, clientIP); err != nil {
m.logger.Warn("IP请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
// 隐藏限制信息,返回通用错误
m.response.InternalError(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 7. 增加计数
m.incrementCounters(ctx, clientIP)
// 8. 添加隐藏的响应头(仅用于内部监控)
m.addHiddenHeaders(c, clientIP)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *DailyRateLimitMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到特定路由
}
// checkIPAccess 检查IP访问权限
func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
// 检查黑名单
if m.limitConfig.EnableIPBlacklist {
for _, blockedIP := range m.limitConfig.IPBlacklist {
if m.isIPMatch(clientIP, blockedIP) {
return fmt.Errorf("IP %s 在黑名单中", clientIP)
}
}
}
// 检查白名单(如果启用)
if m.limitConfig.EnableIPWhitelist {
allowed := false
for _, allowedIP := range m.limitConfig.IPWhitelist {
if m.isIPMatch(clientIP, allowedIP) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("IP %s 不在白名单中", clientIP)
}
}
return nil
}
// isIPMatch 检查IP是否匹配支持CIDR和通配符
func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
// 简单的通配符匹配
if strings.Contains(pattern, "*") {
parts := strings.Split(pattern, ".")
clientParts := strings.Split(clientIP, ".")
if len(parts) != len(clientParts) {
return false
}
for i, part := range parts {
if part != "*" && part != clientParts[i] {
return false
}
}
return true
}
// 精确匹配
return clientIP == pattern
}
// checkUserAgent 检查User-Agent
func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
if !m.limitConfig.EnableUserAgent {
return nil
}
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
return fmt.Errorf("缺少User-Agent")
}
// 检查被阻止的User-Agent
for _, blocked := range m.limitConfig.BlockedUserAgents {
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
return fmt.Errorf("User-Agent被阻止: %s", blocked)
}
}
return nil
}
// checkReferer 检查Referer
func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
if !m.limitConfig.EnableReferer {
return nil
}
referer := c.GetHeader("Referer")
if referer == "" {
return fmt.Errorf("缺少Referer")
}
// 检查允许的Referer
if len(m.limitConfig.AllowedReferers) > 0 {
allowed := false
for _, allowedRef := range m.limitConfig.AllowedReferers {
if strings.Contains(referer, allowedRef) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("Referer不被允许: %s", referer)
}
}
return nil
}
// checkConcurrentLimit 检查并发限制
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
// 获取当前并发数
current, err := m.redis.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
return fmt.Errorf("获取并发计数失败: %w", err)
}
currentCount := 0
if current != "" {
if count, err := strconv.Atoi(current); err == nil {
currentCount = count
}
}
if currentCount >= m.limitConfig.MaxConcurrent {
return fmt.Errorf("并发请求超限: %d", currentCount)
}
// 增加并发计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
_, err = pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
}
return nil
}
// getClientIP 获取客户端IP地址增强版
func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
// 检查是否为代理IP
if m.limitConfig.EnableProxyCheck {
// 检查常见的代理头部
proxyHeaders := []string{
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
}
for _, header := range proxyHeaders {
if ip := c.GetHeader(header); ip != "" {
// 如果X-Forwarded-For包含多个IP取第一个
if header == "X-Forwarded-For" && strings.Contains(ip, ",") {
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
}
return ip
}
}
}
// 回退到标准方法
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
if strings.Contains(xff, ",") {
return strings.TrimSpace(strings.Split(xff, ",")[0])
}
return xff
}
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// checkTotalLimit 检查接口总请求次数限制
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取总请求计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerDay {
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
}
return nil
}
// checkIPLimit 检查IP限制
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取IP计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerIP {
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
}
return nil
}
// incrementCounters 增加计数器
func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) {
// 增加总请求计数
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
m.incrementCounter(ctx, totalKey)
// 增加IP计数
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
m.incrementCounter(ctx, ipKey)
}
// getCounter 获取计数器值
func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) {
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return 0, nil // 键不存在计数为0
}
return 0, err
}
count, err := strconv.Atoi(val)
if err != nil {
return 0, fmt.Errorf("解析计数失败: %w", err)
}
return count, nil
}
// incrementCounter 增加计数器
func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) {
// 使用Redis的INCR命令增加计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, m.limitConfig.TTL)
_, err := pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
}
}
// getDateKey 获取日期键格式2024-01-01
func (m *DailyRateLimitMiddleware) getDateKey() string {
return time.Now().Format("2006-01-02")
}
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
ctx := c.Request.Context()
// 添加隐藏的监控头(客户端看不到)
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
totalCount, _ := m.getCounter(ctx, totalKey)
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
ipCount, _ := m.getCounter(ctx, ipKey)
// 使用非标准的头部名称,避免被客户端识别
c.Header("X-System-Status", "normal")
c.Header("X-Total-Count", strconv.Itoa(totalCount))
c.Header("X-IP-Count", strconv.Itoa(ipCount))
c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339))
}
// getResetTime 获取重置时间明天0点
func (m *DailyRateLimitMiddleware) getResetTime() time.Time {
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location())
}
// GetStats 获取限流统计
func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} {
return map[string]interface{}{
"max_requests_per_day": m.limitConfig.MaxRequestsPerDay,
"max_requests_per_ip": m.limitConfig.MaxRequestsPerIP,
"max_concurrent": m.limitConfig.MaxConcurrent,
"key_prefix": m.limitConfig.KeyPrefix,
"ttl": m.limitConfig.TTL.String(),
"security_features": map[string]interface{}{
"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist,
"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist,
"user_agent_check": m.limitConfig.EnableUserAgent,
"referer_check": m.limitConfig.EnableReferer,
"proxy_check": m.limitConfig.EnableProxyCheck,
},
}
}