499 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			499 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | ||
| 
 | ||
| import (
 | ||
| 	"bytes"
 | ||
| 	"context"
 | ||
| 	"fmt"
 | ||
| 	"io"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| 	"github.com/google/uuid"
 | ||
| 	"go.opentelemetry.io/otel/attribute"
 | ||
| 	"go.opentelemetry.io/otel/trace"
 | ||
| 	"go.uber.org/zap"
 | ||
| 
 | ||
| 	"tyapi-server/internal/shared/tracing"
 | ||
| )
 | ||
| 
 | ||
| // RequestLoggerMiddleware 请求日志中间件
 | ||
| type RequestLoggerMiddleware struct {
 | ||
| 	logger        *zap.Logger
 | ||
| 	useColoredLog bool
 | ||
| 	isDevelopment bool
 | ||
| 	tracer        *tracing.Tracer
 | ||
| }
 | ||
| 
 | ||
| // NewRequestLoggerMiddleware 创建请求日志中间件
 | ||
| func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
 | ||
| 	return &RequestLoggerMiddleware{
 | ||
| 		logger:        logger,
 | ||
| 		useColoredLog: isDevelopment, // 开发环境使用彩色日志
 | ||
| 		isDevelopment: isDevelopment,
 | ||
| 		tracer:        tracer,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *RequestLoggerMiddleware) GetName() string {
 | ||
| 	return "request_logger"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *RequestLoggerMiddleware) GetPriority() int {
 | ||
| 	return 80 // 中等优先级
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	if m.useColoredLog {
 | ||
| 		// 开发环境:使用Gin默认的彩色日志格式
 | ||
| 		return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 | ||
| 			var statusColor, methodColor, resetColor string
 | ||
| 			if param.IsOutputColor() {
 | ||
| 				statusColor = param.StatusCodeColor()
 | ||
| 				methodColor = param.MethodColor()
 | ||
| 				resetColor = param.ResetColor()
 | ||
| 			}
 | ||
| 
 | ||
| 			if param.Latency > time.Minute {
 | ||
| 				param.Latency = param.Latency.Truncate(time.Second)
 | ||
| 			}
 | ||
| 
 | ||
| 			// 获取TraceID
 | ||
| 			traceID := param.Request.Header.Get("X-Trace-ID")
 | ||
| 			if traceID == "" && m.tracer != nil {
 | ||
| 				traceID = m.tracer.GetTraceID(param.Request.Context())
 | ||
| 			}
 | ||
| 
 | ||
| 			// 检查是否为错误响应
 | ||
| 			if param.StatusCode >= 400 && m.tracer != nil {
 | ||
| 				span := trace.SpanFromContext(param.Request.Context())
 | ||
| 				if span.IsRecording() {
 | ||
| 					// 标记为错误操作,确保100%采样
 | ||
| 					span.SetAttributes(
 | ||
| 						attribute.String("error.operation", "true"),
 | ||
| 						attribute.String("operation.type", "error"),
 | ||
| 					)
 | ||
| 				}
 | ||
| 			}
 | ||
| 
 | ||
| 			traceInfo := ""
 | ||
| 			if traceID != "" {
 | ||
| 				traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
 | ||
| 			}
 | ||
| 
 | ||
| 			return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
 | ||
| 				param.TimeStamp.Format("2006/01/02 - 15:04:05"),
 | ||
| 				statusColor, param.StatusCode, resetColor,
 | ||
| 				param.Latency,
 | ||
| 				param.ClientIP,
 | ||
| 				methodColor, param.Method, resetColor,
 | ||
| 				param.Path,
 | ||
| 				traceInfo,
 | ||
| 				param.ErrorMessage,
 | ||
| 			)
 | ||
| 		})
 | ||
| 	} else {
 | ||
| 		// 生产环境:使用结构化JSON日志
 | ||
| 		return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
 | ||
| 			// 获取TraceID
 | ||
| 			traceID := param.Request.Header.Get("X-Trace-ID")
 | ||
| 			if traceID == "" && m.tracer != nil {
 | ||
| 				traceID = m.tracer.GetTraceID(param.Request.Context())
 | ||
| 			}
 | ||
| 
 | ||
| 			// 检查是否为错误响应
 | ||
| 			if param.StatusCode >= 400 && m.tracer != nil {
 | ||
| 				span := trace.SpanFromContext(param.Request.Context())
 | ||
| 				if span.IsRecording() {
 | ||
| 					// 标记为错误操作,确保100%采样
 | ||
| 					span.SetAttributes(
 | ||
| 						attribute.String("error.operation", "true"),
 | ||
| 						attribute.String("operation.type", "error"),
 | ||
| 					)
 | ||
| 
 | ||
| 					// 对于服务器错误,记录更详细的日志
 | ||
| 					if param.StatusCode >= 500 {
 | ||
| 						m.logger.Error("服务器错误",
 | ||
| 							zap.Int("status_code", param.StatusCode),
 | ||
| 							zap.String("method", param.Method),
 | ||
| 							zap.String("path", param.Path),
 | ||
| 							zap.Duration("latency", param.Latency),
 | ||
| 							zap.String("client_ip", param.ClientIP),
 | ||
| 							zap.String("trace_id", traceID),
 | ||
| 						)
 | ||
| 					}
 | ||
| 				}
 | ||
| 			}
 | ||
| 
 | ||
| 			// 记录请求日志
 | ||
| 			logFields := []zap.Field{
 | ||
| 				zap.String("client_ip", param.ClientIP),
 | ||
| 				zap.String("method", param.Method),
 | ||
| 				zap.String("path", param.Path),
 | ||
| 				zap.String("protocol", param.Request.Proto),
 | ||
| 				zap.Int("status_code", param.StatusCode),
 | ||
| 				zap.Duration("latency", param.Latency),
 | ||
| 				zap.String("user_agent", param.Request.UserAgent()),
 | ||
| 				zap.Int("body_size", param.BodySize),
 | ||
| 				zap.String("referer", param.Request.Referer()),
 | ||
| 				zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
 | ||
| 			}
 | ||
| 
 | ||
| 			// 添加TraceID
 | ||
| 			if traceID != "" {
 | ||
| 				logFields = append(logFields, zap.String("trace_id", traceID))
 | ||
| 			}
 | ||
| 
 | ||
| 			m.logger.Info("HTTP请求", logFields...)
 | ||
| 			return ""
 | ||
| 		})
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *RequestLoggerMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // RequestIDMiddleware 请求ID中间件
 | ||
| type RequestIDMiddleware struct{}
 | ||
| 
 | ||
| // NewRequestIDMiddleware 创建请求ID中间件
 | ||
| func NewRequestIDMiddleware() *RequestIDMiddleware {
 | ||
| 	return &RequestIDMiddleware{}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *RequestIDMiddleware) GetName() string {
 | ||
| 	return "request_id"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *RequestIDMiddleware) GetPriority() int {
 | ||
| 	return 95 // 最高优先级,第一个执行
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 获取或生成请求ID
 | ||
| 		requestID := c.GetHeader("X-Request-ID")
 | ||
| 		if requestID == "" {
 | ||
| 			requestID = uuid.New().String()
 | ||
| 		}
 | ||
| 
 | ||
| 		// 设置请求ID到上下文和响应头
 | ||
| 		c.Set("request_id", requestID)
 | ||
| 		c.Header("X-Request-ID", requestID)
 | ||
| 
 | ||
| 		// 添加到响应头,方便客户端追踪
 | ||
| 		c.Writer.Header().Set("X-Request-ID", requestID)
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *RequestIDMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // TraceIDMiddleware 追踪ID中间件
 | ||
| type TraceIDMiddleware struct {
 | ||
| 	tracer *tracing.Tracer
 | ||
| }
 | ||
| 
 | ||
| // NewTraceIDMiddleware 创建追踪ID中间件
 | ||
| func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
 | ||
| 	return &TraceIDMiddleware{
 | ||
| 		tracer: tracer,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *TraceIDMiddleware) GetName() string {
 | ||
| 	return "trace_id"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *TraceIDMiddleware) GetPriority() int {
 | ||
| 	return 94 // 仅次于请求ID中间件
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 获取或生成追踪ID
 | ||
| 		traceID := m.tracer.GetTraceID(c.Request.Context())
 | ||
| 		if traceID != "" {
 | ||
| 			// 设置追踪ID到响应头
 | ||
| 			c.Header("X-Trace-ID", traceID)
 | ||
| 			// 添加到上下文
 | ||
| 			c.Set("trace_id", traceID)
 | ||
| 		}
 | ||
| 
 | ||
| 		// 检查是否为错误请求(例如URL不存在)
 | ||
| 		c.Next()
 | ||
| 
 | ||
| 		// 请求完成后检查状态码
 | ||
| 		if c.Writer.Status() >= 400 {
 | ||
| 			// 获取当前span
 | ||
| 			span := trace.SpanFromContext(c.Request.Context())
 | ||
| 			if span.IsRecording() {
 | ||
| 				// 标记为错误操作,确保100%采样
 | ||
| 				span.SetAttributes(
 | ||
| 					attribute.String("error.operation", "true"),
 | ||
| 					attribute.String("operation.type", "error"),
 | ||
| 				)
 | ||
| 
 | ||
| 				// 设置错误上下文,以便后续span可以识别
 | ||
| 				c.Request = c.Request.WithContext(context.WithValue(
 | ||
| 					c.Request.Context(),
 | ||
| 					"otel_error_request",
 | ||
| 					true,
 | ||
| 				))
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *TraceIDMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // SecurityHeadersMiddleware 安全头部中间件
 | ||
| type SecurityHeadersMiddleware struct{}
 | ||
| 
 | ||
| // NewSecurityHeadersMiddleware 创建安全头部中间件
 | ||
| func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
 | ||
| 	return &SecurityHeadersMiddleware{}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *SecurityHeadersMiddleware) GetName() string {
 | ||
| 	return "security_headers"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *SecurityHeadersMiddleware) GetPriority() int {
 | ||
| 	return 85 // 高优先级
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 设置安全头部
 | ||
| 		c.Header("X-Content-Type-Options", "nosniff")
 | ||
| 		c.Header("X-Frame-Options", "DENY")
 | ||
| 		c.Header("X-XSS-Protection", "1; mode=block")
 | ||
| 		c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
 | ||
| 		c.Header("Content-Security-Policy", "default-src 'self'")
 | ||
| 		c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *SecurityHeadersMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // ResponseTimeMiddleware 响应时间中间件
 | ||
| type ResponseTimeMiddleware struct{}
 | ||
| 
 | ||
| // NewResponseTimeMiddleware 创建响应时间中间件
 | ||
| func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
 | ||
| 	return &ResponseTimeMiddleware{}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *ResponseTimeMiddleware) GetName() string {
 | ||
| 	return "response_time"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *ResponseTimeMiddleware) GetPriority() int {
 | ||
| 	return 75 // 中等优先级
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		start := time.Now()
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 
 | ||
| 		// 计算响应时间并添加到头部
 | ||
| 		duration := time.Since(start)
 | ||
| 		c.Header("X-Response-Time", duration.String())
 | ||
| 
 | ||
| 		// 记录到上下文中,供其他中间件使用
 | ||
| 		c.Set("response_time", duration)
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *ResponseTimeMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
 | ||
| type RequestBodyLoggerMiddleware struct {
 | ||
| 	logger *zap.Logger
 | ||
| 	enable bool
 | ||
| 	tracer *tracing.Tracer
 | ||
| }
 | ||
| 
 | ||
| // NewRequestBodyLoggerMiddleware 创建请求体日志中间件
 | ||
| func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
 | ||
| 	return &RequestBodyLoggerMiddleware{
 | ||
| 		logger: logger,
 | ||
| 		enable: enable,
 | ||
| 		tracer: tracer,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *RequestBodyLoggerMiddleware) GetName() string {
 | ||
| 	return "request_body_logger"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *RequestBodyLoggerMiddleware) GetPriority() int {
 | ||
| 	return 70 // 较低优先级
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	if !m.enable {
 | ||
| 		return func(c *gin.Context) {
 | ||
| 			c.Next()
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		// 只记录POST, PUT, PATCH请求的body
 | ||
| 		if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
 | ||
| 			if c.Request.Body != nil {
 | ||
| 				bodyBytes, err := io.ReadAll(c.Request.Body)
 | ||
| 				if err == nil {
 | ||
| 					// 重新设置body供后续处理使用
 | ||
| 					c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
 | ||
| 
 | ||
| 					// 获取追踪ID
 | ||
| 					traceID := ""
 | ||
| 					if m.tracer != nil {
 | ||
| 						traceID = m.tracer.GetTraceID(c.Request.Context())
 | ||
| 					}
 | ||
| 
 | ||
| 					// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
 | ||
| 					logFields := []zap.Field{
 | ||
| 						zap.String("method", c.Request.Method),
 | ||
| 						zap.String("path", c.Request.URL.Path),
 | ||
| 						zap.String("body", string(bodyBytes)),
 | ||
| 						zap.String("request_id", c.GetString("request_id")),
 | ||
| 					}
 | ||
| 
 | ||
| 					// 添加追踪ID
 | ||
| 					if traceID != "" {
 | ||
| 						logFields = append(logFields, zap.String("trace_id", traceID))
 | ||
| 					}
 | ||
| 
 | ||
| 					m.logger.Debug("请求体", logFields...)
 | ||
| 				}
 | ||
| 			}
 | ||
| 		}
 | ||
| 
 | ||
| 		c.Next()
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
 | ||
| 	return false // 可选中间件,不是全局的
 | ||
| }
 | ||
| 
 | ||
| // ErrorTrackingMiddleware 错误追踪中间件
 | ||
| type ErrorTrackingMiddleware struct {
 | ||
| 	logger *zap.Logger
 | ||
| 	tracer *tracing.Tracer
 | ||
| }
 | ||
| 
 | ||
| // NewErrorTrackingMiddleware 创建错误追踪中间件
 | ||
| func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
 | ||
| 	return &ErrorTrackingMiddleware{
 | ||
| 		logger: logger,
 | ||
| 		tracer: tracer,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *ErrorTrackingMiddleware) GetName() string {
 | ||
| 	return "error_tracking"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *ErrorTrackingMiddleware) GetPriority() int {
 | ||
| 	return 60 // 低优先级,在大多数中间件之后执行
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	return func(c *gin.Context) {
 | ||
| 		c.Next()
 | ||
| 
 | ||
| 		// 检查是否有错误
 | ||
| 		if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
 | ||
| 			// 获取当前span
 | ||
| 			span := trace.SpanFromContext(c.Request.Context())
 | ||
| 			if span.IsRecording() {
 | ||
| 				// 标记为错误操作,确保100%采样
 | ||
| 				span.SetAttributes(
 | ||
| 					attribute.String("error.operation", "true"),
 | ||
| 					attribute.String("operation.type", "error"),
 | ||
| 				)
 | ||
| 
 | ||
| 				// 记录错误日志
 | ||
| 				traceID := m.tracer.GetTraceID(c.Request.Context())
 | ||
| 				spanID := m.tracer.GetSpanID(c.Request.Context())
 | ||
| 
 | ||
| 				logFields := []zap.Field{
 | ||
| 					zap.Int("status_code", c.Writer.Status()),
 | ||
| 					zap.String("method", c.Request.Method),
 | ||
| 					zap.String("path", c.FullPath()),
 | ||
| 					zap.String("client_ip", c.ClientIP()),
 | ||
| 				}
 | ||
| 
 | ||
| 				// 添加追踪信息
 | ||
| 				if traceID != "" {
 | ||
| 					logFields = append(logFields, zap.String("trace_id", traceID))
 | ||
| 				}
 | ||
| 				if spanID != "" {
 | ||
| 					logFields = append(logFields, zap.String("span_id", spanID))
 | ||
| 				}
 | ||
| 
 | ||
| 				// 添加错误信息
 | ||
| 				if len(c.Errors) > 0 {
 | ||
| 					logFields = append(logFields, zap.String("errors", c.Errors.String()))
 | ||
| 				}
 | ||
| 
 | ||
| 				// 根据状态码决定日志级别
 | ||
| 				if c.Writer.Status() >= 500 {
 | ||
| 					m.logger.Error("服务器错误", logFields...)
 | ||
| 				} else {
 | ||
| 					m.logger.Warn("客户端错误", logFields...)
 | ||
| 				}
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *ErrorTrackingMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 |