f
This commit is contained in:
50
internal/shared/middleware/api_auth.go
Normal file
50
internal/shared/middleware/api_auth.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ApiAuthMiddleware API认证中间件
|
||||
type ApiAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
}
|
||||
|
||||
// NewApiAuthMiddleware 创建API认证中间件
|
||||
func NewApiAuthMiddleware(cfg *config.Config, logger *zap.Logger, responseBuilder interfaces.ResponseBuilder) *ApiAuthMiddleware {
|
||||
return &ApiAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
responseBuilder: responseBuilder,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ApiAuthMiddleware) GetName() string {
|
||||
return "api_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ApiAuthMiddleware) GetPriority() int {
|
||||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ApiAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端IP地址,并存入上下文
|
||||
clientIP := c.ClientIP()
|
||||
c.Set("client_ip", clientIP)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ApiAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
379
internal/shared/middleware/auth.go
Normal file
379
internal/shared/middleware/auth.go
Normal file
@@ -0,0 +1,379 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT认证中间件
|
||||
type JWTAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewJWTAuthMiddleware 创建JWT认证中间件
|
||||
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
|
||||
return &JWTAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *JWTAuthMiddleware) GetName() string {
|
||||
return "jwt_auth"
|
||||
}
|
||||
|
||||
// GetExpiresIn 返回JWT过期时间
|
||||
func (m *JWTAuthMiddleware) GetExpiresIn() time.Duration {
|
||||
return m.config.JWT.ExpiresIn
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *JWTAuthMiddleware) GetPriority() int {
|
||||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "缺少认证头部")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "认证头部格式无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "缺少认证令牌")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("无效的认证令牌",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "认证令牌无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息添加到上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("phone", claims.Phone)
|
||||
c.Set("user_type", claims.UserType)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *JWTAuthMiddleware) IsGlobal() bool {
|
||||
return false // 不是全局中间件,需要手动应用到需要认证的路由
|
||||
}
|
||||
|
||||
// JWTClaims JWT声明结构
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
Phone string `json:"phone"`
|
||||
UserType string `json:"user_type"` // 新增:用户类型
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// validateToken 验证JWT token
|
||||
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// respondUnauthorized 返回未授权响应
|
||||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "认证失败",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (m *JWTAuthMiddleware) GenerateToken(userID, phone, email, userType string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: userID,
|
||||
Username: phone, // 普通用户用手机号,管理员用用户名
|
||||
Email: email,
|
||||
Phone: phone,
|
||||
UserType: userType, // 新增:用户类型
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "hyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"hyapi-client"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新token
|
||||
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Issuer: "hyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"hyapi-refresh"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 验证刷新token
|
||||
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// 检查是否为刷新token
|
||||
if len(claims.Audience) == 0 || claims.Audience[0] != "hyapi-refresh" {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims.Subject, nil
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
|
||||
type OptionalAuthMiddleware struct {
|
||||
jwtAuth *JWTAuthMiddleware
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware 创建可选认证中间件
|
||||
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
|
||||
return &OptionalAuthMiddleware{
|
||||
jwtAuth: jwtAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *OptionalAuthMiddleware) GetName() string {
|
||||
return "optional_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *OptionalAuthMiddleware) GetPriority() int {
|
||||
return 60 // 与JWT认证中间件相同
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// 没有认证头部,设置匿名用户标识
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取并验证token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
claims, err := m.jwtAuth.validateToken(tokenString)
|
||||
if err != nil {
|
||||
// token无效,但不返回错误,设置为未认证
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// token有效,设置用户信息
|
||||
c.Set("is_authenticated", true)
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("phone", claims.Phone)
|
||||
c.Set("user_type", claims.UserType)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *OptionalAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// AdminAuthMiddleware 管理员认证中间件
|
||||
type AdminAuthMiddleware struct {
|
||||
jwtAuth *JWTAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAdminAuthMiddleware 创建管理员认证中间件
|
||||
func NewAdminAuthMiddleware(jwtAuth *JWTAuthMiddleware, logger *zap.Logger) *AdminAuthMiddleware {
|
||||
return &AdminAuthMiddleware{
|
||||
jwtAuth: jwtAuth,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *AdminAuthMiddleware) GetName() string {
|
||||
return "admin_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *AdminAuthMiddleware) GetPriority() int {
|
||||
return 60 // 与JWT认证中间件相同
|
||||
}
|
||||
|
||||
// Handle 管理员认证处理
|
||||
func (m *AdminAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 首先进行JWT认证
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "缺少认证头部")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "认证头部格式无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "缺少认证令牌")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.jwtAuth.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("无效的认证令牌",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "认证令牌无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查用户类型是否为管理员
|
||||
if claims.UserType != "admin" {
|
||||
m.respondForbidden(c, "需要管理员权限")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置用户信息到上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("phone", claims.Phone)
|
||||
c.Set("user_type", claims.UserType)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *AdminAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// respondForbidden 返回禁止访问响应
|
||||
func (m *AdminAuthMiddleware) respondForbidden(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"success": false,
|
||||
"message": "权限不足",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// respondUnauthorized 返回未授权响应
|
||||
func (m *AdminAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "认证失败",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
442
internal/shared/middleware/comprehensive_logger.go
Normal file
442
internal/shared/middleware/comprehensive_logger.go
Normal file
@@ -0,0 +1,442 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ComprehensiveLoggerMiddleware 全面日志中间件
|
||||
type ComprehensiveLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
config *ComprehensiveLoggerConfig
|
||||
}
|
||||
|
||||
// ComprehensiveLoggerConfig 全面日志配置
|
||||
type ComprehensiveLoggerConfig struct {
|
||||
EnableRequestLogging bool // 是否记录请求日志
|
||||
EnableResponseLogging bool // 是否记录响应日志
|
||||
EnableRequestBodyLogging bool // 是否记录请求体
|
||||
EnableErrorLogging bool // 是否记录错误日志
|
||||
EnableBusinessLogging bool // 是否记录业务日志
|
||||
EnablePerformanceLogging bool // 是否记录性能日志
|
||||
MaxBodySize int64 // 最大记录体大小
|
||||
ExcludePaths []string // 排除的路径
|
||||
IncludePaths []string // 包含的路径
|
||||
}
|
||||
|
||||
// NewComprehensiveLoggerMiddleware 创建全面日志中间件
|
||||
func NewComprehensiveLoggerMiddleware(logger *zap.Logger, config *ComprehensiveLoggerConfig) *ComprehensiveLoggerMiddleware {
|
||||
if config == nil {
|
||||
config = &ComprehensiveLoggerConfig{
|
||||
EnableRequestLogging: true,
|
||||
EnableResponseLogging: true,
|
||||
EnableRequestBodyLogging: false, // 生产环境默认关闭
|
||||
EnableErrorLogging: true,
|
||||
EnableBusinessLogging: true,
|
||||
EnablePerformanceLogging: true,
|
||||
MaxBodySize: 1024 * 10, // 10KB
|
||||
ExcludePaths: []string{"/health", "/metrics", "/favicon.ico"},
|
||||
}
|
||||
}
|
||||
|
||||
return &ComprehensiveLoggerMiddleware{
|
||||
logger: logger,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ComprehensiveLoggerMiddleware) GetName() string {
|
||||
return "comprehensive_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ComprehensiveLoggerMiddleware) GetPriority() int {
|
||||
return 90 // 高优先级,在panic恢复之后
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ComprehensiveLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 检查是否应该记录此路径
|
||||
if !m.shouldLogPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
requestID := c.GetString("request_id")
|
||||
traceID := c.GetString("trace_id")
|
||||
userID := c.GetString("user_id")
|
||||
|
||||
// 记录请求开始
|
||||
if m.config.EnableRequestLogging {
|
||||
m.logRequest(c, startTime, requestID, traceID, userID)
|
||||
}
|
||||
|
||||
// 捕获请求体(如果需要)
|
||||
var requestBody []byte
|
||||
if m.config.EnableRequestBodyLogging && m.shouldLogRequestBody(c) {
|
||||
requestBody = m.captureRequestBody(c)
|
||||
}
|
||||
|
||||
// 创建响应写入器包装器
|
||||
responseWriter := &responseWriterWrapper{
|
||||
ResponseWriter: c.Writer,
|
||||
logger: m.logger,
|
||||
config: m.config,
|
||||
requestID: requestID,
|
||||
traceID: traceID,
|
||||
userID: userID,
|
||||
startTime: startTime,
|
||||
path: c.Request.URL.Path,
|
||||
method: c.Request.Method,
|
||||
}
|
||||
c.Writer = responseWriter
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 记录响应
|
||||
if m.config.EnableResponseLogging {
|
||||
m.logResponse(c, responseWriter, startTime, requestID, traceID, userID, requestBody)
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if m.config.EnableErrorLogging && len(c.Errors) > 0 {
|
||||
m.logErrors(c, requestID, traceID, userID)
|
||||
}
|
||||
|
||||
// 记录性能指标
|
||||
if m.config.EnablePerformanceLogging {
|
||||
m.logPerformance(c, startTime, requestID, traceID, userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ComprehensiveLoggerMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldLogPath 检查是否应该记录此路径
|
||||
func (m *ComprehensiveLoggerMiddleware) shouldLogPath(path string) bool {
|
||||
// 检查排除路径
|
||||
for _, excludePath := range m.config.ExcludePaths {
|
||||
if strings.HasPrefix(path, excludePath) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查包含路径(如果指定了)
|
||||
if len(m.config.IncludePaths) > 0 {
|
||||
for _, includePath := range m.config.IncludePaths {
|
||||
if strings.HasPrefix(path, includePath) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// shouldLogRequestBody 检查是否应该记录请求体
|
||||
func (m *ComprehensiveLoggerMiddleware) shouldLogRequestBody(c *gin.Context) bool {
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
return strings.Contains(contentType, "application/json") ||
|
||||
strings.Contains(contentType, "application/x-www-form-urlencoded")
|
||||
}
|
||||
|
||||
// captureRequestBody 捕获请求体
|
||||
func (m *ComprehensiveLoggerMiddleware) captureRequestBody(c *gin.Context) []byte {
|
||||
if c.Request.Body == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 重新设置body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
// 限制大小
|
||||
if int64(len(body)) > m.config.MaxBodySize {
|
||||
return body[:m.config.MaxBodySize]
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// logRequest 记录请求日志
|
||||
func (m *ComprehensiveLoggerMiddleware) logRequest(c *gin.Context, startTime time.Time, requestID, traceID, userID string) {
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "request"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("query", c.Request.URL.RawQuery),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.String("user_agent", c.Request.UserAgent()),
|
||||
zap.String("referer", c.Request.Referer()),
|
||||
zap.Int64("content_length", c.Request.ContentLength),
|
||||
zap.String("content_type", c.GetHeader("Content-Type")),
|
||||
zap.Time("timestamp", startTime),
|
||||
}
|
||||
|
||||
m.logger.Info("收到HTTP请求", logFields...)
|
||||
}
|
||||
|
||||
// logResponse 记录响应日志
|
||||
func (m *ComprehensiveLoggerMiddleware) logResponse(c *gin.Context, responseWriter *responseWriterWrapper, startTime time.Time, requestID, traceID, userID string, requestBody []byte) {
|
||||
duration := time.Since(startTime)
|
||||
statusCode := responseWriter.Status()
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "response"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.Duration("duration", duration),
|
||||
zap.Int("response_size", responseWriter.Size()),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
// 添加请求体(如果记录了)
|
||||
if len(requestBody) > 0 {
|
||||
logFields = append(logFields, zap.String("request_body", string(requestBody)))
|
||||
}
|
||||
|
||||
// 根据状态码选择日志级别
|
||||
if statusCode >= 500 {
|
||||
m.logger.Error("HTTP响应错误", logFields...)
|
||||
} else if statusCode >= 400 {
|
||||
m.logger.Warn("HTTP响应警告", logFields...)
|
||||
} else {
|
||||
m.logger.Info("HTTP响应成功", logFields...)
|
||||
}
|
||||
}
|
||||
|
||||
// logErrors 记录错误日志
|
||||
func (m *ComprehensiveLoggerMiddleware) logErrors(c *gin.Context, requestID, traceID, userID string) {
|
||||
for _, ginErr := range c.Errors {
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "error"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Uint64("error_type", uint64(ginErr.Type)),
|
||||
zap.Error(ginErr.Err),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
m.logger.Error("请求处理错误", logFields...)
|
||||
}
|
||||
}
|
||||
|
||||
// logPerformance 记录性能日志
|
||||
func (m *ComprehensiveLoggerMiddleware) logPerformance(c *gin.Context, startTime time.Time, requestID, traceID, userID string) {
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 记录慢请求
|
||||
if duration > 1*time.Second {
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "performance"),
|
||||
zap.String("performance_type", "slow_request"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Duration("duration", duration),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
m.logger.Warn("检测到慢请求", logFields...)
|
||||
}
|
||||
|
||||
// 记录性能指标
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "performance"),
|
||||
zap.String("performance_type", "request_metrics"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.Duration("duration", duration),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
m.logger.Debug("请求性能指标", logFields...)
|
||||
}
|
||||
|
||||
// responseWriterWrapper 响应写入器包装器
|
||||
type responseWriterWrapper struct {
|
||||
gin.ResponseWriter
|
||||
logger *zap.Logger
|
||||
config *ComprehensiveLoggerConfig
|
||||
requestID string
|
||||
traceID string
|
||||
userID string
|
||||
startTime time.Time
|
||||
path string
|
||||
method string
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
// Write 实现Write方法
|
||||
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||
size, err := w.ResponseWriter.Write(b)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
// WriteHeader 实现WriteHeader方法
|
||||
func (w *responseWriterWrapper) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// WriteString 实现WriteString方法
|
||||
func (w *responseWriterWrapper) WriteString(s string) (int, error) {
|
||||
size, err := w.ResponseWriter.WriteString(s)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
|
||||
// Status 获取状态码
|
||||
func (w *responseWriterWrapper) Status() int {
|
||||
return w.status
|
||||
}
|
||||
|
||||
// Size 获取响应大小
|
||||
func (w *responseWriterWrapper) Size() int {
|
||||
return w.size
|
||||
}
|
||||
|
||||
// BusinessLogger 业务日志记录器
|
||||
type BusinessLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewBusinessLogger 创建业务日志记录器
|
||||
func NewBusinessLogger(logger *zap.Logger) *BusinessLogger {
|
||||
return &BusinessLogger{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// LogUserAction 记录用户操作
|
||||
func (bl *BusinessLogger) LogUserAction(ctx context.Context, action string, details map[string]interface{}) {
|
||||
requestID := bl.getRequestIDFromContext(ctx)
|
||||
traceID := bl.getTraceIDFromContext(ctx)
|
||||
userID := bl.getUserIDFromContext(ctx)
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "business"),
|
||||
zap.String("business_type", "user_action"),
|
||||
zap.String("action", action),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
// 添加详细信息
|
||||
for key, value := range details {
|
||||
logFields = append(logFields, zap.Any(key, value))
|
||||
}
|
||||
|
||||
bl.logger.Info("用户操作记录", logFields...)
|
||||
}
|
||||
|
||||
// LogBusinessEvent 记录业务事件
|
||||
func (bl *BusinessLogger) LogBusinessEvent(ctx context.Context, event string, details map[string]interface{}) {
|
||||
requestID := bl.getRequestIDFromContext(ctx)
|
||||
traceID := bl.getTraceIDFromContext(ctx)
|
||||
userID := bl.getUserIDFromContext(ctx)
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "business"),
|
||||
zap.String("business_type", "business_event"),
|
||||
zap.String("event", event),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
// 添加详细信息
|
||||
for key, value := range details {
|
||||
logFields = append(logFields, zap.Any(key, value))
|
||||
}
|
||||
|
||||
bl.logger.Info("业务事件记录", logFields...)
|
||||
}
|
||||
|
||||
// LogSystemEvent 记录系统事件
|
||||
func (bl *BusinessLogger) LogSystemEvent(ctx context.Context, event string, details map[string]interface{}) {
|
||||
requestID := bl.getRequestIDFromContext(ctx)
|
||||
traceID := bl.getTraceIDFromContext(ctx)
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.String("log_type", "business"),
|
||||
zap.String("business_type", "system_event"),
|
||||
zap.String("event", event),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.Time("timestamp", time.Now()),
|
||||
}
|
||||
|
||||
// 添加详细信息
|
||||
for key, value := range details {
|
||||
logFields = append(logFields, zap.Any(key, value))
|
||||
}
|
||||
|
||||
bl.logger.Info("系统事件记录", logFields...)
|
||||
}
|
||||
|
||||
// 辅助方法
|
||||
func (bl *BusinessLogger) getRequestIDFromContext(ctx context.Context) string {
|
||||
if requestID := ctx.Value("request_id"); requestID != nil {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (bl *BusinessLogger) getTraceIDFromContext(ctx context.Context) string {
|
||||
if traceID := ctx.Value("trace_id"); traceID != nil {
|
||||
if id, ok := traceID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (bl *BusinessLogger) getUserIDFromContext(ctx context.Context) string {
|
||||
if userID := ctx.Value("user_id"); userID != nil {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
142
internal/shared/middleware/cors.go
Normal file
142
internal/shared/middleware/cors.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"hyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware CORS中间件
|
||||
type CORSMiddleware struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewCORSMiddleware 创建CORS中间件
|
||||
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
|
||||
return &CORSMiddleware{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *CORSMiddleware) GetName() string {
|
||||
return "cors"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *CORSMiddleware) GetPriority() int {
|
||||
return 95 // 在PanicRecovery(100)之后,SecurityHeaders(85)之前执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.config.Development.EnableCors {
|
||||
// 如果没有启用CORS,返回空处理函数
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// 获取CORS配置
|
||||
origins := m.getAllowedOrigins()
|
||||
methods := m.getAllowedMethods()
|
||||
headers := m.getAllowedHeaders()
|
||||
|
||||
config := cors.Config{
|
||||
AllowAllOrigins: false,
|
||||
AllowOrigins: origins,
|
||||
AllowMethods: methods,
|
||||
AllowHeaders: headers,
|
||||
ExposeHeaders: []string{
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
"Access-Control-Allow-Origin",
|
||||
"Access-Control-Allow-Methods",
|
||||
"Access-Control-Allow-Headers",
|
||||
},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24小时
|
||||
// 增加Chrome兼容性
|
||||
AllowWildcard: false,
|
||||
AllowBrowserExtensions: false,
|
||||
}
|
||||
|
||||
// 创建CORS中间件
|
||||
corsMiddleware := cors.New(config)
|
||||
|
||||
// 返回包装后的中间件
|
||||
return func(c *gin.Context) {
|
||||
// 调用实际的CORS中间件
|
||||
corsMiddleware(c)
|
||||
|
||||
// 继续处理下一个中间件或处理器
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *CORSMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getAllowedOrigins 获取允许的来源
|
||||
func (m *CORSMiddleware) getAllowedOrigins() []string {
|
||||
if m.config.Development.CorsOrigins == "" {
|
||||
return []string{"http://localhost:3000", "http://localhost:8080"}
|
||||
}
|
||||
|
||||
// 解析配置中的origins字符串,按逗号分隔
|
||||
origins := strings.Split(m.config.Development.CorsOrigins, ",")
|
||||
// 去除空格
|
||||
for i, origin := range origins {
|
||||
origins[i] = strings.TrimSpace(origin)
|
||||
}
|
||||
return origins
|
||||
}
|
||||
|
||||
// getAllowedMethods 获取允许的方法
|
||||
func (m *CORSMiddleware) getAllowedMethods() []string {
|
||||
if m.config.Development.CorsMethods == "" {
|
||||
return []string{
|
||||
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
|
||||
}
|
||||
}
|
||||
|
||||
// 解析配置中的methods字符串,按逗号分隔
|
||||
methods := strings.Split(m.config.Development.CorsMethods, ",")
|
||||
// 去除空格
|
||||
for i, method := range methods {
|
||||
methods[i] = strings.TrimSpace(method)
|
||||
}
|
||||
return methods
|
||||
}
|
||||
|
||||
// getAllowedHeaders 获取允许的头部
|
||||
func (m *CORSMiddleware) getAllowedHeaders() []string {
|
||||
if m.config.Development.CorsHeaders == "" {
|
||||
return []string{
|
||||
"Origin",
|
||||
"Content-Type",
|
||||
"Content-Length",
|
||||
"Accept",
|
||||
"Accept-Encoding",
|
||||
"Accept-Language",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"X-Request-ID",
|
||||
"Access-Id",
|
||||
}
|
||||
}
|
||||
|
||||
// 解析配置中的headers字符串,按逗号分隔
|
||||
headers := strings.Split(m.config.Development.CorsHeaders, ",")
|
||||
// 去除空格
|
||||
for i, header := range headers {
|
||||
headers[i] = strings.TrimSpace(header)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
616
internal/shared/middleware/daily_rate_limit.go
Normal file
616
internal/shared/middleware/daily_rate_limit.go
Normal file
@@ -0,0 +1,616 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
securityEntities "hyapi-server/internal/domains/security/entities"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// DailyRateLimitConfig 每日限流配置
|
||||
type DailyRateLimitConfig struct {
|
||||
MaxRequestsPerDay int `mapstructure:"max_requests_per_day"` // 每日最大请求次数
|
||||
MaxRequestsPerIP int `mapstructure:"max_requests_per_ip"` // 每个IP每日最大请求次数
|
||||
KeyPrefix string `mapstructure:"key_prefix"` // Redis键前缀
|
||||
TTL time.Duration `mapstructure:"ttl"` // 键过期时间
|
||||
// 新增安全配置
|
||||
EnableIPWhitelist bool `mapstructure:"enable_ip_whitelist"` // 是否启用IP白名单
|
||||
IPWhitelist []string `mapstructure:"ip_whitelist"` // IP白名单
|
||||
EnableIPBlacklist bool `mapstructure:"enable_ip_blacklist"` // 是否启用IP黑名单
|
||||
IPBlacklist []string `mapstructure:"ip_blacklist"` // IP黑名单
|
||||
EnableUserAgent bool `mapstructure:"enable_user_agent"` // 是否检查User-Agent
|
||||
BlockedUserAgents []string `mapstructure:"blocked_user_agents"` // 被阻止的User-Agent
|
||||
EnableReferer bool `mapstructure:"enable_referer"` // 是否检查Referer
|
||||
AllowedReferers []string `mapstructure:"allowed_referers"` // 允许的Referer
|
||||
EnableGeoBlock bool `mapstructure:"enable_geo_block"` // 是否启用地理位置阻止
|
||||
BlockedCountries []string `mapstructure:"blocked_countries"` // 被阻止的国家/地区
|
||||
EnableProxyCheck bool `mapstructure:"enable_proxy_check"` // 是否检查代理
|
||||
MaxConcurrent int `mapstructure:"max_concurrent"` // 最大并发请求数
|
||||
|
||||
// 路径排除配置
|
||||
ExcludePaths []string `mapstructure:"exclude_paths"` // 排除频率限制的路径
|
||||
// 域名排除配置
|
||||
ExcludeDomains []string `mapstructure:"exclude_domains"` // 排除频率限制的域名
|
||||
}
|
||||
|
||||
// DailyRateLimitMiddleware 每日请求限制中间件
|
||||
type DailyRateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
redis *redis.Client
|
||||
db *gorm.DB
|
||||
response interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
limitConfig DailyRateLimitConfig
|
||||
}
|
||||
|
||||
// NewDailyRateLimitMiddleware 创建每日请求限制中间件
|
||||
func NewDailyRateLimitMiddleware(
|
||||
cfg *config.Config,
|
||||
redis *redis.Client,
|
||||
db *gorm.DB,
|
||||
response interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
limitConfig DailyRateLimitConfig,
|
||||
) *DailyRateLimitMiddleware {
|
||||
// 设置默认值
|
||||
if limitConfig.MaxRequestsPerDay <= 0 {
|
||||
limitConfig.MaxRequestsPerDay = 200 // 默认每日200次
|
||||
}
|
||||
if limitConfig.MaxRequestsPerIP <= 0 {
|
||||
limitConfig.MaxRequestsPerIP = 10 // 默认每个IP每日10次
|
||||
}
|
||||
if limitConfig.KeyPrefix == "" {
|
||||
limitConfig.KeyPrefix = "daily_limit"
|
||||
}
|
||||
if limitConfig.TTL == 0 {
|
||||
limitConfig.TTL = 24 * time.Hour // 默认24小时过期
|
||||
}
|
||||
if limitConfig.MaxConcurrent <= 0 {
|
||||
limitConfig.MaxConcurrent = 5 // 默认最大并发5个
|
||||
}
|
||||
|
||||
return &DailyRateLimitMiddleware{
|
||||
config: cfg,
|
||||
redis: redis,
|
||||
db: db,
|
||||
response: response,
|
||||
logger: logger,
|
||||
limitConfig: limitConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *DailyRateLimitMiddleware) GetName() string {
|
||||
return "daily_rate_limit"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *DailyRateLimitMiddleware) GetPriority() int {
|
||||
return 85 // 在认证之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *DailyRateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ctx := c.Request.Context()
|
||||
if m.config.App.IsDevelopment() {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// 检查是否在排除路径中
|
||||
if m.isExcludedPath(c.Request.URL.Path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
// 开发环境debug模式下跳过
|
||||
if m.config.Development.Debug {
|
||||
m.logger.Info("开发环境debug模式下跳过每日限流", zap.String("path", c.Request.URL.Path))
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否在排除域名中
|
||||
host := c.Request.Host
|
||||
if m.isExcludedDomain(host) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 获取客户端标识
|
||||
clientIP := m.getClientIP(c)
|
||||
|
||||
// 1. 检查IP白名单/黑名单
|
||||
if err := m.checkIPAccess(clientIP); err != nil {
|
||||
m.logger.Warn("IP访问被拒绝",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
zap.Error(err))
|
||||
m.response.Forbidden(c, "访问被拒绝")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 检查User-Agent
|
||||
if err := m.checkUserAgent(c); err != nil {
|
||||
m.logger.Warn("User-Agent被阻止",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
zap.Error(err))
|
||||
m.response.Forbidden(c, "访问被拒绝")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 检查Referer
|
||||
if err := m.checkReferer(c); err != nil {
|
||||
m.logger.Warn("Referer检查失败",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("referer", c.GetHeader("Referer")),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
zap.Error(err))
|
||||
m.response.Forbidden(c, "访问被拒绝")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 检查并发限制
|
||||
concurrentCount, err := m.checkConcurrentLimit(ctx, clientIP)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
|
||||
m.logger.Warn("并发请求超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
zap.Error(err))
|
||||
m.response.TooManyRequests(c, "系统繁忙,请稍后再试")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(concurrentCount, m.limitConfig.MaxConcurrent) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_concurrent_limit")
|
||||
}
|
||||
|
||||
// 5. 检查接口总请求次数限制
|
||||
totalCount, err := m.checkTotalLimit(ctx)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
|
||||
m.logger.Warn("接口总请求次数超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
zap.Error(err))
|
||||
// 隐藏限制信息,返回通用错误
|
||||
m.response.InternalError(c, "系统繁忙,请稍后再试")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(totalCount+1, m.limitConfig.MaxRequestsPerDay) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_total_limit")
|
||||
}
|
||||
|
||||
// 6. 检查IP限制
|
||||
ipCount, err := m.checkIPLimit(ctx, clientIP)
|
||||
if err != nil {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
|
||||
m.logger.Warn("IP请求次数超限",
|
||||
zap.String("ip", clientIP),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
zap.Error(err))
|
||||
// 隐藏限制信息,返回通用错误
|
||||
m.response.InternalError(c, "系统繁忙,请稍后再试")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if m.shouldRecordNearLimit(ipCount+1, m.limitConfig.MaxRequestsPerIP) {
|
||||
m.recordSuspiciousRequest(c, clientIP, "daily_ip_limit")
|
||||
}
|
||||
|
||||
// 7. 增加计数
|
||||
m.incrementCounters(ctx, clientIP)
|
||||
|
||||
// 8. 添加隐藏的响应头(仅用于内部监控)
|
||||
m.addHiddenHeaders(c, clientIP)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DailyRateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
|
||||
if m.db == nil {
|
||||
return
|
||||
}
|
||||
record := securityEntities.SuspiciousIPRecord{
|
||||
IP: ip,
|
||||
Path: c.Request.URL.Path,
|
||||
Method: c.Request.Method,
|
||||
RequestCount: 1,
|
||||
WindowSeconds: int(m.limitConfig.TTL.Seconds()),
|
||||
TriggerReason: reason,
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
if record.WindowSeconds <= 0 {
|
||||
record.WindowSeconds = 10
|
||||
}
|
||||
if err := m.db.Create(&record).Error; err != nil {
|
||||
m.logger.Warn("记录每日限流可疑IP失败", zap.String("ip", ip), zap.String("reason", reason), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *DailyRateLimitMiddleware) shouldRecordNearLimit(current, max int) bool {
|
||||
if max <= 0 {
|
||||
return false
|
||||
}
|
||||
threshold := int(math.Ceil(float64(max) * 0.8))
|
||||
if threshold < 1 {
|
||||
threshold = 1
|
||||
}
|
||||
return current >= threshold
|
||||
}
|
||||
|
||||
// isExcludedDomain 检查域名是否在排除列表中
|
||||
func (m *DailyRateLimitMiddleware) isExcludedDomain(host string) bool {
|
||||
for _, excludeDomain := range m.limitConfig.ExcludeDomains {
|
||||
// 支持通配符匹配
|
||||
if strings.HasPrefix(excludeDomain, "*") {
|
||||
// 后缀匹配,如 "*.api.example.com" 匹配 "api.example.com"
|
||||
if strings.HasSuffix(host, excludeDomain[1:]) {
|
||||
return true
|
||||
}
|
||||
} else if strings.HasSuffix(excludeDomain, "*") {
|
||||
// 前缀匹配,如 "api.*" 匹配 "api.example.com"
|
||||
if strings.HasPrefix(host, excludeDomain[:len(excludeDomain)-1]) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
// 精确匹配
|
||||
if host == excludeDomain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isExcludedPath 检查路径是否在排除列表中
|
||||
func (m *DailyRateLimitMiddleware) isExcludedPath(path string) bool {
|
||||
for _, excludePath := range m.limitConfig.ExcludePaths {
|
||||
// 支持多种匹配模式
|
||||
if strings.HasPrefix(excludePath, "*") {
|
||||
// 前缀匹配,如 "*api_name" 匹配 "/api/v1/any_api_name"
|
||||
if strings.Contains(path, excludePath[1:]) {
|
||||
return true
|
||||
}
|
||||
} else if strings.HasSuffix(excludePath, "*") {
|
||||
// 后缀匹配,如 "/api/v1/*" 匹配 "/api/v1/any_api_name"
|
||||
if strings.HasPrefix(path, excludePath[:len(excludePath)-1]) {
|
||||
return true
|
||||
}
|
||||
} else if strings.Contains(excludePath, "*") {
|
||||
// 中间通配符匹配,如 "/api/v1/*api_name" 匹配 "/api/v1/any_api_name"
|
||||
parts := strings.Split(excludePath, "*")
|
||||
if len(parts) == 2 {
|
||||
prefix := parts[0]
|
||||
suffix := parts[1]
|
||||
if strings.HasPrefix(path, prefix) && strings.HasSuffix(path, suffix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 精确匹配
|
||||
if path == excludePath {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *DailyRateLimitMiddleware) IsGlobal() bool {
|
||||
return false // 不是全局中间件,需要手动应用到特定路由
|
||||
}
|
||||
|
||||
// checkIPAccess 检查IP访问权限
|
||||
func (m *DailyRateLimitMiddleware) checkIPAccess(clientIP string) error {
|
||||
// 检查黑名单
|
||||
if m.limitConfig.EnableIPBlacklist {
|
||||
for _, blockedIP := range m.limitConfig.IPBlacklist {
|
||||
if m.isIPMatch(clientIP, blockedIP) {
|
||||
return fmt.Errorf("IP %s 在黑名单中", clientIP)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查白名单(如果启用)
|
||||
if m.limitConfig.EnableIPWhitelist {
|
||||
allowed := false
|
||||
for _, allowedIP := range m.limitConfig.IPWhitelist {
|
||||
if m.isIPMatch(clientIP, allowedIP) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
return fmt.Errorf("IP %s 不在白名单中", clientIP)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isIPMatch 检查IP是否匹配(支持CIDR和通配符)
|
||||
func (m *DailyRateLimitMiddleware) isIPMatch(clientIP, pattern string) bool {
|
||||
// 简单的通配符匹配
|
||||
if strings.Contains(pattern, "*") {
|
||||
parts := strings.Split(pattern, ".")
|
||||
clientParts := strings.Split(clientIP, ".")
|
||||
if len(parts) != len(clientParts) {
|
||||
return false
|
||||
}
|
||||
for i, part := range parts {
|
||||
if part != "*" && part != clientParts[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// 精确匹配
|
||||
return clientIP == pattern
|
||||
}
|
||||
|
||||
// checkUserAgent 检查User-Agent
|
||||
func (m *DailyRateLimitMiddleware) checkUserAgent(c *gin.Context) error {
|
||||
if !m.limitConfig.EnableUserAgent {
|
||||
return nil
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
if userAgent == "" {
|
||||
return fmt.Errorf("缺少User-Agent")
|
||||
}
|
||||
|
||||
// 检查被阻止的User-Agent
|
||||
for _, blocked := range m.limitConfig.BlockedUserAgents {
|
||||
if strings.Contains(strings.ToLower(userAgent), strings.ToLower(blocked)) {
|
||||
return fmt.Errorf("User-Agent被阻止: %s", blocked)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkReferer 检查Referer
|
||||
func (m *DailyRateLimitMiddleware) checkReferer(c *gin.Context) error {
|
||||
if !m.limitConfig.EnableReferer {
|
||||
return nil
|
||||
}
|
||||
|
||||
referer := c.GetHeader("Referer")
|
||||
if referer == "" {
|
||||
return fmt.Errorf("缺少Referer")
|
||||
}
|
||||
|
||||
// 检查允许的Referer
|
||||
if len(m.limitConfig.AllowedReferers) > 0 {
|
||||
allowed := false
|
||||
for _, allowedRef := range m.limitConfig.AllowedReferers {
|
||||
if strings.Contains(referer, allowedRef) {
|
||||
allowed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !allowed {
|
||||
return fmt.Errorf("Referer不被允许: %s", referer)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkConcurrentLimit 检查并发限制
|
||||
func (m *DailyRateLimitMiddleware) checkConcurrentLimit(ctx context.Context, clientIP string) (int, error) {
|
||||
key := fmt.Sprintf("%s:concurrent:%s", m.limitConfig.KeyPrefix, clientIP)
|
||||
|
||||
// 获取当前并发数
|
||||
current, err := m.redis.Get(ctx, key).Result()
|
||||
if err != nil && err != redis.Nil {
|
||||
return 0, fmt.Errorf("获取并发计数失败: %w", err)
|
||||
}
|
||||
|
||||
currentCount := 0
|
||||
if current != "" {
|
||||
if count, err := strconv.Atoi(current); err == nil {
|
||||
currentCount = count
|
||||
}
|
||||
}
|
||||
|
||||
if currentCount >= m.limitConfig.MaxConcurrent {
|
||||
return currentCount, fmt.Errorf("并发请求超限: %d", currentCount)
|
||||
}
|
||||
|
||||
// 增加并发计数
|
||||
pipe := m.redis.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, 30*time.Second) // 30秒过期
|
||||
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
m.logger.Error("增加并发计数失败", zap.String("key", key), zap.Error(err))
|
||||
}
|
||||
|
||||
return currentCount + 1, nil
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP地址(增强版)
|
||||
func (m *DailyRateLimitMiddleware) getClientIP(c *gin.Context) string {
|
||||
// 检查是否为代理IP
|
||||
if m.limitConfig.EnableProxyCheck {
|
||||
// 检查常见的代理头部
|
||||
proxyHeaders := []string{
|
||||
"CF-Connecting-IP", // Cloudflare
|
||||
"X-Forwarded-For", // 标准代理头
|
||||
"X-Real-IP", // Nginx
|
||||
"X-Client-IP", // Apache
|
||||
"X-Forwarded", // 其他代理
|
||||
"Forwarded-For", // RFC 7239
|
||||
"Forwarded", // RFC 7239
|
||||
}
|
||||
|
||||
for _, header := range proxyHeaders {
|
||||
if ip := c.GetHeader(header); ip != "" {
|
||||
// 如果X-Forwarded-For包含多个IP,取第一个
|
||||
if header == "X-Forwarded-For" && strings.Contains(ip, ",") {
|
||||
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
|
||||
}
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到标准方法
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
if strings.Contains(xff, ",") {
|
||||
return strings.TrimSpace(strings.Split(xff, ",")[0])
|
||||
}
|
||||
return xff
|
||||
}
|
||||
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// checkTotalLimit 检查接口总请求次数限制
|
||||
func (m *DailyRateLimitMiddleware) checkTotalLimit(ctx context.Context) (int, error) {
|
||||
key := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||||
|
||||
count, err := m.getCounter(ctx, key)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("获取总请求计数失败: %w", err)
|
||||
}
|
||||
|
||||
if count >= m.limitConfig.MaxRequestsPerDay {
|
||||
return count, fmt.Errorf("接口今日总请求次数已达上限 %d", m.limitConfig.MaxRequestsPerDay)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// checkIPLimit 检查IP限制
|
||||
func (m *DailyRateLimitMiddleware) checkIPLimit(ctx context.Context, clientIP string) (int, error) {
|
||||
key := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||||
|
||||
count, err := m.getCounter(ctx, key)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("获取IP计数失败: %w", err)
|
||||
}
|
||||
|
||||
if count >= m.limitConfig.MaxRequestsPerIP {
|
||||
return count, fmt.Errorf("IP %s 今日请求次数已达上限 %d", clientIP, m.limitConfig.MaxRequestsPerIP)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// incrementCounters 增加计数器
|
||||
func (m *DailyRateLimitMiddleware) incrementCounters(ctx context.Context, clientIP string) {
|
||||
// 增加总请求计数
|
||||
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||||
m.incrementCounter(ctx, totalKey)
|
||||
|
||||
// 增加IP计数
|
||||
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||||
m.incrementCounter(ctx, ipKey)
|
||||
}
|
||||
|
||||
// getCounter 获取计数器值
|
||||
func (m *DailyRateLimitMiddleware) getCounter(ctx context.Context, key string) (int, error) {
|
||||
val, err := m.redis.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return 0, nil // 键不存在,计数为0
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
count, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("解析计数失败: %w", err)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// incrementCounter 增加计数器
|
||||
func (m *DailyRateLimitMiddleware) incrementCounter(ctx context.Context, key string) {
|
||||
// 使用Redis的INCR命令增加计数
|
||||
pipe := m.redis.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, m.limitConfig.TTL)
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
m.logger.Error("增加计数器失败", zap.String("key", key), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// getDateKey 获取日期键(格式:2024-01-01)
|
||||
func (m *DailyRateLimitMiddleware) getDateKey() string {
|
||||
return time.Now().Format("2006-01-02")
|
||||
}
|
||||
|
||||
// addHiddenHeaders 添加隐藏的响应头(仅用于内部监控)
|
||||
func (m *DailyRateLimitMiddleware) addHiddenHeaders(c *gin.Context, clientIP string) {
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 添加隐藏的监控头(客户端看不到)
|
||||
totalKey := fmt.Sprintf("%s:total:%s", m.limitConfig.KeyPrefix, m.getDateKey())
|
||||
totalCount, _ := m.getCounter(ctx, totalKey)
|
||||
|
||||
ipKey := fmt.Sprintf("%s:ip:%s:%s", m.limitConfig.KeyPrefix, clientIP, m.getDateKey())
|
||||
ipCount, _ := m.getCounter(ctx, ipKey)
|
||||
|
||||
// 使用非标准的头部名称,避免被客户端识别
|
||||
c.Header("X-System-Status", "normal")
|
||||
c.Header("X-Total-Count", strconv.Itoa(totalCount))
|
||||
c.Header("X-IP-Count", strconv.Itoa(ipCount))
|
||||
c.Header("X-Reset-Time", m.getResetTime().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// getResetTime 获取重置时间(明天0点)
|
||||
func (m *DailyRateLimitMiddleware) getResetTime() time.Time {
|
||||
now := time.Now()
|
||||
tomorrow := now.Add(24 * time.Hour)
|
||||
return time.Date(tomorrow.Year(), tomorrow.Month(), tomorrow.Day(), 0, 0, 0, 0, tomorrow.Location())
|
||||
}
|
||||
|
||||
// GetStats 获取限流统计
|
||||
func (m *DailyRateLimitMiddleware) GetStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"max_requests_per_day": m.limitConfig.MaxRequestsPerDay,
|
||||
"max_requests_per_ip": m.limitConfig.MaxRequestsPerIP,
|
||||
"max_concurrent": m.limitConfig.MaxConcurrent,
|
||||
"key_prefix": m.limitConfig.KeyPrefix,
|
||||
"ttl": m.limitConfig.TTL.String(),
|
||||
"security_features": map[string]interface{}{
|
||||
"ip_whitelist_enabled": m.limitConfig.EnableIPWhitelist,
|
||||
"ip_blacklist_enabled": m.limitConfig.EnableIPBlacklist,
|
||||
"user_agent_check": m.limitConfig.EnableUserAgent,
|
||||
"referer_check": m.limitConfig.EnableReferer,
|
||||
"proxy_check": m.limitConfig.EnableProxyCheck,
|
||||
},
|
||||
}
|
||||
}
|
||||
72
internal/shared/middleware/domain.go
Normal file
72
internal/shared/middleware/domain.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"hyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DomainAuthMiddleware 域名认证中间件
|
||||
type DomainAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewDomainAuthMiddleware 创建域名认证中间件
|
||||
func NewDomainAuthMiddleware(cfg *config.Config, logger *zap.Logger) *DomainAuthMiddleware {
|
||||
return &DomainAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *DomainAuthMiddleware) GetName() string {
|
||||
return "domain_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *DomainAuthMiddleware) GetPriority() int {
|
||||
return 100
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *DomainAuthMiddleware) Handle(domain string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
|
||||
// 开发环境下跳过外部验证
|
||||
if m.config.App.IsDevelopment() {
|
||||
m.logger.Info("开发环境:跳过域名验证",
|
||||
zap.String("domain", domain))
|
||||
c.Next()
|
||||
}
|
||||
if domain == "" {
|
||||
domain = m.config.API.Domain
|
||||
}
|
||||
host := c.Request.Host
|
||||
|
||||
// 移除端口部分
|
||||
if idx := strings.Index(host, ":"); idx != -1 {
|
||||
host = host[:idx]
|
||||
}
|
||||
m.logger.Info("域名认证中间件检查", zap.String("host", host), zap.String("domain", domain))
|
||||
if host == domain || host == "api.haiyudata.com" || host == "apitest.haiyudata.com" {
|
||||
// 设置域名匹配标记
|
||||
c.Set("domainMatched", domain)
|
||||
c.Next()
|
||||
} else {
|
||||
// 不匹配域名,跳过当前组处理,继续执行其他路由
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *DomainAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
104
internal/shared/middleware/panic_recovery.go
Normal file
104
internal/shared/middleware/panic_recovery.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PanicRecoveryMiddleware Panic恢复中间件
|
||||
type PanicRecoveryMiddleware struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewPanicRecoveryMiddleware 创建Panic恢复中间件
|
||||
func NewPanicRecoveryMiddleware(logger *zap.Logger) *PanicRecoveryMiddleware {
|
||||
return &PanicRecoveryMiddleware{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *PanicRecoveryMiddleware) GetName() string {
|
||||
return "panic_recovery"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *PanicRecoveryMiddleware) GetPriority() int {
|
||||
return 100 // 最高优先级,第一个执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *PanicRecoveryMiddleware) Handle() gin.HandlerFunc {
|
||||
return gin.RecoveryWithWriter(&panicLogger{logger: m.logger})
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *PanicRecoveryMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// panicLogger 实现io.Writer接口,用于记录panic信息
|
||||
type panicLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Write 实现io.Writer接口
|
||||
func (pl *panicLogger) Write(p []byte) (n int, err error) {
|
||||
pl.logger.Error("系统发生严重错误",
|
||||
zap.String("error_type", "panic"),
|
||||
zap.String("stack_trace", string(p)),
|
||||
zap.String("timestamp", time.Now().Format("2006-01-02 15:04:05")),
|
||||
)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
// CustomPanicRecovery 自定义panic恢复中间件
|
||||
func (m *PanicRecoveryMiddleware) CustomPanicRecovery() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// 获取请求信息
|
||||
requestID := c.GetString("request_id")
|
||||
traceID := c.GetString("trace_id")
|
||||
userID := c.GetString("user_id")
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
path := c.Request.URL.Path
|
||||
userAgent := c.Request.UserAgent()
|
||||
|
||||
// 记录详细的panic信息
|
||||
m.logger.Error("系统发生严重错误",
|
||||
zap.Any("panic_error", err),
|
||||
zap.String("error_type", "panic"),
|
||||
zap.String("request_id", requestID),
|
||||
zap.String("trace_id", traceID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.String("user_agent", userAgent),
|
||||
zap.String("stack_trace", string(debug.Stack())),
|
||||
zap.String("timestamp", time.Now().Format("2006-01-02 15:04:05")),
|
||||
)
|
||||
|
||||
// 返回500错误响应
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"success": false,
|
||||
"message": "服务器内部错误",
|
||||
"error_code": "INTERNAL_SERVER_ERROR",
|
||||
"request_id": requestID,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
|
||||
// 中止请求处理
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
196
internal/shared/middleware/ratelimit.go
Normal file
196
internal/shared/middleware/ratelimit.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
"hyapi-server/internal/config"
|
||||
securityEntities "hyapi-server/internal/domains/security/entities"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/time/rate"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
response interfaces.ResponseBuilder
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder, db *gorm.DB, logger *zap.Logger) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
response: response,
|
||||
db: db,
|
||||
logger: logger,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RateLimitMiddleware) GetName() string {
|
||||
return "ratelimit"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RateLimitMiddleware) GetPriority() int {
|
||||
return 90 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端标识(IP地址)
|
||||
clientID := m.getClientID(c)
|
||||
|
||||
// 获取或创建限流器
|
||||
limiter := m.getLimiter(clientID)
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
m.recordSuspiciousRequest(c, clientID, "rate_limit")
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
c.Header("Retry-After", "60")
|
||||
|
||||
// 使用统一的响应格式
|
||||
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) recordSuspiciousRequest(c *gin.Context, ip, reason string) {
|
||||
if m.db == nil {
|
||||
return
|
||||
}
|
||||
windowSeconds := int(m.config.RateLimit.Window.Seconds())
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 1
|
||||
}
|
||||
record := securityEntities.SuspiciousIPRecord{
|
||||
IP: ip,
|
||||
Path: c.Request.URL.Path,
|
||||
Method: c.Request.Method,
|
||||
RequestCount: 1,
|
||||
WindowSeconds: windowSeconds,
|
||||
TriggerReason: reason,
|
||||
UserAgent: c.GetHeader("User-Agent"),
|
||||
}
|
||||
if err := m.db.Create(&record).Error; err != nil && m.logger != nil {
|
||||
m.logger.Warn("记录可疑IP失败", zap.String("ip", ip), zap.String("path", record.Path), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RateLimitMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getClientID 获取客户端标识
|
||||
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
|
||||
// 优先使用X-Forwarded-For头部
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
return xff
|
||||
}
|
||||
|
||||
// 使用X-Real-IP头部
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 使用RemoteAddr
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// getLimiter 获取或创建限流器
|
||||
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
|
||||
m.mutex.RLock()
|
||||
limiter, exists := m.limiters[clientID]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists := m.limiters[clientID]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 创建新的限流器
|
||||
// rate.Every计算每个请求之间的间隔
|
||||
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
|
||||
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
|
||||
|
||||
m.limiters[clientID] = limiter
|
||||
|
||||
// 启动清理协程(仅第一次创建时)
|
||||
if len(m.limiters) == 1 {
|
||||
go m.cleanupRoutine()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanupRoutine() {
|
||||
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanup() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for clientID, limiter := range m.limiters {
|
||||
// 如果限流器在过去1小时内没有被使用,则删除它
|
||||
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
|
||||
delete(m.limiters, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取限流统计
|
||||
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"active_limiters": len(m.limiters),
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests": m.config.RateLimit.Requests,
|
||||
"window": m.config.RateLimit.Window,
|
||||
"burst": m.config.RateLimit.Burst,
|
||||
},
|
||||
}
|
||||
}
|
||||
498
internal/shared/middleware/request_logger.go
Normal file
498
internal/shared/middleware/request_logger.go
Normal file
@@ -0,0 +1,498 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/shared/tracing"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 请求日志中间件
|
||||
type RequestLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
useColoredLog bool
|
||||
isDevelopment bool
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
|
||||
return &RequestLoggerMiddleware{
|
||||
logger: logger,
|
||||
useColoredLog: isDevelopment, // 开发环境使用彩色日志
|
||||
isDevelopment: isDevelopment,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestLoggerMiddleware) GetName() string {
|
||||
return "request_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestLoggerMiddleware) GetPriority() int {
|
||||
return 80 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
if m.useColoredLog {
|
||||
// 开发环境:使用Gin默认的彩色日志格式
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var statusColor, methodColor, resetColor string
|
||||
if param.IsOutputColor() {
|
||||
statusColor = param.StatusCodeColor()
|
||||
methodColor = param.MethodColor()
|
||||
resetColor = param.ResetColor()
|
||||
}
|
||||
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
|
||||
// 获取TraceID
|
||||
traceID := param.Request.Header.Get("X-Trace-ID")
|
||||
if traceID == "" && m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(param.Request.Context())
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if param.StatusCode >= 400 && m.tracer != nil {
|
||||
span := trace.SpanFromContext(param.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
traceInfo := ""
|
||||
if traceID != "" {
|
||||
traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
statusColor, param.StatusCode, resetColor,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
methodColor, param.Method, resetColor,
|
||||
param.Path,
|
||||
traceInfo,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
// 生产环境:使用结构化JSON日志
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 获取TraceID
|
||||
traceID := param.Request.Header.Get("X-Trace-ID")
|
||||
if traceID == "" && m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(param.Request.Context())
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if param.StatusCode >= 400 && m.tracer != nil {
|
||||
span := trace.SpanFromContext(param.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 对于服务器错误,记录更详细的日志
|
||||
if param.StatusCode >= 500 {
|
||||
m.logger.Error("服务器错误",
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("trace_id", traceID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 记录请求日志
|
||||
logFields := []zap.Field{
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
}
|
||||
|
||||
// 添加TraceID
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
m.logger.Info("HTTP请求", logFields...)
|
||||
return ""
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestLoggerMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestIDMiddleware 请求ID中间件
|
||||
type RequestIDMiddleware struct{}
|
||||
|
||||
// NewRequestIDMiddleware 创建请求ID中间件
|
||||
func NewRequestIDMiddleware() *RequestIDMiddleware {
|
||||
return &RequestIDMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestIDMiddleware) GetName() string {
|
||||
return "request_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestIDMiddleware) GetPriority() int {
|
||||
return 95 // 最高优先级,第一个执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 设置请求ID到上下文和响应头
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 添加到响应头,方便客户端追踪
|
||||
c.Writer.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TraceIDMiddleware 追踪ID中间件
|
||||
type TraceIDMiddleware struct {
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewTraceIDMiddleware 创建追踪ID中间件
|
||||
func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
|
||||
return &TraceIDMiddleware{
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *TraceIDMiddleware) GetName() string {
|
||||
return "trace_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *TraceIDMiddleware) GetPriority() int {
|
||||
return 94 // 仅次于请求ID中间件
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成追踪ID
|
||||
traceID := m.tracer.GetTraceID(c.Request.Context())
|
||||
if traceID != "" {
|
||||
// 设置追踪ID到响应头
|
||||
c.Header("X-Trace-ID", traceID)
|
||||
// 添加到上下文
|
||||
c.Set("trace_id", traceID)
|
||||
}
|
||||
|
||||
// 检查是否为错误请求(例如URL不存在)
|
||||
c.Next()
|
||||
|
||||
// 请求完成后检查状态码
|
||||
if c.Writer.Status() >= 400 {
|
||||
// 获取当前span
|
||||
span := trace.SpanFromContext(c.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 设置错误上下文,以便后续span可以识别
|
||||
c.Request = c.Request.WithContext(context.WithValue(
|
||||
c.Request.Context(),
|
||||
"otel_error_request",
|
||||
true,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *TraceIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware 安全头部中间件
|
||||
type SecurityHeadersMiddleware struct{}
|
||||
|
||||
// NewSecurityHeadersMiddleware 创建安全头部中间件
|
||||
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
|
||||
return &SecurityHeadersMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *SecurityHeadersMiddleware) GetName() string {
|
||||
return "security_headers"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *SecurityHeadersMiddleware) GetPriority() int {
|
||||
return 85 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置安全头部
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Content-Security-Policy", "default-src 'self'")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ResponseTimeMiddleware 响应时间中间件
|
||||
type ResponseTimeMiddleware struct{}
|
||||
|
||||
// NewResponseTimeMiddleware 创建响应时间中间件
|
||||
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
|
||||
return &ResponseTimeMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ResponseTimeMiddleware) GetName() string {
|
||||
return "response_time"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ResponseTimeMiddleware) GetPriority() int {
|
||||
return 75 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
// 计算响应时间并添加到头部
|
||||
duration := time.Since(start)
|
||||
c.Header("X-Response-Time", duration.String())
|
||||
|
||||
// 记录到上下文中,供其他中间件使用
|
||||
c.Set("response_time", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ResponseTimeMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
|
||||
type RequestBodyLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
enable bool
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
|
||||
return &RequestBodyLoggerMiddleware{
|
||||
logger: logger,
|
||||
enable: enable,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestBodyLoggerMiddleware) GetName() string {
|
||||
return "request_body_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
|
||||
return 70 // 较低优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.enable {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 只记录POST, PUT, PATCH请求的body
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err == nil {
|
||||
// 重新设置body供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 获取追踪ID
|
||||
traceID := ""
|
||||
if m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(c.Request.Context())
|
||||
}
|
||||
|
||||
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
|
||||
logFields := []zap.Field{
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("body", string(bodyBytes)),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
}
|
||||
|
||||
// 添加追踪ID
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
m.logger.Debug("请求体", logFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
|
||||
return false // 可选中间件,不是全局的
|
||||
}
|
||||
|
||||
// ErrorTrackingMiddleware 错误追踪中间件
|
||||
type ErrorTrackingMiddleware struct {
|
||||
logger *zap.Logger
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewErrorTrackingMiddleware 创建错误追踪中间件
|
||||
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
|
||||
return &ErrorTrackingMiddleware{
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ErrorTrackingMiddleware) GetName() string {
|
||||
return "error_tracking"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ErrorTrackingMiddleware) GetPriority() int {
|
||||
return 60 // 低优先级,在大多数中间件之后执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// 检查是否有错误
|
||||
if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
|
||||
// 获取当前span
|
||||
span := trace.SpanFromContext(c.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 记录错误日志
|
||||
traceID := m.tracer.GetTraceID(c.Request.Context())
|
||||
spanID := m.tracer.GetSpanID(c.Request.Context())
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.Int("status_code", c.Writer.Status()),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.FullPath()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
}
|
||||
|
||||
// 添加追踪信息
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
if spanID != "" {
|
||||
logFields = append(logFields, zap.String("span_id", spanID))
|
||||
}
|
||||
|
||||
// 添加错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
logFields = append(logFields, zap.String("errors", c.Errors.String()))
|
||||
}
|
||||
|
||||
// 根据状态码决定日志级别
|
||||
if c.Writer.Status() >= 500 {
|
||||
m.logger.Error("服务器错误", logFields...)
|
||||
} else {
|
||||
m.logger.Warn("客户端错误", logFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ErrorTrackingMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user