package tracing import ( "context" "fmt" "strings" "time" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" "go.uber.org/zap" "gorm.io/gorm" ) const ( gormSpanKey = "otel:span" gormOperationKey = "otel:operation" gormTableNameKey = "otel:table_name" gormStartTimeKey = "otel:start_time" ) // GormTracingPlugin GORM链路追踪插件 type GormTracingPlugin struct { tracer *Tracer logger *zap.Logger config GormPluginConfig } // GormPluginConfig GORM插件配置 type GormPluginConfig struct { IncludeSQL bool IncludeValues bool SlowThreshold time.Duration ExcludeTables []string SanitizeSQL bool } // DefaultGormPluginConfig 默认GORM插件配置 func DefaultGormPluginConfig() GormPluginConfig { return GormPluginConfig{ IncludeSQL: true, IncludeValues: false, // 生产环境建议设为false避免记录敏感数据 SlowThreshold: 200 * time.Millisecond, ExcludeTables: []string{"migrations", "schema_migrations"}, SanitizeSQL: true, } } // NewGormTracingPlugin 创建GORM追踪插件 func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin { return &GormTracingPlugin{ tracer: tracer, logger: logger, config: DefaultGormPluginConfig(), } } // Name 返回插件名称 func (p *GormTracingPlugin) Name() string { return "gorm-otel-tracing" } // Initialize 初始化插件 func (p *GormTracingPlugin) Initialize(db *gorm.DB) error { // 注册各种操作的回调 callbacks := []string{"create", "query", "update", "delete", "raw"} for _, operation := range callbacks { switch operation { case "create": err := db.Callback().Create().Before("gorm:create"). Register(p.Name()+":before_create", p.beforeOperation) if err != nil { return fmt.Errorf("failed to register before create callback: %w", err) } err = db.Callback().Create().After("gorm:create"). Register(p.Name()+":after_create", p.afterOperation) if err != nil { return fmt.Errorf("failed to register after create callback: %w", err) } case "query": err := db.Callback().Query().Before("gorm:query"). Register(p.Name()+":before_query", p.beforeOperation) if err != nil { return fmt.Errorf("failed to register before query callback: %w", err) } err = db.Callback().Query().After("gorm:query"). Register(p.Name()+":after_query", p.afterOperation) if err != nil { return fmt.Errorf("failed to register after query callback: %w", err) } case "update": err := db.Callback().Update().Before("gorm:update"). Register(p.Name()+":before_update", p.beforeOperation) if err != nil { return fmt.Errorf("failed to register before update callback: %w", err) } err = db.Callback().Update().After("gorm:update"). Register(p.Name()+":after_update", p.afterOperation) if err != nil { return fmt.Errorf("failed to register after update callback: %w", err) } case "delete": err := db.Callback().Delete().Before("gorm:delete"). Register(p.Name()+":before_delete", p.beforeOperation) if err != nil { return fmt.Errorf("failed to register before delete callback: %w", err) } err = db.Callback().Delete().After("gorm:delete"). Register(p.Name()+":after_delete", p.afterOperation) if err != nil { return fmt.Errorf("failed to register after delete callback: %w", err) } case "raw": err := db.Callback().Raw().Before("gorm:raw"). Register(p.Name()+":before_raw", p.beforeOperation) if err != nil { return fmt.Errorf("failed to register before raw callback: %w", err) } err = db.Callback().Raw().After("gorm:raw"). Register(p.Name()+":after_raw", p.afterOperation) if err != nil { return fmt.Errorf("failed to register after raw callback: %w", err) } } } p.logger.Info("GORM追踪插件已初始化") return nil } // beforeOperation 操作前回调 func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) { // 检查是否应该跳过追踪 if p.shouldSkipTracing(db) { return } ctx := db.Statement.Context if ctx == nil { ctx = context.Background() } // 获取操作信息 operation := p.getOperationType(db) tableName := p.getTableName(db) // 检查是否应该排除此表 if p.isExcludedTable(tableName) { return } // 开始追踪 ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName) // 添加基础属性 p.tracer.AddSpanAttributes(span, attribute.String("db.system", "postgresql"), attribute.String("db.operation", operation), ) if tableName != "" { p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName)) } // 保存追踪信息到GORM context db.Set(gormSpanKey, span) db.Set(gormOperationKey, operation) db.Set(gormTableNameKey, tableName) db.Set(gormStartTimeKey, time.Now()) // 更新statement context db.Statement.Context = ctx } // afterOperation 操作后回调 func (p *GormTracingPlugin) afterOperation(db *gorm.DB) { // 获取span spanValue, exists := db.Get(gormSpanKey) if !exists { return } span, ok := spanValue.(trace.Span) if !ok { return } defer span.End() // 获取操作信息 operation, _ := db.Get(gormOperationKey) tableName, _ := db.Get(gormTableNameKey) startTime, _ := db.Get(gormStartTimeKey) // 计算执行时间 var duration time.Duration if st, ok := startTime.(time.Time); ok { duration = time.Since(st) p.tracer.AddSpanAttributes(span, attribute.Int64("db.duration_ms", duration.Milliseconds()), ) } // 添加SQL信息 if p.config.IncludeSQL && db.Statement.SQL.String() != "" { sql := db.Statement.SQL.String() if p.config.SanitizeSQL { sql = p.sanitizeSQL(sql) } p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql)) } // 添加影响行数 if db.Statement.RowsAffected >= 0 { p.tracer.AddSpanAttributes(span, attribute.Int64("db.rows_affected", db.Statement.RowsAffected), ) } // 处理错误 if db.Error != nil { p.tracer.SetSpanError(span, db.Error) span.SetStatus(codes.Error, db.Error.Error()) p.logger.Error("数据库操作失败", zap.String("operation", fmt.Sprintf("%v", operation)), zap.String("table", fmt.Sprintf("%v", tableName)), zap.Error(db.Error), zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)), ) } else { p.tracer.SetSpanSuccess(span) span.SetStatus(codes.Ok, "success") // 检查慢查询 if duration > p.config.SlowThreshold { p.tracer.AddSpanAttributes(span, attribute.Bool("db.slow_query", true), ) p.logger.Warn("慢SQL查询检测", zap.String("operation", fmt.Sprintf("%v", operation)), zap.String("table", fmt.Sprintf("%v", tableName)), zap.Duration("duration", duration), zap.String("sql", db.Statement.SQL.String()), zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)), ) } } } // shouldSkipTracing 检查是否应该跳过追踪 func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool { // 检查是否已有span(避免重复追踪) if _, exists := db.Get(gormSpanKey); exists { return true } return false } // getOperationType 获取操作类型 func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string { switch db.Statement.ReflectValue.Kind() { default: sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String())) if sql == "" { return "unknown" } if strings.HasPrefix(sql, "SELECT") { return "select" } else if strings.HasPrefix(sql, "INSERT") { return "insert" } else if strings.HasPrefix(sql, "UPDATE") { return "update" } else if strings.HasPrefix(sql, "DELETE") { return "delete" } else if strings.HasPrefix(sql, "CREATE") { return "create" } else if strings.HasPrefix(sql, "DROP") { return "drop" } else if strings.HasPrefix(sql, "ALTER") { return "alter" } return "query" } } // getTableName 获取表名 func (p *GormTracingPlugin) getTableName(db *gorm.DB) string { if db.Statement.Table != "" { return db.Statement.Table } if db.Statement.Schema != nil && db.Statement.Schema.Table != "" { return db.Statement.Schema.Table } return "" } // isExcludedTable 检查是否为排除的表 func (p *GormTracingPlugin) isExcludedTable(tableName string) bool { for _, excluded := range p.config.ExcludeTables { if tableName == excluded { return true } } return false } // sanitizeSQL 清理SQL语句,移除敏感信息 func (p *GormTracingPlugin) sanitizeSQL(sql string) string { // 简单的SQL清理,将参数替换为占位符 // 在生产环境中,您可能需要更复杂的清理逻辑 return strings.ReplaceAll(sql, "'", "?") }