This commit is contained in:
2025-08-18 18:12:37 +08:00
7 changed files with 840 additions and 14 deletions

View File

@@ -17,7 +17,8 @@ type Config struct {
Email EmailConfig `mapstructure:"email"`
Storage StorageConfig `mapstructure:"storage"`
OCR OCRConfig `mapstructure:"ocr"`
RateLimit RateLimitConfig `mapstructure:"ratelimit"`
RateLimit RateLimitConfig `mapstructure:"ratelimit"`
DailyRateLimit DailyRateLimitConfig `mapstructure:"daily_ratelimit"`
Monitoring MonitoringConfig `mapstructure:"monitoring"`
Health HealthConfig `mapstructure:"health"`
Resilience ResilienceConfig `mapstructure:"resilience"`
@@ -118,6 +119,27 @@ type RateLimitConfig struct {
Burst int `mapstructure:"burst"`
}
// DailyRateLimitConfig 每日限流配置
type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
}
// MonitoringConfig 监控配置
type MonitoringConfig struct {
MetricsEnabled bool `mapstructure:"metrics_enabled"`

View File

@@ -344,6 +344,29 @@ func NewContainer() *Container {
middleware.NewResponseTimeMiddleware,
middleware.NewCORSMiddleware,
middleware.NewRateLimitMiddleware,
// 每日限流中间件
func(cfg *config.Config, redis *redis.Client, response interfaces.ResponseBuilder, logger *zap.Logger) *middleware.DailyRateLimitMiddleware {
limitConfig := middleware.DailyRateLimitConfig{
MaxRequestsPerDay: cfg.DailyRateLimit.MaxRequestsPerDay,
MaxRequestsPerIP: cfg.DailyRateLimit.MaxRequestsPerIP,
KeyPrefix: cfg.DailyRateLimit.KeyPrefix,
TTL: cfg.DailyRateLimit.TTL,
MaxConcurrent: cfg.DailyRateLimit.MaxConcurrent,
// 安全配置
EnableIPWhitelist: cfg.DailyRateLimit.EnableIPWhitelist,
IPWhitelist: cfg.DailyRateLimit.IPWhitelist,
EnableIPBlacklist: cfg.DailyRateLimit.EnableIPBlacklist,
IPBlacklist: cfg.DailyRateLimit.IPBlacklist,
EnableUserAgent: cfg.DailyRateLimit.EnableUserAgent,
BlockedUserAgents: cfg.DailyRateLimit.BlockedUserAgents,
EnableReferer: cfg.DailyRateLimit.EnableReferer,
AllowedReferers: cfg.DailyRateLimit.AllowedReferers,
EnableGeoBlock: cfg.DailyRateLimit.EnableGeoBlock,
BlockedCountries: cfg.DailyRateLimit.BlockedCountries,
EnableProxyCheck: cfg.DailyRateLimit.EnableProxyCheck,
}
return middleware.NewDailyRateLimitMiddleware(cfg, redis, response, logger, limitConfig)
},
NewRequestLoggerMiddlewareWrapper,
middleware.NewJWTAuthMiddleware,
middleware.NewOptionalAuthMiddleware,
@@ -701,6 +724,7 @@ func RegisterMiddlewares(
responseTime *middleware.ResponseTimeMiddleware,
cors *middleware.CORSMiddleware,
rateLimit *middleware.RateLimitMiddleware,
dailyRateLimit *middleware.DailyRateLimitMiddleware,
requestLogger *middleware.RequestLoggerMiddleware,
traceIDMiddleware *middleware.TraceIDMiddleware,
errorTrackingMiddleware *middleware.ErrorTrackingMiddleware,
@@ -714,6 +738,7 @@ func RegisterMiddlewares(
router.RegisterMiddleware(responseTime)
router.RegisterMiddleware(cors)
router.RegisterMiddleware(rateLimit)
router.RegisterMiddleware(dailyRateLimit)
router.RegisterMiddleware(requestLogger)
router.RegisterMiddleware(traceIDMiddleware)
router.RegisterMiddleware(errorTrackingMiddleware)

View File

@@ -10,11 +10,12 @@ import (
// CertificationRoutes 认证路由
type CertificationRoutes struct {
handler *handlers.CertificationHandler
router *http.GinRouter
logger *zap.Logger
auth *middleware.JWTAuthMiddleware
optional *middleware.OptionalAuthMiddleware
handler *handlers.CertificationHandler
router *http.GinRouter
logger *zap.Logger
auth *middleware.JWTAuthMiddleware
optional *middleware.OptionalAuthMiddleware
dailyRateLimit *middleware.DailyRateLimitMiddleware
}
// NewCertificationRoutes 创建认证路由
@@ -24,13 +25,15 @@ func NewCertificationRoutes(
logger *zap.Logger,
auth *middleware.JWTAuthMiddleware,
optional *middleware.OptionalAuthMiddleware,
dailyRateLimit *middleware.DailyRateLimitMiddleware,
) *CertificationRoutes {
return &CertificationRoutes{
handler: handler,
router: router,
logger: logger,
auth: auth,
optional: optional,
handler: handler,
router: router,
logger: logger,
auth: auth,
optional: optional,
dailyRateLimit: dailyRateLimit,
}
}
@@ -48,8 +51,8 @@ func (r *CertificationRoutes) Register(router *http.GinRouter) {
// 1. 获取认证详情
authGroup.GET("/details", r.handler.GetCertification)
// 2. 提交企业信息
authGroup.POST("/enterprise-info", r.handler.SubmitEnterpriseInfo)
// 2. 提交企业信息(应用每日限流)
authGroup.POST("/enterprise-info", r.dailyRateLimit.Handle(), r.handler.SubmitEnterpriseInfo)
// 3. 申请合同签署
authGroup.POST("/apply-contract", r.handler.ApplyContract)

View File

@@ -0,0 +1,478 @@
package middleware
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
)
// DailyRateLimitConfig 每日限流配置
type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
}
// DailyRateLimitMiddleware 每日请求限制中间件
type DailyRateLimitMiddleware struct {
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
limitConfig DailyRateLimitConfig
}
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
func NewDailyRateLimitMiddleware(
cfg *config.Config,
redis *redis.Client,
response interfaces.ResponseBuilder,
logger *zap.Logger,
limitConfig DailyRateLimitConfig,
) *DailyRateLimitMiddleware {
// 设置默认值
if limitConfig.MaxRequestsPerDay <= 0 {
limitConfig.MaxRequestsPerDay = 200 // 默认每日200次
}
if limitConfig.MaxRequestsPerIP <= 0 {
limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次
}
if limitConfig.KeyPrefix == "" {
limitConfig.KeyPrefix = "daily_limit"
}
if limitConfig.TTL == 0 {
limitConfig.TTL = 24 * time.Hour // 默认24小时过期
}
if limitConfig.MaxConcurrent <= 0 {
limitConfig.MaxConcurrent = 5 // 默认最大并发5个
}
return &DailyRateLimitMiddleware{
config: cfg,
redis: redis,
response: response,
logger: logger,
limitConfig: limitConfig,
}
}
// GetName 返回中间件名称
func (m *DailyRateLimitMiddleware) GetName() string {
return "daily_rate_limit"
}
// GetPriority 返回中间件优先级
func (m *DailyRateLimitMiddleware) GetPriority() int {
return 85 // 在认证之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 获取客户端标识
clientIP := m.getClientIP(c)
// 1. 检查IP白名单/黑名单
if err := m.checkIPAccess(clientIP); err != nil {
m.logger.Warn("IP访问被拒绝",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 2. 检查User-Agent
if err := m.checkUserAgent(c); err != nil {
m.logger.Warn("User-Agent被阻止",
zap.String("ip", clientIP),
zap.String("user_agent", c.GetHeader("User-Agent")),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 3. 检查Referer
if err := m.checkReferer(c); err != nil {
m.logger.Warn("Referer检查失败",
zap.String("ip", clientIP),
zap.String("referer", c.GetHeader("Referer")),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.Forbidden(c, "访问被拒绝")
c.Abort()
return
}
// 4. 检查并发限制
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
m.logger.Warn("并发请求超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
m.response.TooManyRequests(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 5. 检查接口总请求次数限制
if err := m.checkTotalLimit(ctx); err != nil {
m.logger.Warn("接口总请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
// 隐藏限制信息,返回通用错误
m.response.InternalError(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 6. 检查IP限制
if err := m.checkIPLimit(ctx, clientIP); err != nil {
m.logger.Warn("IP请求次数超限",
zap.String("ip", clientIP),
zap.String("request_id", c.GetString("request_id")),
zap.Error(err))
// 隐藏限制信息,返回通用错误
m.response.InternalError(c, "系统繁忙,请稍后再试")
c.Abort()
return
}
// 7. 增加计数
m.incrementCounters(ctx, clientIP)
// 8. 添加隐藏的响应头(仅用于内部监控)
m.addHiddenHeaders(c, clientIP)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *DailyRateLimitMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到特定路由
}
// checkIPAccess 检查IP访问权限
func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
// 检查黑名单
if m.limitConfig.EnableIPBlacklist {
for _, blockedIP := range m.limitConfig.IPBlacklist {
if m.isIPMatch(clientIP, blockedIP) {
return fmt.Errorf("IP %s 在黑名单中", clientIP)
}
}
}
// 检查白名单(如果启用)
if m.limitConfig.EnableIPWhitelist {
allowed := false
for _, allowedIP := range m.limitConfig.IPWhitelist {
if m.isIPMatch(clientIP, allowedIP) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("IP %s 不在白名单中", clientIP)
}
}
return nil
}
// isIPMatch 检查IP是否匹配支持CIDR和通配符
func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
// 简单的通配符匹配
if strings.Contains(pattern, "*") {
parts := strings.Split(pattern, ".")
clientParts := strings.Split(clientIP, ".")
if len(parts) != len(clientParts) {
return false
}
for i, part := range parts {
if part != "*" && part != clientParts[i] {
return false
}
}
return true
}
// 精确匹配
return clientIP == pattern
}
// checkUserAgent 检查User-Agent
func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
if !m.limitConfig.EnableUserAgent {
return nil
}
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
return fmt.Errorf("缺少User-Agent")
}
// 检查被阻止的User-Agent
for _, blocked := range m.limitConfig.BlockedUserAgents {
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
return fmt.Errorf("User-Agent被阻止: %s", blocked)
}
}
return nil
}
// checkReferer 检查Referer
func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
if !m.limitConfig.EnableReferer {
return nil
}
referer := c.GetHeader("Referer")
if referer == "" {
return fmt.Errorf("缺少Referer")
}
// 检查允许的Referer
if len(m.limitConfig.AllowedReferers) > 0 {
allowed := false
for _, allowedRef := range m.limitConfig.AllowedReferers {
if strings.Contains(referer, allowedRef) {
allowed = true
break
}
}
if !allowed {
return fmt.Errorf("Referer不被允许: %s", referer)
}
}
return nil
}
// checkConcurrentLimit 检查并发限制
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
// 获取当前并发数
current, err := m.redis.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
return fmt.Errorf("获取并发计数失败: %w", err)
}
currentCount := 0
if current != "" {
if count, err := strconv.Atoi(current); err == nil {
currentCount = count
}
}
if currentCount >= m.limitConfig.MaxConcurrent {
return fmt.Errorf("并发请求超限: %d", currentCount)
}
// 增加并发计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
_, err = pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
}
return nil
}
// getClientIP 获取客户端IP地址增强版
func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
// 检查是否为代理IP
if m.limitConfig.EnableProxyCheck {
// 检查常见的代理头部
proxyHeaders := []string{
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
}
for _, header := range proxyHeaders {
if ip := c.GetHeader(header); ip != "" {
// 如果X-Forwarded-For包含多个IP取第一个
if header == "X-Forwarded-For" && strings.Contains(ip, ",") {
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
}
return ip
}
}
}
// 回退到标准方法
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
if strings.Contains(xff, ",") {
return strings.TrimSpace(strings.Split(xff, ",")[0])
}
return xff
}
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// checkTotalLimit 检查接口总请求次数限制
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取总请求计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerDay {
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
}
return nil
}
// checkIPLimit 检查IP限制
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取IP计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerIP {
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
}
return nil
}
// incrementCounters 增加计数器
func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) {
// 增加总请求计数
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
m.incrementCounter(ctx, totalKey)
// 增加IP计数
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
m.incrementCounter(ctx, ipKey)
}
// getCounter 获取计数器值
func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) {
val, err := m.redis.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return 0, nil // 键不存在计数为0
}
return 0, err
}
count, err := strconv.Atoi(val)
if err != nil {
return 0, fmt.Errorf("解析计数失败: %w", err)
}
return count, nil
}
// incrementCounter 增加计数器
func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) {
// 使用Redis的INCR命令增加计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, m.limitConfig.TTL)
_, err := pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
}
}
// getDateKey 获取日期键格式2024-01-01
func (m *DailyRateLimitMiddleware) getDateKey() string {
return time.Now().Format("2006-01-02")
}
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
ctx := c.Request.Context()
// 添加隐藏的监控头(客户端看不到)
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
totalCount, _ := m.getCounter(ctx, totalKey)
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
ipCount, _ := m.getCounter(ctx, ipKey)
// 使用非标准的头部名称,避免被客户端识别
c.Header("X-System-Status", "normal")
c.Header("X-Total-Count", strconv.Itoa(totalCount))
c.Header("X-IP-Count", strconv.Itoa(ipCount))
c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339))
}
// getResetTime 获取重置时间明天0点
func (m *DailyRateLimitMiddleware) getResetTime() time.Time {
now := time.Now()
tomorrow := now.Add(24 * time.Hour)
return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location())
}
// GetStats 获取限流统计
func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} {
return map[string]interface{}{
"max_requests_per_day": m.limitConfig.MaxRequestsPerDay,
"max_requests_per_ip": m.limitConfig.MaxRequestsPerIP,
"max_concurrent": m.limitConfig.MaxConcurrent,
"key_prefix": m.limitConfig.KeyPrefix,
"ttl": m.limitConfig.TTL.String(),
"security_features": map[string]interface{}{
"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist,
"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist,
"user_agent_check": m.limitConfig.EnableUserAgent,
"referer_check": m.limitConfig.EnableReferer,
"proxy_check": m.limitConfig.EnableProxyCheck,
},
}
}