Files
tyapi-server/internal/shared/middleware/request_logger.go

499 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"tyapi-server/internal/shared/tracing"
)
// RequestLoggerMiddleware 请求日志中间件
type RequestLoggerMiddleware struct {
logger *zap.Logger
useColoredLog bool
isDevelopment bool
tracer *tracing.Tracer
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
return &RequestLoggerMiddleware{
logger: logger,
useColoredLog: isDevelopment, // 开发环境使用彩色日志
isDevelopment: isDevelopment,
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *RequestLoggerMiddleware) GetName() string {
return "request_logger"
}
// GetPriority 返回中间件优先级
func (m *RequestLoggerMiddleware) GetPriority() int {
return 80 // 中等优先级
}
// Handle 返回中间件处理函数
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
if m.useColoredLog {
// 开发环境使用Gin默认的彩色日志格式
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
// 获取TraceID
traceID := param.Request.Header.Get("X-Trace-ID")
if traceID == "" && m.tracer != nil {
traceID = m.tracer.GetTraceID(param.Request.Context())
}
// 检查是否为错误响应
if param.StatusCode >= 400 && m.tracer != nil {
span := trace.SpanFromContext(param.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
}
}
traceInfo := ""
if traceID != "" {
traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
}
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
methodColor, param.Method, resetColor,
param.Path,
traceInfo,
param.ErrorMessage,
)
})
} else {
// 生产环境使用结构化JSON日志
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 获取TraceID
traceID := param.Request.Header.Get("X-Trace-ID")
if traceID == "" && m.tracer != nil {
traceID = m.tracer.GetTraceID(param.Request.Context())
}
// 检查是否为错误响应
if param.StatusCode >= 400 && m.tracer != nil {
span := trace.SpanFromContext(param.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 对于服务器错误,记录更详细的日志
if param.StatusCode >= 500 {
m.logger.Error("服务器错误",
zap.Int("status_code", param.StatusCode),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.Duration("latency", param.Latency),
zap.String("client_ip", param.ClientIP),
zap.String("trace_id", traceID),
)
}
}
}
// 记录请求日志
logFields := []zap.Field{
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
}
// 添加TraceID
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
m.logger.Info("HTTP请求", logFields...)
return ""
})
}
}
// IsGlobal 是否为全局中间件
func (m *RequestLoggerMiddleware) IsGlobal() bool {
return true
}
// RequestIDMiddleware 请求ID中间件
type RequestIDMiddleware struct{}
// NewRequestIDMiddleware 创建请求ID中间件
func NewRequestIDMiddleware() *RequestIDMiddleware {
return &RequestIDMiddleware{}
}
// GetName 返回中间件名称
func (m *RequestIDMiddleware) GetName() string {
return "request_id"
}
// GetPriority 返回中间件优先级
func (m *RequestIDMiddleware) GetPriority() int {
return 95 // 最高优先级,第一个执行
}
// Handle 返回中间件处理函数
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取或生成请求ID
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
// 设置请求ID到上下文和响应头
c.Set("request_id", requestID)
c.Header("X-Request-ID", requestID)
// 添加到响应头,方便客户端追踪
c.Writer.Header().Set("X-Request-ID", requestID)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RequestIDMiddleware) IsGlobal() bool {
return true
}
// TraceIDMiddleware 追踪ID中间件
type TraceIDMiddleware struct {
tracer *tracing.Tracer
}
// NewTraceIDMiddleware 创建追踪ID中间件
func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
return &TraceIDMiddleware{
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *TraceIDMiddleware) GetName() string {
return "trace_id"
}
// GetPriority 返回中间件优先级
func (m *TraceIDMiddleware) GetPriority() int {
return 94 // 仅次于请求ID中间件
}
// Handle 返回中间件处理函数
func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取或生成追踪ID
traceID := m.tracer.GetTraceID(c.Request.Context())
if traceID != "" {
// 设置追踪ID到响应头
c.Header("X-Trace-ID", traceID)
// 添加到上下文
c.Set("trace_id", traceID)
}
// 检查是否为错误请求例如URL不存在
c.Next()
// 请求完成后检查状态码
if c.Writer.Status() >= 400 {
// 获取当前span
span := trace.SpanFromContext(c.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 设置错误上下文以便后续span可以识别
c.Request = c.Request.WithContext(context.WithValue(
c.Request.Context(),
"otel_error_request",
true,
))
}
}
}
}
// IsGlobal 是否为全局中间件
func (m *TraceIDMiddleware) IsGlobal() bool {
return true
}
// SecurityHeadersMiddleware 安全头部中间件
type SecurityHeadersMiddleware struct{}
// NewSecurityHeadersMiddleware 创建安全头部中间件
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
return &SecurityHeadersMiddleware{}
}
// GetName 返回中间件名称
func (m *SecurityHeadersMiddleware) GetName() string {
return "security_headers"
}
// GetPriority 返回中间件优先级
func (m *SecurityHeadersMiddleware) GetPriority() int {
return 85 // 高优先级
}
// Handle 返回中间件处理函数
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置安全头部
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Content-Security-Policy", "default-src 'self'")
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
return true
}
// ResponseTimeMiddleware 响应时间中间件
type ResponseTimeMiddleware struct{}
// NewResponseTimeMiddleware 创建响应时间中间件
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
return &ResponseTimeMiddleware{}
}
// GetName 返回中间件名称
func (m *ResponseTimeMiddleware) GetName() string {
return "response_time"
}
// GetPriority 返回中间件优先级
func (m *ResponseTimeMiddleware) GetPriority() int {
return 75 // 中等优先级
}
// Handle 返回中间件处理函数
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
// 计算响应时间并添加到头部
duration := time.Since(start)
c.Header("X-Response-Time", duration.String())
// 记录到上下文中,供其他中间件使用
c.Set("response_time", duration)
}
}
// IsGlobal 是否为全局中间件
func (m *ResponseTimeMiddleware) IsGlobal() bool {
return true
}
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
type RequestBodyLoggerMiddleware struct {
logger *zap.Logger
enable bool
tracer *tracing.Tracer
}
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
return &RequestBodyLoggerMiddleware{
logger: logger,
enable: enable,
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *RequestBodyLoggerMiddleware) GetName() string {
return "request_body_logger"
}
// GetPriority 返回中间件优先级
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
return 70 // 较低优先级
}
// Handle 返回中间件处理函数
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
if !m.enable {
return func(c *gin.Context) {
c.Next()
}
}
return func(c *gin.Context) {
// 只记录POST, PUT, PATCH请求的body
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
if c.Request.Body != nil {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err == nil {
// 重新设置body供后续处理使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 获取追踪ID
traceID := ""
if m.tracer != nil {
traceID = m.tracer.GetTraceID(c.Request.Context())
}
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
logFields := []zap.Field{
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.String("body", string(bodyBytes)),
zap.String("request_id", c.GetString("request_id")),
}
// 添加追踪ID
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
m.logger.Debug("请求体", logFields...)
}
}
}
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
return false // 可选中间件,不是全局的
}
// ErrorTrackingMiddleware 错误追踪中间件
type ErrorTrackingMiddleware struct {
logger *zap.Logger
tracer *tracing.Tracer
}
// NewErrorTrackingMiddleware 创建错误追踪中间件
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
return &ErrorTrackingMiddleware{
logger: logger,
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *ErrorTrackingMiddleware) GetName() string {
return "error_tracking"
}
// GetPriority 返回中间件优先级
func (m *ErrorTrackingMiddleware) GetPriority() int {
return 60 // 低优先级,在大多数中间件之后执行
}
// Handle 返回中间件处理函数
func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 检查是否有错误
if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
// 获取当前span
span := trace.SpanFromContext(c.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录错误日志
traceID := m.tracer.GetTraceID(c.Request.Context())
spanID := m.tracer.GetSpanID(c.Request.Context())
logFields := []zap.Field{
zap.Int("status_code", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", c.FullPath()),
zap.String("client_ip", c.ClientIP()),
}
// 添加追踪信息
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
if spanID != "" {
logFields = append(logFields, zap.String("span_id", spanID))
}
// 添加错误信息
if len(c.Errors) > 0 {
logFields = append(logFields, zap.String("errors", c.Errors.String()))
}
// 根据状态码决定日志级别
if c.Writer.Status() >= 500 {
m.logger.Error("服务器错误", logFields...)
} else {
m.logger.Warn("客户端错误", logFields...)
}
}
}
}
}
// IsGlobal 是否为全局中间件
func (m *ErrorTrackingMiddleware) IsGlobal() bool {
return true
}