Initial commit: Basic project structure and dependencies
This commit is contained in:
261
internal/shared/middleware/auth.go
Normal file
261
internal/shared/middleware/auth.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT认证中间件
|
||||
type JWTAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewJWTAuthMiddleware 创建JWT认证中间件
|
||||
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
|
||||
return &JWTAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *JWTAuthMiddleware) GetName() string {
|
||||
return "jwt_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *JWTAuthMiddleware) GetPriority() int {
|
||||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "Missing authorization header")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "Invalid authorization header format")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "Missing token")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("Invalid token",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息添加到上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *JWTAuthMiddleware) IsGlobal() bool {
|
||||
return false // 不是全局中间件,需要手动应用到需要认证的路由
|
||||
}
|
||||
|
||||
// JWTClaims JWT声明结构
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// validateToken 验证JWT token
|
||||
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// respondUnauthorized 返回未授权响应
|
||||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Unauthorized",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "tyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"tyapi-client"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新token
|
||||
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Issuer: "tyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"tyapi-refresh"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 验证刷新token
|
||||
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// 检查是否为刷新token
|
||||
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims.Subject, nil
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
|
||||
type OptionalAuthMiddleware struct {
|
||||
jwtAuth *JWTAuthMiddleware
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware 创建可选认证中间件
|
||||
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
|
||||
return &OptionalAuthMiddleware{
|
||||
jwtAuth: jwtAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *OptionalAuthMiddleware) GetName() string {
|
||||
return "optional_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *OptionalAuthMiddleware) GetPriority() int {
|
||||
return 60 // 与JWT认证中间件相同
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// 没有认证头部,设置匿名用户标识
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取并验证token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
claims, err := m.jwtAuth.validateToken(tokenString)
|
||||
if err != nil {
|
||||
// token无效,但不返回错误,设置为未认证
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// token有效,设置用户信息
|
||||
c.Set("is_authenticated", true)
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *OptionalAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
104
internal/shared/middleware/cors.go
Normal file
104
internal/shared/middleware/cors.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware CORS中间件
|
||||
type CORSMiddleware struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewCORSMiddleware 创建CORS中间件
|
||||
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
|
||||
return &CORSMiddleware{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *CORSMiddleware) GetName() string {
|
||||
return "cors"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *CORSMiddleware) GetPriority() int {
|
||||
return 100 // 高优先级,最先执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.config.Development.EnableCors {
|
||||
// 如果没有启用CORS,返回空处理函数
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
config := cors.Config{
|
||||
AllowAllOrigins: false,
|
||||
AllowOrigins: m.getAllowedOrigins(),
|
||||
AllowMethods: m.getAllowedMethods(),
|
||||
AllowHeaders: m.getAllowedHeaders(),
|
||||
ExposeHeaders: []string{
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24小时
|
||||
}
|
||||
|
||||
return cors.New(config)
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *CORSMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getAllowedOrigins 获取允许的来源
|
||||
func (m *CORSMiddleware) getAllowedOrigins() []string {
|
||||
if m.config.Development.CorsOrigins == "" {
|
||||
return []string{"http://localhost:3000", "http://localhost:8080"}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的origins字符串
|
||||
return []string{m.config.Development.CorsOrigins}
|
||||
}
|
||||
|
||||
// getAllowedMethods 获取允许的方法
|
||||
func (m *CORSMiddleware) getAllowedMethods() []string {
|
||||
if m.config.Development.CorsMethods == "" {
|
||||
return []string{
|
||||
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的methods字符串
|
||||
return []string{m.config.Development.CorsMethods}
|
||||
}
|
||||
|
||||
// getAllowedHeaders 获取允许的头部
|
||||
func (m *CORSMiddleware) getAllowedHeaders() []string {
|
||||
if m.config.Development.CorsHeaders == "" {
|
||||
return []string{
|
||||
"Origin",
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"Accept",
|
||||
"Accept-Encoding",
|
||||
"Accept-Language",
|
||||
"X-Request-ID",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的headers字符串
|
||||
return []string{m.config.Development.CorsHeaders}
|
||||
}
|
||||
166
internal/shared/middleware/ratelimit.go
Normal file
166
internal/shared/middleware/ratelimit.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RateLimitMiddleware) GetName() string {
|
||||
return "ratelimit"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RateLimitMiddleware) GetPriority() int {
|
||||
return 90 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端标识(IP地址)
|
||||
clientID := m.getClientID(c)
|
||||
|
||||
// 获取或创建限流器
|
||||
limiter := m.getLimiter(clientID)
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
c.Header("Retry-After", "60")
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": "Rate limit exceeded",
|
||||
"error": "Too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RateLimitMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getClientID 获取客户端标识
|
||||
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
|
||||
// 优先使用X-Forwarded-For头部
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
return xff
|
||||
}
|
||||
|
||||
// 使用X-Real-IP头部
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 使用RemoteAddr
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// getLimiter 获取或创建限流器
|
||||
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
|
||||
m.mutex.RLock()
|
||||
limiter, exists := m.limiters[clientID]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists := m.limiters[clientID]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 创建新的限流器
|
||||
// rate.Every计算每个请求之间的间隔
|
||||
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
|
||||
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
|
||||
|
||||
m.limiters[clientID] = limiter
|
||||
|
||||
// 启动清理协程(仅第一次创建时)
|
||||
if len(m.limiters) == 1 {
|
||||
go m.cleanupRoutine()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanupRoutine() {
|
||||
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanup() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for clientID, limiter := range m.limiters {
|
||||
// 如果限流器在过去1小时内没有被使用,则删除它
|
||||
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
|
||||
delete(m.limiters, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取限流统计
|
||||
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"active_limiters": len(m.limiters),
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests": m.config.RateLimit.Requests,
|
||||
"window": m.config.RateLimit.Window,
|
||||
"burst": m.config.RateLimit.Burst,
|
||||
},
|
||||
}
|
||||
}
|
||||
241
internal/shared/middleware/request_logger.go
Normal file
241
internal/shared/middleware/request_logger.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 请求日志中间件
|
||||
type RequestLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
|
||||
return &RequestLoggerMiddleware{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestLoggerMiddleware) GetName() string {
|
||||
return "request_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestLoggerMiddleware) GetPriority() int {
|
||||
return 80 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 使用zap logger记录请求信息
|
||||
m.logger.Info("HTTP Request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
)
|
||||
|
||||
// 返回空字符串,因为我们已经用zap记录了
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestLoggerMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestIDMiddleware 请求ID中间件
|
||||
type RequestIDMiddleware struct{}
|
||||
|
||||
// NewRequestIDMiddleware 创建请求ID中间件
|
||||
func NewRequestIDMiddleware() *RequestIDMiddleware {
|
||||
return &RequestIDMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestIDMiddleware) GetName() string {
|
||||
return "request_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestIDMiddleware) GetPriority() int {
|
||||
return 95 // 最高优先级,第一个执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 设置请求ID到上下文和响应头
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 添加到响应头,方便客户端追踪
|
||||
c.Writer.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware 安全头部中间件
|
||||
type SecurityHeadersMiddleware struct{}
|
||||
|
||||
// NewSecurityHeadersMiddleware 创建安全头部中间件
|
||||
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
|
||||
return &SecurityHeadersMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *SecurityHeadersMiddleware) GetName() string {
|
||||
return "security_headers"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *SecurityHeadersMiddleware) GetPriority() int {
|
||||
return 85 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置安全头部
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Content-Security-Policy", "default-src 'self'")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ResponseTimeMiddleware 响应时间中间件
|
||||
type ResponseTimeMiddleware struct{}
|
||||
|
||||
// NewResponseTimeMiddleware 创建响应时间中间件
|
||||
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
|
||||
return &ResponseTimeMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ResponseTimeMiddleware) GetName() string {
|
||||
return "response_time"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ResponseTimeMiddleware) GetPriority() int {
|
||||
return 75 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
// 计算响应时间并添加到头部
|
||||
duration := time.Since(start)
|
||||
c.Header("X-Response-Time", duration.String())
|
||||
|
||||
// 记录到上下文中,供其他中间件使用
|
||||
c.Set("response_time", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ResponseTimeMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
|
||||
type RequestBodyLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
enable bool
|
||||
}
|
||||
|
||||
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
|
||||
return &RequestBodyLoggerMiddleware{
|
||||
logger: logger,
|
||||
enable: enable,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestBodyLoggerMiddleware) GetName() string {
|
||||
return "request_body_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
|
||||
return 70 // 较低优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.enable {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 只记录POST, PUT, PATCH请求的body
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err == nil {
|
||||
// 重新设置body供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
|
||||
m.logger.Debug("Request Body",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("body", string(bodyBytes)),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
|
||||
return false // 可选中间件,不是全局的
|
||||
}
|
||||
Reference in New Issue
Block a user