554 lines
16 KiB
Go
554 lines
16 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"tyapi-server/internal/config"
|
||
"tyapi-server/internal/shared/interfaces"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/redis/go-redis/v9"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// DailyRateLimitConfig 每日限流配置
|
||
type DailyRateLimitConfig struct {
|
||
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
|
||
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
|
||
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
|
||
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
|
||
// 新增安全配置
|
||
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
|
||
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
|
||
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
|
||
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
|
||
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
|
||
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
|
||
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
|
||
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
|
||
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
|
||
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
|
||
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
|
||
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
|
||
// 路径排除配置
|
||
ExcludePaths []string `mapstructure:"exclude_paths"` // 排除频率限制的路径
|
||
// 域名排除配置
|
||
ExcludeDomains []string `mapstructure:"exclude_domains"` // 排除频率限制的域名
|
||
}
|
||
|
||
// DailyRateLimitMiddleware 每日请求限制中间件
|
||
type DailyRateLimitMiddleware struct {
|
||
config *config.Config
|
||
redis *redis.Client
|
||
response interfaces.ResponseBuilder
|
||
logger *zap.Logger
|
||
limitConfig DailyRateLimitConfig
|
||
}
|
||
|
||
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
|
||
func NewDailyRateLimitMiddleware(
|
||
cfg *config.Config,
|
||
redis *redis.Client,
|
||
response interfaces.ResponseBuilder,
|
||
logger *zap.Logger,
|
||
limitConfig DailyRateLimitConfig,
|
||
) *DailyRateLimitMiddleware {
|
||
// 设置默认值
|
||
if limitConfig.MaxRequestsPerDay <= 0 {
|
||
limitConfig.MaxRequestsPerDay = 200 // 默认每日200次
|
||
}
|
||
if limitConfig.MaxRequestsPerIP <= 0 {
|
||
limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次
|
||
}
|
||
if limitConfig.KeyPrefix == "" {
|
||
limitConfig.KeyPrefix = "daily_limit"
|
||
}
|
||
if limitConfig.TTL == 0 {
|
||
limitConfig.TTL = 24 * time.Hour // 默认24小时过期
|
||
}
|
||
if limitConfig.MaxConcurrent <= 0 {
|
||
limitConfig.MaxConcurrent = 5 // 默认最大并发5个
|
||
}
|
||
|
||
return &DailyRateLimitMiddleware{
|
||
config: cfg,
|
||
redis: redis,
|
||
response: response,
|
||
logger: logger,
|
||
limitConfig: limitConfig,
|
||
}
|
||
}
|
||
|
||
// GetName 返回中间件名称
|
||
func (m *DailyRateLimitMiddleware) GetName() string {
|
||
return "daily_rate_limit"
|
||
}
|
||
|
||
// GetPriority 返回中间件优先级
|
||
func (m *DailyRateLimitMiddleware) GetPriority() int {
|
||
return 85 // 在认证之后,业务处理之前
|
||
}
|
||
|
||
// Handle 返回中间件处理函数
|
||
func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
ctx := c.Request.Context()
|
||
|
||
// 检查是否在排除路径中
|
||
if m.isExcludedPath(c.Request.URL.Path) {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 检查是否在排除域名中
|
||
host := c.Request.Host
|
||
if m.isExcludedDomain(host) {
|
||
c.Next()
|
||
return
|
||
}
|
||
|
||
// 获取客户端标识
|
||
clientIP := m.getClientIP(c)
|
||
|
||
// 1. 检查IP白名单/黑名单
|
||
if err := m.checkIPAccess(clientIP); err != nil {
|
||
m.logger.Warn("IP访问被拒绝",
|
||
zap.String("ip", clientIP),
|
||
zap.String("request_id", c.GetString("request_id")),
|
||
zap.Error(err))
|
||
m.response.Forbidden(c, "访问被拒绝")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 2. 检查User-Agent
|
||
if err := m.checkUserAgent(c); err != nil {
|
||
m.logger.Warn("User-Agent被阻止",
|
||
zap.String("ip", clientIP),
|
||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||
zap.String("request_id", c.GetString("request_id")),
|
||
zap.Error(err))
|
||
m.response.Forbidden(c, "访问被拒绝")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 3. 检查Referer
|
||
if err := m.checkReferer(c); err != nil {
|
||
m.logger.Warn("Referer检查失败",
|
||
zap.String("ip", clientIP),
|
||
zap.String("referer", c.GetHeader("Referer")),
|
||
zap.String("request_id", c.GetString("request_id")),
|
||
zap.Error(err))
|
||
m.response.Forbidden(c, "访问被拒绝")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 4. 检查并发限制
|
||
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
|
||
m.logger.Warn("并发请求超限",
|
||
zap.String("ip", clientIP),
|
||
zap.String("request_id", c.GetString("request_id")),
|
||
zap.Error(err))
|
||
m.response.TooManyRequests(c, "系统繁忙,请稍后再试")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 5. 检查接口总请求次数限制
|
||
if err := m.checkTotalLimit(ctx); err != nil {
|
||
m.logger.Warn("接口总请求次数超限",
|
||
zap.String("ip", clientIP),
|
||
zap.String("request_id", c.GetString("request_id")),
|
||
zap.Error(err))
|
||
// 隐藏限制信息,返回通用错误
|
||
m.response.InternalError(c, "系统繁忙,请稍后再试")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 6. 检查IP限制
|
||
if err := m.checkIPLimit(ctx, clientIP); err != nil {
|
||
m.logger.Warn("IP请求次数超限",
|
||
zap.String("ip", clientIP),
|
||
zap.String("request_id", c.GetString("request_id")),
|
||
zap.Error(err))
|
||
// 隐藏限制信息,返回通用错误
|
||
m.response.InternalError(c, "系统繁忙,请稍后再试")
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// 7. 增加计数
|
||
m.incrementCounters(ctx, clientIP)
|
||
|
||
// 8. 添加隐藏的响应头(仅用于内部监控)
|
||
m.addHiddenHeaders(c, clientIP)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// isExcludedDomain 检查域名是否在排除列表中
|
||
func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool {
|
||
for _, excludeDomain := range m.limitConfig.ExcludeDomains {
|
||
// 支持通配符匹配
|
||
if strings.HasPrefix(excludeDomain, "*") {
|
||
// 后缀匹配,如 "*.api.example.com" 匹配 "api.example.com"
|
||
if strings.HasSuffix(host, excludeDomain[1:]) {
|
||
return true
|
||
}
|
||
} else if strings.HasSuffix(excludeDomain, "*") {
|
||
// 前缀匹配,如 "api.*" 匹配 "api.example.com"
|
||
if strings.HasPrefix(host, excludeDomain[:len(excludeDomain)-1]) {
|
||
return true
|
||
}
|
||
} else {
|
||
// 精确匹配
|
||
if host == excludeDomain {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// isExcludedPath 检查路径是否在排除列表中
|
||
func (m *DailyRateLimitMiddleware) isExcludedPath(path string) bool {
|
||
for _, excludePath := range m.limitConfig.ExcludePaths {
|
||
// 支持多种匹配模式
|
||
if strings.HasPrefix(excludePath, "*") {
|
||
// 前缀匹配,如 "*api_name" 匹配 "/api/v1/any_api_name"
|
||
if strings.Contains(path, excludePath[1:]) {
|
||
return true
|
||
}
|
||
} else if strings.HasSuffix(excludePath, "*") {
|
||
// 后缀匹配,如 "/api/v1/*" 匹配 "/api/v1/any_api_name"
|
||
if strings.HasPrefix(path, excludePath[:len(excludePath)-1]) {
|
||
return true
|
||
}
|
||
} else if strings.Contains(excludePath, "*") {
|
||
// 中间通配符匹配,如 "/api/v1/*api_name" 匹配 "/api/v1/any_api_name"
|
||
parts := strings.Split(excludePath, "*")
|
||
if len(parts) == 2 {
|
||
prefix := parts[0]
|
||
suffix := parts[1]
|
||
if strings.HasPrefix(path, prefix) && strings.HasSuffix(path, suffix) {
|
||
return true
|
||
}
|
||
}
|
||
} else {
|
||
// 精确匹配
|
||
if path == excludePath {
|
||
return true
|
||
}
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// IsGlobal 是否为全局中间件
|
||
func (m *DailyRateLimitMiddleware) IsGlobal() bool {
|
||
return false // 不是全局中间件,需要手动应用到特定路由
|
||
}
|
||
|
||
// checkIPAccess 检查IP访问权限
|
||
func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
|
||
// 检查黑名单
|
||
if m.limitConfig.EnableIPBlacklist {
|
||
for _, blockedIP := range m.limitConfig.IPBlacklist {
|
||
if m.isIPMatch(clientIP, blockedIP) {
|
||
return fmt.Errorf("IP %s 在黑名单中", clientIP)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 检查白名单(如果启用)
|
||
if m.limitConfig.EnableIPWhitelist {
|
||
allowed := false
|
||
for _, allowedIP := range m.limitConfig.IPWhitelist {
|
||
if m.isIPMatch(clientIP, allowedIP) {
|
||
allowed = true
|
||
break
|
||
}
|
||
}
|
||
if !allowed {
|
||
return fmt.Errorf("IP %s 不在白名单中", clientIP)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// isIPMatch 检查IP是否匹配(支持CIDR和通配符)
|
||
func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
|
||
// 简单的通配符匹配
|
||
if strings.Contains(pattern, "*") {
|
||
parts := strings.Split(pattern, ".")
|
||
clientParts := strings.Split(clientIP, ".")
|
||
if len(parts) != len(clientParts) {
|
||
return false
|
||
}
|
||
for i, part := range parts {
|
||
if part != "*" && part != clientParts[i] {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
// 精确匹配
|
||
return clientIP == pattern
|
||
}
|
||
|
||
// checkUserAgent 检查User-Agent
|
||
func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
|
||
if !m.limitConfig.EnableUserAgent {
|
||
return nil
|
||
}
|
||
|
||
userAgent := c.GetHeader("User-Agent")
|
||
if userAgent == "" {
|
||
return fmt.Errorf("缺少User-Agent")
|
||
}
|
||
|
||
// 检查被阻止的User-Agent
|
||
for _, blocked := range m.limitConfig.BlockedUserAgents {
|
||
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
|
||
return fmt.Errorf("User-Agent被阻止: %s", blocked)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkReferer 检查Referer
|
||
func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
|
||
if !m.limitConfig.EnableReferer {
|
||
return nil
|
||
}
|
||
|
||
referer := c.GetHeader("Referer")
|
||
if referer == "" {
|
||
return fmt.Errorf("缺少Referer")
|
||
}
|
||
|
||
// 检查允许的Referer
|
||
if len(m.limitConfig.AllowedReferers) > 0 {
|
||
allowed := false
|
||
for _, allowedRef := range m.limitConfig.AllowedReferers {
|
||
if strings.Contains(referer, allowedRef) {
|
||
allowed = true
|
||
break
|
||
}
|
||
}
|
||
if !allowed {
|
||
return fmt.Errorf("Referer不被允许: %s", referer)
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkConcurrentLimit 检查并发限制
|
||
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
|
||
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
|
||
|
||
// 获取当前并发数
|
||
current, err := m.redis.Get(ctx, key).Result()
|
||
if err != nil && err != redis.Nil {
|
||
return fmt.Errorf("获取并发计数失败: %w", err)
|
||
}
|
||
|
||
currentCount := 0
|
||
if current != "" {
|
||
if count, err := strconv.Atoi(current); err == nil {
|
||
currentCount = count
|
||
}
|
||
}
|
||
|
||
if currentCount >= m.limitConfig.MaxConcurrent {
|
||
return fmt.Errorf("并发请求超限: %d", currentCount)
|
||
}
|
||
|
||
// 增加并发计数
|
||
pipe := m.redis.Pipeline()
|
||
pipe.Incr(ctx, key)
|
||
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
|
||
|
||
_, err = pipe.Exec(ctx)
|
||
if err != nil {
|
||
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// getClientIP 获取客户端IP地址(增强版)
|
||
func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
|
||
// 检查是否为代理IP
|
||
if m.limitConfig.EnableProxyCheck {
|
||
// 检查常见的代理头部
|
||
proxyHeaders := []string{
|
||
"CF-Connecting-IP", // Cloudflare
|
||
"X-Forwarded-For", // 标准代理头
|
||
"X-Real-IP", // Nginx
|
||
"X-Client-IP", // Apache
|
||
"X-Forwarded", // 其他代理
|
||
"Forwarded-For", // RFC 7239
|
||
"Forwarded", // RFC 7239
|
||
}
|
||
|
||
for _, header := range proxyHeaders {
|
||
if ip := c.GetHeader(header); ip != "" {
|
||
// 如果X-Forwarded-For包含多个IP,取第一个
|
||
if header == "X-Forwarded-For" && strings.Contains(ip, ",") {
|
||
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
|
||
}
|
||
return ip
|
||
}
|
||
}
|
||
}
|
||
|
||
// 回退到标准方法
|
||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||
if strings.Contains(xff, ",") {
|
||
return strings.TrimSpace(strings.Split(xff, ",")[0])
|
||
}
|
||
return xff
|
||
}
|
||
|
||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||
return xri
|
||
}
|
||
|
||
return c.ClientIP()
|
||
}
|
||
|
||
// checkTotalLimit 检查接口总请求次数限制
|
||
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
|
||
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||
|
||
count, err := m.getCounter(ctx, key)
|
||
if err != nil {
|
||
return fmt.Errorf("获取总请求计数失败: %w", err)
|
||
}
|
||
|
||
if count >= m.limitConfig.MaxRequestsPerDay {
|
||
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// checkIPLimit 检查IP限制
|
||
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
|
||
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||
|
||
count, err := m.getCounter(ctx, key)
|
||
if err != nil {
|
||
return fmt.Errorf("获取IP计数失败: %w", err)
|
||
}
|
||
|
||
if count >= m.limitConfig.MaxRequestsPerIP {
|
||
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// incrementCounters 增加计数器
|
||
func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) {
|
||
// 增加总请求计数
|
||
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||
m.incrementCounter(ctx, totalKey)
|
||
|
||
// 增加IP计数
|
||
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||
m.incrementCounter(ctx, ipKey)
|
||
}
|
||
|
||
// getCounter 获取计数器值
|
||
func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) {
|
||
val, err := m.redis.Get(ctx, key).Result()
|
||
if err != nil {
|
||
if err == redis.Nil {
|
||
return 0, nil // 键不存在,计数为0
|
||
}
|
||
return 0, err
|
||
}
|
||
|
||
count, err := strconv.Atoi(val)
|
||
if err != nil {
|
||
return 0, fmt.Errorf("解析计数失败: %w", err)
|
||
}
|
||
|
||
return count, nil
|
||
}
|
||
|
||
// incrementCounter 增加计数器
|
||
func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) {
|
||
// 使用Redis的INCR命令增加计数
|
||
pipe := m.redis.Pipeline()
|
||
pipe.Incr(ctx, key)
|
||
pipe.Expire(ctx, key, m.limitConfig.TTL)
|
||
|
||
_, err := pipe.Exec(ctx)
|
||
if err != nil {
|
||
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
|
||
}
|
||
}
|
||
|
||
// getDateKey 获取日期键(格式:2024-01-01)
|
||
func (m *DailyRateLimitMiddleware) getDateKey() string {
|
||
return time.Now().Format("2006-01-02")
|
||
}
|
||
|
||
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
|
||
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
|
||
ctx := c.Request.Context()
|
||
|
||
// 添加隐藏的监控头(客户端看不到)
|
||
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||
totalCount, _ := m.getCounter(ctx, totalKey)
|
||
|
||
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||
ipCount, _ := m.getCounter(ctx, ipKey)
|
||
|
||
// 使用非标准的头部名称,避免被客户端识别
|
||
c.Header("X-System-Status", "normal")
|
||
c.Header("X-Total-Count", strconv.Itoa(totalCount))
|
||
c.Header("X-IP-Count", strconv.Itoa(ipCount))
|
||
c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339))
|
||
}
|
||
|
||
// getResetTime 获取重置时间(明天0点)
|
||
func (m *DailyRateLimitMiddleware) getResetTime() time.Time {
|
||
now := time.Now()
|
||
tomorrow := now.Add(24 * time.Hour)
|
||
return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location())
|
||
}
|
||
|
||
// GetStats 获取限流统计
|
||
func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} {
|
||
return map[string]interface{}{
|
||
"max_requests_per_day": m.limitConfig.MaxRequestsPerDay,
|
||
"max_requests_per_ip": m.limitConfig.MaxRequestsPerIP,
|
||
"max_concurrent": m.limitConfig.MaxConcurrent,
|
||
"key_prefix": m.limitConfig.KeyPrefix,
|
||
"ttl": m.limitConfig.TTL.String(),
|
||
"security_features": map[string]interface{}{
|
||
"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist,
|
||
"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist,
|
||
"user_agent_check": m.limitConfig.EnableUserAgent,
|
||
"referer_check": m.limitConfig.EnableReferer,
|
||
"proxy_check": m.limitConfig.EnableProxyCheck,
|
||
},
|
||
}
|
||
}
|