package middleware import ( "bytes" "context" "fmt" "io" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "tyapi-server/internal/shared/tracing" ) // RequestLoggerMiddleware 请求日志中间件 type RequestLoggerMiddleware struct { logger *zap.Logger useColoredLog bool isDevelopment bool tracer *tracing.Tracer } // NewRequestLoggerMiddleware 创建请求日志中间件 func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware { return &RequestLoggerMiddleware{ logger: logger, useColoredLog: isDevelopment, // 开发环境使用彩色日志 isDevelopment: isDevelopment, tracer: tracer, } } // GetName 返回中间件名称 func (m *RequestLoggerMiddleware) GetName() string { return "request_logger" } // GetPriority 返回中间件优先级 func (m *RequestLoggerMiddleware) GetPriority() int { return 80 // 中等优先级 } // Handle 返回中间件处理函数 func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc { if m.useColoredLog { // 开发环境:使用Gin默认的彩色日志格式 return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { var statusColor, methodColor, resetColor string if param.IsOutputColor() { statusColor = param.StatusCodeColor() methodColor = param.MethodColor() resetColor = param.ResetColor() } if param.Latency > time.Minute { param.Latency = param.Latency.Truncate(time.Second) } // 获取TraceID traceID := param.Request.Header.Get("X-Trace-ID") if traceID == "" && m.tracer != nil { traceID = m.tracer.GetTraceID(param.Request.Context()) } // 检查是否为错误响应 if param.StatusCode >= 400 && m.tracer != nil { span := trace.SpanFromContext(param.Request.Context()) if span.IsRecording() { // 标记为错误操作,确保100%采样 span.SetAttributes( attribute.String("error.operation", "true"), attribute.String("operation.type", "error"), ) } } traceInfo := "" if traceID != "" { traceInfo = fmt.Sprintf(" | TraceID: %s", traceID) } return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s", param.TimeStamp.Format("2006/01/02 - 15:04:05"), statusColor, param.StatusCode, resetColor, param.Latency, param.ClientIP, methodColor, param.Method, resetColor, param.Path, traceInfo, param.ErrorMessage, ) }) } else { // 生产环境:使用结构化JSON日志 return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string { // 获取TraceID traceID := param.Request.Header.Get("X-Trace-ID") if traceID == "" && m.tracer != nil { traceID = m.tracer.GetTraceID(param.Request.Context()) } // 检查是否为错误响应 if param.StatusCode >= 400 && m.tracer != nil { span := trace.SpanFromContext(param.Request.Context()) if span.IsRecording() { // 标记为错误操作,确保100%采样 span.SetAttributes( attribute.String("error.operation", "true"), attribute.String("operation.type", "error"), ) // 对于服务器错误,记录更详细的日志 if param.StatusCode >= 500 { m.logger.Error("服务器错误", zap.Int("status_code", param.StatusCode), zap.String("method", param.Method), zap.String("path", param.Path), zap.Duration("latency", param.Latency), zap.String("client_ip", param.ClientIP), zap.String("trace_id", traceID), ) } } } // 记录请求日志 logFields := []zap.Field{ zap.String("client_ip", param.ClientIP), zap.String("method", param.Method), zap.String("path", param.Path), zap.String("protocol", param.Request.Proto), zap.Int("status_code", param.StatusCode), zap.Duration("latency", param.Latency), zap.String("user_agent", param.Request.UserAgent()), zap.Int("body_size", param.BodySize), zap.String("referer", param.Request.Referer()), zap.String("request_id", param.Request.Header.Get("X-Request-ID")), } // 添加TraceID if traceID != "" { logFields = append(logFields, zap.String("trace_id", traceID)) } m.logger.Info("HTTP请求", logFields...) return "" }) } } // IsGlobal 是否为全局中间件 func (m *RequestLoggerMiddleware) IsGlobal() bool { return true } // RequestIDMiddleware 请求ID中间件 type RequestIDMiddleware struct{} // NewRequestIDMiddleware 创建请求ID中间件 func NewRequestIDMiddleware() *RequestIDMiddleware { return &RequestIDMiddleware{} } // GetName 返回中间件名称 func (m *RequestIDMiddleware) GetName() string { return "request_id" } // GetPriority 返回中间件优先级 func (m *RequestIDMiddleware) GetPriority() int { return 95 // 最高优先级,第一个执行 } // Handle 返回中间件处理函数 func (m *RequestIDMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { // 获取或生成请求ID requestID := c.GetHeader("X-Request-ID") if requestID == "" { requestID = uuid.New().String() } // 设置请求ID到上下文和响应头 c.Set("request_id", requestID) c.Header("X-Request-ID", requestID) // 添加到响应头,方便客户端追踪 c.Writer.Header().Set("X-Request-ID", requestID) c.Next() } } // IsGlobal 是否为全局中间件 func (m *RequestIDMiddleware) IsGlobal() bool { return true } // TraceIDMiddleware 追踪ID中间件 type TraceIDMiddleware struct { tracer *tracing.Tracer } // NewTraceIDMiddleware 创建追踪ID中间件 func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware { return &TraceIDMiddleware{ tracer: tracer, } } // GetName 返回中间件名称 func (m *TraceIDMiddleware) GetName() string { return "trace_id" } // GetPriority 返回中间件优先级 func (m *TraceIDMiddleware) GetPriority() int { return 94 // 仅次于请求ID中间件 } // Handle 返回中间件处理函数 func (m *TraceIDMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { // 获取或生成追踪ID traceID := m.tracer.GetTraceID(c.Request.Context()) if traceID != "" { // 设置追踪ID到响应头 c.Header("X-Trace-ID", traceID) // 添加到上下文 c.Set("trace_id", traceID) } // 检查是否为错误请求(例如URL不存在) c.Next() // 请求完成后检查状态码 if c.Writer.Status() >= 400 { // 获取当前span span := trace.SpanFromContext(c.Request.Context()) if span.IsRecording() { // 标记为错误操作,确保100%采样 span.SetAttributes( attribute.String("error.operation", "true"), attribute.String("operation.type", "error"), ) // 设置错误上下文,以便后续span可以识别 c.Request = c.Request.WithContext(context.WithValue( c.Request.Context(), "otel_error_request", true, )) } } } } // IsGlobal 是否为全局中间件 func (m *TraceIDMiddleware) IsGlobal() bool { return true } // SecurityHeadersMiddleware 安全头部中间件 type SecurityHeadersMiddleware struct{} // NewSecurityHeadersMiddleware 创建安全头部中间件 func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware { return &SecurityHeadersMiddleware{} } // GetName 返回中间件名称 func (m *SecurityHeadersMiddleware) GetName() string { return "security_headers" } // GetPriority 返回中间件优先级 func (m *SecurityHeadersMiddleware) GetPriority() int { return 85 // 高优先级 } // Handle 返回中间件处理函数 func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { // 设置安全头部 c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("X-XSS-Protection", "1; mode=block") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") c.Header("Content-Security-Policy", "default-src 'self'") c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains") c.Next() } } // IsGlobal 是否为全局中间件 func (m *SecurityHeadersMiddleware) IsGlobal() bool { return true } // ResponseTimeMiddleware 响应时间中间件 type ResponseTimeMiddleware struct{} // NewResponseTimeMiddleware 创建响应时间中间件 func NewResponseTimeMiddleware() *ResponseTimeMiddleware { return &ResponseTimeMiddleware{} } // GetName 返回中间件名称 func (m *ResponseTimeMiddleware) GetName() string { return "response_time" } // GetPriority 返回中间件优先级 func (m *ResponseTimeMiddleware) GetPriority() int { return 75 // 中等优先级 } // Handle 返回中间件处理函数 func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() c.Next() // 计算响应时间并添加到头部 duration := time.Since(start) c.Header("X-Response-Time", duration.String()) // 记录到上下文中,供其他中间件使用 c.Set("response_time", duration) } } // IsGlobal 是否为全局中间件 func (m *ResponseTimeMiddleware) IsGlobal() bool { return true } // RequestBodyLoggerMiddleware 请求体日志中间件(用于调试) type RequestBodyLoggerMiddleware struct { logger *zap.Logger enable bool tracer *tracing.Tracer } // NewRequestBodyLoggerMiddleware 创建请求体日志中间件 func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware { return &RequestBodyLoggerMiddleware{ logger: logger, enable: enable, tracer: tracer, } } // GetName 返回中间件名称 func (m *RequestBodyLoggerMiddleware) GetName() string { return "request_body_logger" } // GetPriority 返回中间件优先级 func (m *RequestBodyLoggerMiddleware) GetPriority() int { return 70 // 较低优先级 } // Handle 返回中间件处理函数 func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc { if !m.enable { return func(c *gin.Context) { c.Next() } } return func(c *gin.Context) { // 只记录POST, PUT, PATCH请求的body if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" { if c.Request.Body != nil { bodyBytes, err := io.ReadAll(c.Request.Body) if err == nil { // 重新设置body供后续处理使用 c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // 获取追踪ID traceID := "" if m.tracer != nil { traceID = m.tracer.GetTraceID(c.Request.Context()) } // 记录请求体(注意:生产环境中应该谨慎记录敏感信息) logFields := []zap.Field{ zap.String("method", c.Request.Method), zap.String("path", c.Request.URL.Path), zap.String("body", string(bodyBytes)), zap.String("request_id", c.GetString("request_id")), } // 添加追踪ID if traceID != "" { logFields = append(logFields, zap.String("trace_id", traceID)) } m.logger.Debug("请求体", logFields...) } } } c.Next() } } // IsGlobal 是否为全局中间件 func (m *RequestBodyLoggerMiddleware) IsGlobal() bool { return false // 可选中间件,不是全局的 } // ErrorTrackingMiddleware 错误追踪中间件 type ErrorTrackingMiddleware struct { logger *zap.Logger tracer *tracing.Tracer } // NewErrorTrackingMiddleware 创建错误追踪中间件 func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware { return &ErrorTrackingMiddleware{ logger: logger, tracer: tracer, } } // GetName 返回中间件名称 func (m *ErrorTrackingMiddleware) GetName() string { return "error_tracking" } // GetPriority 返回中间件优先级 func (m *ErrorTrackingMiddleware) GetPriority() int { return 60 // 低优先级,在大多数中间件之后执行 } // Handle 返回中间件处理函数 func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc { return func(c *gin.Context) { c.Next() // 检查是否有错误 if len(c.Errors) > 0 || c.Writer.Status() >= 400 { // 获取当前span span := trace.SpanFromContext(c.Request.Context()) if span.IsRecording() { // 标记为错误操作,确保100%采样 span.SetAttributes( attribute.String("error.operation", "true"), attribute.String("operation.type", "error"), ) // 记录错误日志 traceID := m.tracer.GetTraceID(c.Request.Context()) spanID := m.tracer.GetSpanID(c.Request.Context()) logFields := []zap.Field{ zap.Int("status_code", c.Writer.Status()), zap.String("method", c.Request.Method), zap.String("path", c.FullPath()), zap.String("client_ip", c.ClientIP()), } // 添加追踪信息 if traceID != "" { logFields = append(logFields, zap.String("trace_id", traceID)) } if spanID != "" { logFields = append(logFields, zap.String("span_id", spanID)) } // 添加错误信息 if len(c.Errors) > 0 { logFields = append(logFields, zap.String("errors", c.Errors.String())) } // 根据状态码决定日志级别 if c.Writer.Status() >= 500 { m.logger.Error("服务器错误", logFields...) } else { m.logger.Warn("客户端错误", logFields...) } } } } } // IsGlobal 是否为全局中间件 func (m *ErrorTrackingMiddleware) IsGlobal() bool { return true }