321 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			321 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package tracing
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"fmt"
 | ||
| 	"strings"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"go.opentelemetry.io/otel/attribute"
 | ||
| 	"go.opentelemetry.io/otel/codes"
 | ||
| 	"go.opentelemetry.io/otel/trace"
 | ||
| 	"go.uber.org/zap"
 | ||
| 	"gorm.io/gorm"
 | ||
| )
 | ||
| 
 | ||
| const (
 | ||
| 	gormSpanKey      = "otel:span"
 | ||
| 	gormOperationKey = "otel:operation"
 | ||
| 	gormTableNameKey = "otel:table_name"
 | ||
| 	gormStartTimeKey = "otel:start_time"
 | ||
| )
 | ||
| 
 | ||
| // GormTracingPlugin GORM链路追踪插件
 | ||
| type GormTracingPlugin struct {
 | ||
| 	tracer *Tracer
 | ||
| 	logger *zap.Logger
 | ||
| 	config GormPluginConfig
 | ||
| }
 | ||
| 
 | ||
| // GormPluginConfig GORM插件配置
 | ||
| type GormPluginConfig struct {
 | ||
| 	IncludeSQL    bool
 | ||
| 	IncludeValues bool
 | ||
| 	SlowThreshold time.Duration
 | ||
| 	ExcludeTables []string
 | ||
| 	SanitizeSQL   bool
 | ||
| }
 | ||
| 
 | ||
| // DefaultGormPluginConfig 默认GORM插件配置
 | ||
| func DefaultGormPluginConfig() GormPluginConfig {
 | ||
| 	return GormPluginConfig{
 | ||
| 		IncludeSQL:    true,
 | ||
| 		IncludeValues: false, // 生产环境建议设为false避免记录敏感数据
 | ||
| 		SlowThreshold: 200 * time.Millisecond,
 | ||
| 		ExcludeTables: []string{"migrations", "schema_migrations"},
 | ||
| 		SanitizeSQL:   true,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // NewGormTracingPlugin 创建GORM追踪插件
 | ||
| func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin {
 | ||
| 	return &GormTracingPlugin{
 | ||
| 		tracer: tracer,
 | ||
| 		logger: logger,
 | ||
| 		config: DefaultGormPluginConfig(),
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // Name 返回插件名称
 | ||
| func (p *GormTracingPlugin) Name() string {
 | ||
| 	return "gorm-otel-tracing"
 | ||
| }
 | ||
| 
 | ||
| // Initialize 初始化插件
 | ||
| func (p *GormTracingPlugin) Initialize(db *gorm.DB) error {
 | ||
| 	// 注册各种操作的回调
 | ||
| 	callbacks := []string{"create", "query", "update", "delete", "raw"}
 | ||
| 
 | ||
| 	for _, operation := range callbacks {
 | ||
| 		switch operation {
 | ||
| 		case "create":
 | ||
| 			err := db.Callback().Create().Before("gorm:create").
 | ||
| 				Register(p.Name()+":before_create", p.beforeOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register before create callback: %w", err)
 | ||
| 			}
 | ||
| 			err = db.Callback().Create().After("gorm:create").
 | ||
| 				Register(p.Name()+":after_create", p.afterOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register after create callback: %w", err)
 | ||
| 			}
 | ||
| 		case "query":
 | ||
| 			err := db.Callback().Query().Before("gorm:query").
 | ||
| 				Register(p.Name()+":before_query", p.beforeOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register before query callback: %w", err)
 | ||
| 			}
 | ||
| 			err = db.Callback().Query().After("gorm:query").
 | ||
| 				Register(p.Name()+":after_query", p.afterOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register after query callback: %w", err)
 | ||
| 			}
 | ||
| 		case "update":
 | ||
| 			err := db.Callback().Update().Before("gorm:update").
 | ||
| 				Register(p.Name()+":before_update", p.beforeOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register before update callback: %w", err)
 | ||
| 			}
 | ||
| 			err = db.Callback().Update().After("gorm:update").
 | ||
| 				Register(p.Name()+":after_update", p.afterOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register after update callback: %w", err)
 | ||
| 			}
 | ||
| 		case "delete":
 | ||
| 			err := db.Callback().Delete().Before("gorm:delete").
 | ||
| 				Register(p.Name()+":before_delete", p.beforeOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register before delete callback: %w", err)
 | ||
| 			}
 | ||
| 			err = db.Callback().Delete().After("gorm:delete").
 | ||
| 				Register(p.Name()+":after_delete", p.afterOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register after delete callback: %w", err)
 | ||
| 			}
 | ||
| 		case "raw":
 | ||
| 			err := db.Callback().Raw().Before("gorm:raw").
 | ||
| 				Register(p.Name()+":before_raw", p.beforeOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register before raw callback: %w", err)
 | ||
| 			}
 | ||
| 			err = db.Callback().Raw().After("gorm:raw").
 | ||
| 				Register(p.Name()+":after_raw", p.afterOperation)
 | ||
| 			if err != nil {
 | ||
| 				return fmt.Errorf("failed to register after raw callback: %w", err)
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	p.logger.Info("GORM追踪插件已初始化")
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // beforeOperation 操作前回调
 | ||
| func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) {
 | ||
| 	// 检查是否应该跳过追踪
 | ||
| 	if p.shouldSkipTracing(db) {
 | ||
| 		return
 | ||
| 	}
 | ||
| 
 | ||
| 	ctx := db.Statement.Context
 | ||
| 	if ctx == nil {
 | ||
| 		ctx = context.Background()
 | ||
| 	}
 | ||
| 
 | ||
| 	// 获取操作信息
 | ||
| 	operation := p.getOperationType(db)
 | ||
| 	tableName := p.getTableName(db)
 | ||
| 
 | ||
| 	// 检查是否应该排除此表
 | ||
| 	if p.isExcludedTable(tableName) {
 | ||
| 		return
 | ||
| 	}
 | ||
| 
 | ||
| 	// 开始追踪
 | ||
| 	ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName)
 | ||
| 
 | ||
| 	// 添加基础属性
 | ||
| 	p.tracer.AddSpanAttributes(span,
 | ||
| 		attribute.String("db.system", "postgresql"),
 | ||
| 		attribute.String("db.operation", operation),
 | ||
| 	)
 | ||
| 
 | ||
| 	if tableName != "" {
 | ||
| 		p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName))
 | ||
| 	}
 | ||
| 
 | ||
| 	// 保存追踪信息到GORM context
 | ||
| 	db.Set(gormSpanKey, span)
 | ||
| 	db.Set(gormOperationKey, operation)
 | ||
| 	db.Set(gormTableNameKey, tableName)
 | ||
| 	db.Set(gormStartTimeKey, time.Now())
 | ||
| 
 | ||
| 	// 更新statement context
 | ||
| 	db.Statement.Context = ctx
 | ||
| }
 | ||
| 
 | ||
| // afterOperation 操作后回调
 | ||
| func (p *GormTracingPlugin) afterOperation(db *gorm.DB) {
 | ||
| 	// 获取span
 | ||
| 	spanValue, exists := db.Get(gormSpanKey)
 | ||
| 	if !exists {
 | ||
| 		return
 | ||
| 	}
 | ||
| 
 | ||
| 	span, ok := spanValue.(trace.Span)
 | ||
| 	if !ok {
 | ||
| 		return
 | ||
| 	}
 | ||
| 	defer span.End()
 | ||
| 
 | ||
| 	// 获取操作信息
 | ||
| 	operation, _ := db.Get(gormOperationKey)
 | ||
| 	tableName, _ := db.Get(gormTableNameKey)
 | ||
| 	startTime, _ := db.Get(gormStartTimeKey)
 | ||
| 
 | ||
| 	// 计算执行时间
 | ||
| 	var duration time.Duration
 | ||
| 	if st, ok := startTime.(time.Time); ok {
 | ||
| 		duration = time.Since(st)
 | ||
| 		p.tracer.AddSpanAttributes(span,
 | ||
| 			attribute.Int64("db.duration_ms", duration.Milliseconds()),
 | ||
| 		)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 添加SQL信息
 | ||
| 	if p.config.IncludeSQL && db.Statement.SQL.String() != "" {
 | ||
| 		sql := db.Statement.SQL.String()
 | ||
| 		if p.config.SanitizeSQL {
 | ||
| 			sql = p.sanitizeSQL(sql)
 | ||
| 		}
 | ||
| 		p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql))
 | ||
| 	}
 | ||
| 
 | ||
| 	// 添加影响行数
 | ||
| 	if db.Statement.RowsAffected >= 0 {
 | ||
| 		p.tracer.AddSpanAttributes(span,
 | ||
| 			attribute.Int64("db.rows_affected", db.Statement.RowsAffected),
 | ||
| 		)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 处理错误
 | ||
| 	if db.Error != nil {
 | ||
| 		p.tracer.SetSpanError(span, db.Error)
 | ||
| 		span.SetStatus(codes.Error, db.Error.Error())
 | ||
| 
 | ||
| 		p.logger.Error("数据库操作失败",
 | ||
| 			zap.String("operation", fmt.Sprintf("%v", operation)),
 | ||
| 			zap.String("table", fmt.Sprintf("%v", tableName)),
 | ||
| 			zap.Error(db.Error),
 | ||
| 			zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
 | ||
| 		)
 | ||
| 	} else {
 | ||
| 		p.tracer.SetSpanSuccess(span)
 | ||
| 		span.SetStatus(codes.Ok, "success")
 | ||
| 
 | ||
| 		// 检查慢查询
 | ||
| 		if duration > p.config.SlowThreshold {
 | ||
| 			p.tracer.AddSpanAttributes(span,
 | ||
| 				attribute.Bool("db.slow_query", true),
 | ||
| 			)
 | ||
| 
 | ||
| 			p.logger.Warn("慢SQL查询检测",
 | ||
| 				zap.String("operation", fmt.Sprintf("%v", operation)),
 | ||
| 				zap.String("table", fmt.Sprintf("%v", tableName)),
 | ||
| 				zap.Duration("duration", duration),
 | ||
| 				zap.String("sql", db.Statement.SQL.String()),
 | ||
| 				zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
 | ||
| 			)
 | ||
| 		}
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // shouldSkipTracing 检查是否应该跳过追踪
 | ||
| func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool {
 | ||
| 	// 检查是否已有span(避免重复追踪)
 | ||
| 	if _, exists := db.Get(gormSpanKey); exists {
 | ||
| 		return true
 | ||
| 	}
 | ||
| 
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // getOperationType 获取操作类型
 | ||
| func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string {
 | ||
| 	switch db.Statement.ReflectValue.Kind() {
 | ||
| 	default:
 | ||
| 		sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String()))
 | ||
| 		if sql == "" {
 | ||
| 			return "unknown"
 | ||
| 		}
 | ||
| 
 | ||
| 		if strings.HasPrefix(sql, "SELECT") {
 | ||
| 			return "select"
 | ||
| 		} else if strings.HasPrefix(sql, "INSERT") {
 | ||
| 			return "insert"
 | ||
| 		} else if strings.HasPrefix(sql, "UPDATE") {
 | ||
| 			return "update"
 | ||
| 		} else if strings.HasPrefix(sql, "DELETE") {
 | ||
| 			return "delete"
 | ||
| 		} else if strings.HasPrefix(sql, "CREATE") {
 | ||
| 			return "create"
 | ||
| 		} else if strings.HasPrefix(sql, "DROP") {
 | ||
| 			return "drop"
 | ||
| 		} else if strings.HasPrefix(sql, "ALTER") {
 | ||
| 			return "alter"
 | ||
| 		}
 | ||
| 
 | ||
| 		return "query"
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // getTableName 获取表名
 | ||
| func (p *GormTracingPlugin) getTableName(db *gorm.DB) string {
 | ||
| 	if db.Statement.Table != "" {
 | ||
| 		return db.Statement.Table
 | ||
| 	}
 | ||
| 
 | ||
| 	if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
 | ||
| 		return db.Statement.Schema.Table
 | ||
| 	}
 | ||
| 
 | ||
| 	return ""
 | ||
| }
 | ||
| 
 | ||
| // isExcludedTable 检查是否为排除的表
 | ||
| func (p *GormTracingPlugin) isExcludedTable(tableName string) bool {
 | ||
| 	for _, excluded := range p.config.ExcludeTables {
 | ||
| 		if tableName == excluded {
 | ||
| 			return true
 | ||
| 		}
 | ||
| 	}
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // sanitizeSQL 清理SQL语句,移除敏感信息
 | ||
| func (p *GormTracingPlugin) sanitizeSQL(sql string) string {
 | ||
| 	// 简单的SQL清理,将参数替换为占位符
 | ||
| 	// 在生产环境中,您可能需要更复杂的清理逻辑
 | ||
| 	return strings.ReplaceAll(sql, "'", "?")
 | ||
| }
 |