This commit is contained in:
Mrx
2026-02-27 14:49:29 +08:00
parent f17e22f4c8
commit d12529307b
16 changed files with 633 additions and 95 deletions

View File

@@ -48,6 +48,9 @@ type SendCodeCommand struct {
// 编码后的数据使用自定义编码方案的JSON字符串包含所有参数phone, scene, timestamp, nonce, signature
Data string `json:"data" binding:"required" example:"K8mN9vP2sL7kH3oB6yC1zA5uF0qE9tW..."` // 自定义编码后的数据
// 阿里云滑块验证码参数(直接接收,不参与编码)
CaptchaVerifyParam string `json:"captchaVerifyParam,omitempty" example:"..."` // 滑块验证码验证参数
// 以下字段从data解码后填充不直接接收
Phone string `json:"-"` // 从data解码后获取
Scene string `json:"-"` // 从data解码后获取

View File

@@ -13,7 +13,7 @@ func (s *UserApplicationServiceImpl) SendCode(ctx context.Context, cmd *commands
return err
}
err := s.smsCodeService.SendCode(ctx, cmd.Phone, entities.SMSScene(cmd.Scene), clientIP, userAgent)
err := s.smsCodeService.SendCode(ctx, cmd.Phone, entities.SMSScene(cmd.Scene), clientIP, userAgent, cmd.CaptchaVerifyParam)
if err != nil {
return err
}

View File

@@ -216,6 +216,11 @@ type SMSConfig struct {
// 签名验证配置
SignatureEnabled bool `mapstructure:"signature_enabled"` // 是否启用签名验证
SignatureSecret string `mapstructure:"signature_secret"` // 签名密钥
// 滑块验证码配置
CaptchaEnabled bool `mapstructure:"captcha_enabled"` // 是否启用滑块验证码
CaptchaSecret string `mapstructure:"captcha_secret"` // 阿里云验证码密钥
CaptchaEndpoint string `mapstructure:"captcha_endpoint"` // 阿里云验证码服务Endpoint
SceneID string `mapstructure:"scene_id"` // 阿里云验证码场景ID
}
// SMSRateLimit 短信限流配置

View File

@@ -37,6 +37,7 @@ import (
product_repo "tyapi-server/internal/infrastructure/database/repositories/product"
infra_events "tyapi-server/internal/infrastructure/events"
"tyapi-server/internal/infrastructure/external/alicloud"
"tyapi-server/internal/infrastructure/external/captcha"
"tyapi-server/internal/infrastructure/external/email"
"tyapi-server/internal/infrastructure/external/jiguang"
"tyapi-server/internal/infrastructure/external/muzi"
@@ -238,6 +239,19 @@ func NewContainer() *Container {
},
// 短信服务
sms.NewAliSMSService,
// 验证码服务
fx.Annotate(
func(cfg *config.Config) *captcha.CaptchaService {
return captcha.NewCaptchaService(captcha.CaptchaConfig{
AccessKeyID: cfg.SMS.AccessKeyID,
AccessKeySecret: cfg.SMS.AccessKeySecret,
EndpointURL: cfg.SMS.CaptchaEndpoint,
SceneID: cfg.SMS.SceneID,
EncryptKey: cfg.SMS.CaptchaSecret, // 加密模式 ekeyBase64 编码的 32 字节)
})
},
fx.ResultTags(`name:"captchaService"`),
),
// 邮件服务
fx.Annotate(
func(cfg *config.Config, logger *zap.Logger) *email.QQEmailService {
@@ -670,7 +684,10 @@ func NewContainer() *Container {
user_service.NewUserAggregateService,
),
user_service.NewUserAuthService,
user_service.NewSMSCodeService,
fx.Annotate(
user_service.NewSMSCodeService,
fx.ParamTags(``, ``, ``, `name:"captchaService"`),
),
user_service.NewContractAggregateService,
product_service.NewProductManagementService,
product_service.NewProductSubscriptionService,
@@ -1276,12 +1293,19 @@ func NewContainer() *Container {
) *handlers.UIComponentHandler {
return handlers.NewUIComponentHandler(uiComponentAppService, responseBuilder, validator, logger)
},
// 验证码HTTP处理器
fx.Annotate(
handlers.NewCaptchaHandler,
fx.ParamTags(`name:"captchaService"`, ``, ``, ``),
),
),
// 路由注册
fx.Provide(
// 用户路由
routes.NewUserRoutes,
// 验证码路由
routes.NewCaptchaRoutes,
// 认证路由
routes.NewCertificationRoutes,
// 财务路由
@@ -1408,6 +1432,7 @@ func RegisterMiddlewares(
func RegisterRoutes(
router *sharedhttp.GinRouter,
userRoutes *routes.UserRoutes,
captchaRoutes *routes.CaptchaRoutes,
certificationRoutes *routes.CertificationRoutes,
financeRoutes *routes.FinanceRoutes,
productRoutes *routes.ProductRoutes,
@@ -1432,6 +1457,7 @@ func RegisterRoutes(
// 所有域名路由路由
userRoutes.Register(router)
captchaRoutes.Register(router)
certificationRoutes.Register(router)
financeRoutes.Register(router)
productRoutes.Register(router)

View File

@@ -10,18 +10,20 @@ import (
"tyapi-server/internal/config"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/infrastructure/external/captcha"
"tyapi-server/internal/infrastructure/external/sms"
"tyapi-server/internal/shared/interfaces"
)
// SMSCodeService 短信验证码服务
type SMSCodeService struct {
repo repositories.SMSCodeRepository
smsClient *sms.AliSMSService
cache interfaces.CacheService
config config.SMSConfig
appConfig config.AppConfig
logger *zap.Logger
repo repositories.SMSCodeRepository
smsClient *sms.AliSMSService
cache interfaces.CacheService
captchaSvc *captcha.CaptchaService
config config.SMSConfig
appConfig config.AppConfig
logger *zap.Logger
}
// NewSMSCodeService 创建短信验证码服务
@@ -29,23 +31,36 @@ func NewSMSCodeService(
repo repositories.SMSCodeRepository,
smsClient *sms.AliSMSService,
cache interfaces.CacheService,
captchaSvc *captcha.CaptchaService,
config config.SMSConfig,
appConfig config.AppConfig,
logger *zap.Logger,
) *SMSCodeService {
return &SMSCodeService{
repo: repo,
smsClient: smsClient,
cache: cache,
config: config,
appConfig: appConfig,
logger: logger,
repo: repo,
smsClient: smsClient,
cache: cache,
captchaSvc: captchaSvc,
config: config,
appConfig: appConfig,
logger: logger,
}
}
// SendCode 发送验证码
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error {
// 0. 发送前安全限流检查
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent, captchaVerifyParam string) error {
// 0. 验证滑块验证码(如果启用)
if s.config.CaptchaEnabled && s.captchaSvc != nil {
if err := s.captchaSvc.Verify(captchaVerifyParam); err != nil {
s.logger.Warn("滑块验证码校验失败",
zap.String("phone", phone),
zap.String("scene", string(scene)),
zap.Error(err))
return captcha.ErrCaptchaVerifyFailed
}
}
// 0.1. 发送前安全限流检查
if err := s.CheckRateLimit(ctx, phone, scene, clientIP, userAgent); err != nil {
return err
}

View File

@@ -0,0 +1,134 @@
package captcha
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"time"
"github.com/alibabacloud-go/tea/tea"
captcha20230305 "github.com/alibabacloud-go/captcha-20230305/client"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
)
var (
ErrCaptchaVerifyFailed = errors.New("图形验证码校验失败")
ErrCaptchaConfig = errors.New("验证码配置错误")
ErrCaptchaEncryptMissing = errors.New("加密模式需要配置 EncryptKey控制台 ekey")
)
// CaptchaConfig 阿里云验证码配置
type CaptchaConfig struct {
AccessKeyID string
AccessKeySecret string
EndpointURL string
SceneID string
// EncryptKey 加密模式使用的密钥(控制台 ekeyBase64 编码的 32 字节),用于生成 EncryptedSceneId
EncryptKey string
}
// CaptchaService 阿里云验证码服务
type CaptchaService struct {
config CaptchaConfig
}
// NewCaptchaService 创建验证码服务实例
func NewCaptchaService(config CaptchaConfig) *CaptchaService {
return &CaptchaService{
config: config,
}
}
// Verify 验证滑块验证码
func (s *CaptchaService) Verify(captchaVerifyParam string) error {
if captchaVerifyParam == "" {
return ErrCaptchaVerifyFailed
}
if s.config.AccessKeyID == "" || s.config.AccessKeySecret == "" {
return ErrCaptchaConfig
}
clientCfg := &openapi.Config{
AccessKeyId: tea.String(s.config.AccessKeyID),
AccessKeySecret: tea.String(s.config.AccessKeySecret),
}
clientCfg.Endpoint = tea.String(s.config.EndpointURL)
client, err := captcha20230305.NewClient(clientCfg)
if err != nil {
return errors.Join(ErrCaptchaConfig, err)
}
req := &captcha20230305.VerifyIntelligentCaptchaRequest{
SceneId: tea.String(s.config.SceneID),
CaptchaVerifyParam: tea.String(captchaVerifyParam),
}
resp, err := client.VerifyIntelligentCaptcha(req)
if err != nil {
return errors.Join(ErrCaptchaVerifyFailed, err)
}
if resp.Body == nil || !tea.BoolValue(resp.Body.Result.VerifyResult) {
return ErrCaptchaVerifyFailed
}
return nil
}
// GetEncryptedSceneId 生成加密场景 IDEncryptedSceneId供前端加密模式初始化验证码使用。
// 算法AES-256-CBC明文 sceneId&timestamp&expireTime密钥为控制台 ekeyBase64 解码后 32 字节)。
// expireTimeSec 有效期为 186400 秒。
func (s *CaptchaService) GetEncryptedSceneId(expireTimeSec int) (string, error) {
if expireTimeSec <= 0 || expireTimeSec > 86400 {
return "", fmt.Errorf("expireTimeSec 必须在 186400 之间")
}
if s.config.EncryptKey == "" {
return "", ErrCaptchaEncryptMissing
}
if s.config.SceneID == "" {
return "", ErrCaptchaConfig
}
keyBytes, err := base64.StdEncoding.DecodeString(s.config.EncryptKey)
if err != nil || len(keyBytes) != 32 {
return "", errors.Join(ErrCaptchaConfig, fmt.Errorf("EncryptKey 必须为 Base64 编码的 32 字节"))
}
plaintext := fmt.Sprintf("%s&%d&%d", s.config.SceneID, time.Now().Unix(), expireTimeSec)
plainBytes := []byte(plaintext)
plainBytes = pkcs7Pad(plainBytes, aes.BlockSize)
block, err := aes.NewCipher(keyBytes)
if err != nil {
return "", errors.Join(ErrCaptchaConfig, err)
}
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}
mode := cipher.NewCBCEncrypter(block, iv)
ciphertext := make([]byte, len(plainBytes))
mode.CryptBlocks(ciphertext, plainBytes)
result := make([]byte, len(iv)+len(ciphertext))
copy(result, iv)
copy(result[len(iv):], ciphertext)
return base64.StdEncoding.EncodeToString(result), nil
}
func pkcs7Pad(data []byte, blockSize int) []byte {
n := blockSize - (len(data) % blockSize)
pad := make([]byte, n)
for i := range pad {
pad[i] = byte(n)
}
return append(data, pad...)
}

View File

@@ -0,0 +1,92 @@
package handlers
import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/config"
"tyapi-server/internal/infrastructure/external/captcha"
"tyapi-server/internal/shared/interfaces"
)
// CaptchaHandler 验证码滑块HTTP 处理器
type CaptchaHandler struct {
captchaService *captcha.CaptchaService
response interfaces.ResponseBuilder
config *config.Config
logger *zap.Logger
}
// NewCaptchaHandler 创建验证码处理器
func NewCaptchaHandler(
captchaService *captcha.CaptchaService,
response interfaces.ResponseBuilder,
cfg *config.Config,
logger *zap.Logger,
) *CaptchaHandler {
return &CaptchaHandler{
captchaService: captchaService,
response: response,
config: cfg,
logger: logger,
}
}
// EncryptedSceneIdReq 获取加密场景 ID 的请求(可选参数)
type EncryptedSceneIdReq struct {
ExpireSeconds *int `form:"expire_seconds" json:"expire_seconds"` // 有效期秒数186400默认 3600
}
// GetEncryptedSceneId 获取加密场景 ID供前端加密模式初始化阿里云验证码
// @Summary 获取验证码加密场景ID
// @Description 用于加密模式下发 EncryptedSceneId前端用此初始化滑块验证码
// @Tags 验证码
// @Accept json
// @Produce json
// @Param body body EncryptedSceneIdReq false "可选expire_seconds 有效期(1-86400)默认3600"
// @Success 200 {object} map[string]interface{} "encryptedSceneId"
// @Failure 400 {object} map[string]interface{} "配置未启用或参数错误"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/captcha/encryptedSceneId [post]
func (h *CaptchaHandler) GetEncryptedSceneId(c *gin.Context) {
expireSec := 3600
if c.Request.ContentLength > 0 {
var req EncryptedSceneIdReq
if err := c.ShouldBindJSON(&req); err == nil && req.ExpireSeconds != nil {
expireSec = *req.ExpireSeconds
}
}
if expireSec <= 0 || expireSec > 86400 {
h.response.BadRequest(c, "expire_seconds 必须在 186400 之间")
return
}
encrypted, err := h.captchaService.GetEncryptedSceneId(expireSec)
if err != nil {
if err == captcha.ErrCaptchaEncryptMissing || err == captcha.ErrCaptchaConfig {
h.logger.Warn("验证码加密场景ID生成失败", zap.Error(err))
h.response.BadRequest(c, "验证码加密模式未配置或配置错误")
return
}
h.logger.Error("验证码加密场景ID生成失败", zap.Error(err))
h.response.InternalError(c, "生成失败,请稍后重试")
return
}
h.response.Success(c, map[string]string{"encryptedSceneId": encrypted}, "ok")
}
// GetConfig 获取验证码前端配置是否启用、场景ID等便于前端决定是否展示滑块
// @Summary 获取验证码配置
// @Description 返回是否启用滑块、场景ID非加密模式用
// @Tags 验证码
// @Produce json
// @Success 200 {object} map[string]interface{} "captchaEnabled, sceneId"
// @Router /api/v1/captcha/config [get]
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
data := map[string]interface{}{
"captchaEnabled": h.config.SMS.CaptchaEnabled,
"sceneId": h.config.SMS.SceneID,
}
h.response.Success(c, data, "ok")
}

View File

@@ -68,7 +68,7 @@ type decodedSendCodeData struct {
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body commands.SendCodeCommand true "发送验证码请求(包含data字段"
// @Param request body commands.SendCodeCommand true "发送验证码请求包含data字段和可选的captchaVerifyParam字段"
// @Success 200 {object} map[string]interface{} "验证码发送成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 429 {object} map[string]interface{} "请求频率限制"
@@ -77,7 +77,7 @@ type decodedSendCodeData struct {
func (h *UserHandler) SendCode(c *gin.Context) {
var cmd commands.SendCodeCommand
// 绑定请求(包含data字段
// 绑定请求包含data字段和可选的captchaVerifyParam字段
if err := c.ShouldBindJSON(&cmd); err != nil {
h.response.BadRequest(c, "请求参数格式错误必须提供data字段")
return
@@ -123,11 +123,12 @@ func (h *UserHandler) SendCode(c *gin.Context) {
// 构建SendCodeCommand用于调用应用服务
serviceCmd := &commands.SendCodeCommand{
Phone: decodedData.Phone,
Scene: decodedData.Scene,
Timestamp: decodedData.Timestamp,
Nonce: decodedData.Nonce,
Signature: decodedData.Signature,
Phone: decodedData.Phone,
Scene: decodedData.Scene,
Timestamp: decodedData.Timestamp,
Nonce: decodedData.Nonce,
Signature: decodedData.Signature,
CaptchaVerifyParam: cmd.CaptchaVerifyParam,
}
clientIP := c.ClientIP()

View File

@@ -0,0 +1,33 @@
package routes
import (
"tyapi-server/internal/infrastructure/http/handlers"
sharedhttp "tyapi-server/internal/shared/http"
"go.uber.org/zap"
)
// CaptchaRoutes 验证码路由
type CaptchaRoutes struct {
handler *handlers.CaptchaHandler
logger *zap.Logger
}
// NewCaptchaRoutes 创建验证码路由
func NewCaptchaRoutes(handler *handlers.CaptchaHandler, logger *zap.Logger) *CaptchaRoutes {
return &CaptchaRoutes{
handler: handler,
logger: logger,
}
}
// Register 注册验证码相关路由
func (r *CaptchaRoutes) Register(router *sharedhttp.GinRouter) {
engine := router.GetEngine()
captchaGroup := engine.Group("/api/v1/captcha")
{
captchaGroup.POST("/encryptedSceneId", r.handler.GetEncryptedSceneId)
captchaGroup.GET("/config", r.handler.GetConfig)
}
r.logger.Info("验证码路由注册完成")
}

View File

@@ -20,39 +20,39 @@ type DailyRateLimitConfig struct {
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
// 新增安全配置
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
// 路径排除配置
ExcludePaths []string `mapstructure:"exclude_paths"` // 排除频率限制的路径
ExcludePaths []string `mapstructure:"exclude_paths"` // 排除频率限制的路径
// 域名排除配置
ExcludeDomains []string `mapstructure:"exclude_domains"` // 排除频率限制的域名
ExcludeDomains []string `mapstructure:"exclude_domains"` // 排除频率限制的域名
}
// DailyRateLimitMiddleware 每日请求限制中间件
type DailyRateLimitMiddleware struct {
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
config *config.Config
redis *redis.Client
response interfaces.ResponseBuilder
logger *zap.Logger
limitConfig DailyRateLimitConfig
}
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
func NewDailyRateLimitMiddleware(
cfg *config.Config,
redis *redis.Client,
cfg *config.Config,
redis *redis.Client,
response interfaces.ResponseBuilder,
logger *zap.Logger,
limitConfig DailyRateLimitConfig,
@@ -97,23 +97,23 @@ func (m *DailyRateLimitMiddleware) GetPriority() int {
func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
ctx := c.Request.Context()
// 检查是否在排除路径中
if m.isExcludedPath(c.Request.URL.Path) {
c.Next()
return
}
// 检查是否在排除域名中
host := c.Request.Host
if m.isExcludedDomain(host) {
c.Next()
return
}
// 获取客户端标识
clientIP := m.getClientIP(c)
// 1. 检查IP白名单/黑名单
if err := m.checkIPAccess(clientIP); err != nil {
m.logger.Warn("IP访问被拒绝",
@@ -124,7 +124,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 2. 检查User-Agent
if err := m.checkUserAgent(c); err != nil {
m.logger.Warn("User-Agent被阻止",
@@ -136,7 +136,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 3. 检查Referer
if err := m.checkReferer(c); err != nil {
m.logger.Warn("Referer检查失败",
@@ -148,7 +148,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 4. 检查并发限制
if err := m.checkConcurrentLimit(ctx, clientIP); err != nil {
m.logger.Warn("并发请求超限",
@@ -159,7 +159,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 5. 检查接口总请求次数限制
if err := m.checkTotalLimit(ctx); err != nil {
m.logger.Warn("接口总请求次数超限",
@@ -171,7 +171,7 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 6. 检查IP限制
if err := m.checkIPLimit(ctx, clientIP); err != nil {
m.logger.Warn("IP请求次数超限",
@@ -183,13 +183,13 @@ func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
c.Abort()
return
}
// 7. 增加计数
m.incrementCounters(ctx, clientIP)
// 8. 添加隐藏的响应头(仅用于内部监控)
m.addHiddenHeaders(c, clientIP)
c.Next()
}
}
@@ -267,7 +267,7 @@ func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
}
}
}
// 检查白名单(如果启用)
if m.limitConfig.EnableIPWhitelist {
allowed := false
@@ -281,7 +281,7 @@ func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
return fmt.Errorf("IP %s 不在白名单中", clientIP)
}
}
return nil
}
@@ -301,7 +301,7 @@ func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
}
return true
}
// 精确匹配
return clientIP == pattern
}
@@ -311,19 +311,19 @@ func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
if !m.limitConfig.EnableUserAgent {
return nil
}
userAgent := c.GetHeader("User-Agent")
if userAgent == "" {
return fmt.Errorf("缺少User-Agent")
}
// 检查被阻止的User-Agent
for _, blocked := range m.limitConfig.BlockedUserAgents {
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
return fmt.Errorf("User-Agent被阻止: %s", blocked)
}
}
return nil
}
@@ -332,12 +332,12 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
if !m.limitConfig.EnableReferer {
return nil
}
referer := c.GetHeader("Referer")
if referer == "" {
return fmt.Errorf("缺少Referer")
}
// 检查允许的Referer
if len(m.limitConfig.AllowedReferers) > 0 {
allowed := false
@@ -351,41 +351,41 @@ func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
return fmt.Errorf("Referer不被允许: %s", referer)
}
}
return nil
}
// checkConcurrentLimit 检查并发限制
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
// 获取当前并发数
current, err := m.redis.Get(ctx, key).Result()
if err != nil && err != redis.Nil {
return fmt.Errorf("获取并发计数失败: %w", err)
}
currentCount := 0
if current != "" {
if count, err := strconv.Atoi(current); err == nil {
currentCount = count
}
}
if currentCount >= m.limitConfig.MaxConcurrent {
return fmt.Errorf("并发请求超限: %d", currentCount)
}
// 增加并发计数
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
_, err = pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
}
return nil
}
@@ -395,15 +395,15 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
if m.limitConfig.EnableProxyCheck {
// 检查常见的代理头部
proxyHeaders := []string{
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
"CF-Connecting-IP", // Cloudflare
"X-Forwarded-For", // 标准代理头
"X-Real-IP", // Nginx
"X-Client-IP", // Apache
"X-Forwarded", // 其他代理
"Forwarded-For", // RFC 7239
"Forwarded", // RFC 7239
}
for _, header := range proxyHeaders {
if ip := c.GetHeader(header); ip != "" {
// 如果X-Forwarded-For包含多个IP取第一个
@@ -414,7 +414,7 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
}
}
}
// 回退到标准方法
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
if strings.Contains(xff, ",") {
@@ -422,43 +422,43 @@ func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
}
return xff
}
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
return c.ClientIP()
}
// checkTotalLimit 检查接口总请求次数限制
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) error {
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取总请求计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerDay {
return fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
}
return nil
}
// checkIPLimit 检查IP限制
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) error {
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
count, err := m.getCounter(ctx, key)
if err != nil {
return fmt.Errorf("获取IP计数失败: %w", err)
}
if count >= m.limitConfig.MaxRequestsPerIP {
return fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
}
return nil
}
@@ -467,7 +467,7 @@ func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, client
// 增加总请求计数
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
m.incrementCounter(ctx, totalKey)
// 增加IP计数
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
m.incrementCounter(ctx, ipKey)
@@ -482,12 +482,12 @@ func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (
}
return 0, err
}
count, err := strconv.Atoi(val)
if err != nil {
return 0, fmt.Errorf("解析计数失败: %w", err)
}
return count, nil
}
@@ -497,7 +497,7 @@ func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key str
pipe := m.redis.Pipeline()
pipe.Incr(ctx, key)
pipe.Expire(ctx, key, m.limitConfig.TTL)
_, err := pipe.Exec(ctx)
if err != nil {
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
@@ -512,14 +512,14 @@ func (m *DailyRateLimitMiddleware) getDateKey() string {
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
ctx := c.Request.Context()
// 添加隐藏的监控头(客户端看不到)
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
totalCount, _ := m.getCounter(ctx, totalKey)
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
ipCount, _ := m.getCounter(ctx, ipKey)
// 使用非标准的头部名称,避免被客户端识别
c.Header("X-System-Status", "normal")
c.Header("X-Total-Count", strconv.Itoa(totalCount))