feat(架构): 完善基础架构设计
This commit is contained in:
@@ -42,31 +42,31 @@ func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "Missing authorization header")
|
||||
m.respondUnauthorized(c, "缺少认证头部")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "Invalid authorization header format")
|
||||
m.respondUnauthorized(c, "认证头部格式无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "Missing token")
|
||||
m.respondUnauthorized(c, "缺少认证令牌")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("Invalid token",
|
||||
m.logger.Warn("无效的认证令牌",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "Invalid token")
|
||||
m.respondUnauthorized(c, "认证令牌无效")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error
|
||||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Unauthorized",
|
||||
"message": "认证失败",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
|
||||
@@ -2,11 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
@@ -15,14 +15,16 @@ import (
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
response interfaces.ResponseBuilder
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
response: response,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
@@ -48,15 +50,13 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
c.Header("Retry-After", "60")
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": "Rate limit exceeded",
|
||||
"error": "Too many requests",
|
||||
})
|
||||
// 使用统一的响应格式
|
||||
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,23 +2,35 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/tracing"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 请求日志中间件
|
||||
type RequestLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
logger *zap.Logger
|
||||
useColoredLog bool
|
||||
isDevelopment bool
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
|
||||
return &RequestLoggerMiddleware{
|
||||
logger: logger,
|
||||
logger: logger,
|
||||
useColoredLog: isDevelopment, // 开发环境使用彩色日志
|
||||
isDevelopment: isDevelopment,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,24 +46,110 @@ func (m *RequestLoggerMiddleware) GetPriority() int {
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 使用zap logger记录请求信息
|
||||
m.logger.Info("HTTP Request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
)
|
||||
if m.useColoredLog {
|
||||
// 开发环境:使用Gin默认的彩色日志格式
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var statusColor, methodColor, resetColor string
|
||||
if param.IsOutputColor() {
|
||||
statusColor = param.StatusCodeColor()
|
||||
methodColor = param.MethodColor()
|
||||
resetColor = param.ResetColor()
|
||||
}
|
||||
|
||||
// 返回空字符串,因为我们已经用zap记录了
|
||||
return ""
|
||||
})
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
|
||||
// 获取TraceID
|
||||
traceID := param.Request.Header.Get("X-Trace-ID")
|
||||
if traceID == "" && m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(param.Request.Context())
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if param.StatusCode >= 400 && m.tracer != nil {
|
||||
span := trace.SpanFromContext(param.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
traceInfo := ""
|
||||
if traceID != "" {
|
||||
traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
statusColor, param.StatusCode, resetColor,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
methodColor, param.Method, resetColor,
|
||||
param.Path,
|
||||
traceInfo,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
// 生产环境:使用结构化JSON日志
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 获取TraceID
|
||||
traceID := param.Request.Header.Get("X-Trace-ID")
|
||||
if traceID == "" && m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(param.Request.Context())
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if param.StatusCode >= 400 && m.tracer != nil {
|
||||
span := trace.SpanFromContext(param.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 对于服务器错误,记录更详细的日志
|
||||
if param.StatusCode >= 500 {
|
||||
m.logger.Error("服务器错误",
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("trace_id", traceID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 记录请求日志
|
||||
logFields := []zap.Field{
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
}
|
||||
|
||||
// 添加TraceID
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
m.logger.Info("HTTP请求", logFields...)
|
||||
return ""
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
@@ -102,6 +200,70 @@ func (m *RequestIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TraceIDMiddleware 追踪ID中间件
|
||||
type TraceIDMiddleware struct {
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewTraceIDMiddleware 创建追踪ID中间件
|
||||
func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
|
||||
return &TraceIDMiddleware{
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *TraceIDMiddleware) GetName() string {
|
||||
return "trace_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *TraceIDMiddleware) GetPriority() int {
|
||||
return 94 // 仅次于请求ID中间件
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成追踪ID
|
||||
traceID := m.tracer.GetTraceID(c.Request.Context())
|
||||
if traceID != "" {
|
||||
// 设置追踪ID到响应头
|
||||
c.Header("X-Trace-ID", traceID)
|
||||
// 添加到上下文
|
||||
c.Set("trace_id", traceID)
|
||||
}
|
||||
|
||||
// 检查是否为错误请求(例如URL不存在)
|
||||
c.Next()
|
||||
|
||||
// 请求完成后检查状态码
|
||||
if c.Writer.Status() >= 400 {
|
||||
// 获取当前span
|
||||
span := trace.SpanFromContext(c.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 设置错误上下文,以便后续span可以识别
|
||||
c.Request = c.Request.WithContext(context.WithValue(
|
||||
c.Request.Context(),
|
||||
"otel_error_request",
|
||||
true,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *TraceIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware 安全头部中间件
|
||||
type SecurityHeadersMiddleware struct{}
|
||||
|
||||
@@ -183,13 +345,15 @@ func (m *ResponseTimeMiddleware) IsGlobal() bool {
|
||||
type RequestBodyLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
enable bool
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
|
||||
return &RequestBodyLoggerMiddleware{
|
||||
logger: logger,
|
||||
enable: enable,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,13 +384,26 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
// 重新设置body供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 获取追踪ID
|
||||
traceID := ""
|
||||
if m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(c.Request.Context())
|
||||
}
|
||||
|
||||
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
|
||||
m.logger.Debug("Request Body",
|
||||
logFields := []zap.Field{
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("body", string(bodyBytes)),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
)
|
||||
}
|
||||
|
||||
// 添加追踪ID
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
m.logger.Debug("请求体", logFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -239,3 +416,83 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
|
||||
return false // 可选中间件,不是全局的
|
||||
}
|
||||
|
||||
// ErrorTrackingMiddleware 错误追踪中间件
|
||||
type ErrorTrackingMiddleware struct {
|
||||
logger *zap.Logger
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewErrorTrackingMiddleware 创建错误追踪中间件
|
||||
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
|
||||
return &ErrorTrackingMiddleware{
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ErrorTrackingMiddleware) GetName() string {
|
||||
return "error_tracking"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ErrorTrackingMiddleware) GetPriority() int {
|
||||
return 60 // 低优先级,在大多数中间件之后执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// 检查是否有错误
|
||||
if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
|
||||
// 获取当前span
|
||||
span := trace.SpanFromContext(c.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 记录错误日志
|
||||
traceID := m.tracer.GetTraceID(c.Request.Context())
|
||||
spanID := m.tracer.GetSpanID(c.Request.Context())
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.Int("status_code", c.Writer.Status()),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.FullPath()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
}
|
||||
|
||||
// 添加追踪信息
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
if spanID != "" {
|
||||
logFields = append(logFields, zap.String("span_id", spanID))
|
||||
}
|
||||
|
||||
// 添加错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
logFields = append(logFields, zap.String("errors", c.Errors.String()))
|
||||
}
|
||||
|
||||
// 根据状态码决定日志级别
|
||||
if c.Writer.Status() >= 500 {
|
||||
m.logger.Error("服务器错误", logFields...)
|
||||
} else {
|
||||
m.logger.Warn("客户端错误", logFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ErrorTrackingMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user