321 lines
8.5 KiB
Go
321 lines
8.5 KiB
Go
package tracing
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
"time"
|
||
|
||
"go.opentelemetry.io/otel/attribute"
|
||
"go.opentelemetry.io/otel/codes"
|
||
"go.opentelemetry.io/otel/trace"
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
const (
|
||
gormSpanKey = "otel:span"
|
||
gormOperationKey = "otel:operation"
|
||
gormTableNameKey = "otel:table_name"
|
||
gormStartTimeKey = "otel:start_time"
|
||
)
|
||
|
||
// GormTracingPlugin GORM链路追踪插件
|
||
type GormTracingPlugin struct {
|
||
tracer *Tracer
|
||
logger *zap.Logger
|
||
config GormPluginConfig
|
||
}
|
||
|
||
// GormPluginConfig GORM插件配置
|
||
type GormPluginConfig struct {
|
||
IncludeSQL bool
|
||
IncludeValues bool
|
||
SlowThreshold time.Duration
|
||
ExcludeTables []string
|
||
SanitizeSQL bool
|
||
}
|
||
|
||
// DefaultGormPluginConfig 默认GORM插件配置
|
||
func DefaultGormPluginConfig() GormPluginConfig {
|
||
return GormPluginConfig{
|
||
IncludeSQL: true,
|
||
IncludeValues: false, // 生产环境建议设为false避免记录敏感数据
|
||
SlowThreshold: 200 * time.Millisecond,
|
||
ExcludeTables: []string{"migrations", "schema_migrations"},
|
||
SanitizeSQL: true,
|
||
}
|
||
}
|
||
|
||
// NewGormTracingPlugin 创建GORM追踪插件
|
||
func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin {
|
||
return &GormTracingPlugin{
|
||
tracer: tracer,
|
||
logger: logger,
|
||
config: DefaultGormPluginConfig(),
|
||
}
|
||
}
|
||
|
||
// Name 返回插件名称
|
||
func (p *GormTracingPlugin) Name() string {
|
||
return "gorm-otel-tracing"
|
||
}
|
||
|
||
// Initialize 初始化插件
|
||
func (p *GormTracingPlugin) Initialize(db *gorm.DB) error {
|
||
// 注册各种操作的回调
|
||
callbacks := []string{"create", "query", "update", "delete", "raw"}
|
||
|
||
for _, operation := range callbacks {
|
||
switch operation {
|
||
case "create":
|
||
err := db.Callback().Create().Before("gorm:create").
|
||
Register(p.Name()+":before_create", p.beforeOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register before create callback: %w", err)
|
||
}
|
||
err = db.Callback().Create().After("gorm:create").
|
||
Register(p.Name()+":after_create", p.afterOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register after create callback: %w", err)
|
||
}
|
||
case "query":
|
||
err := db.Callback().Query().Before("gorm:query").
|
||
Register(p.Name()+":before_query", p.beforeOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register before query callback: %w", err)
|
||
}
|
||
err = db.Callback().Query().After("gorm:query").
|
||
Register(p.Name()+":after_query", p.afterOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register after query callback: %w", err)
|
||
}
|
||
case "update":
|
||
err := db.Callback().Update().Before("gorm:update").
|
||
Register(p.Name()+":before_update", p.beforeOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register before update callback: %w", err)
|
||
}
|
||
err = db.Callback().Update().After("gorm:update").
|
||
Register(p.Name()+":after_update", p.afterOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register after update callback: %w", err)
|
||
}
|
||
case "delete":
|
||
err := db.Callback().Delete().Before("gorm:delete").
|
||
Register(p.Name()+":before_delete", p.beforeOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register before delete callback: %w", err)
|
||
}
|
||
err = db.Callback().Delete().After("gorm:delete").
|
||
Register(p.Name()+":after_delete", p.afterOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register after delete callback: %w", err)
|
||
}
|
||
case "raw":
|
||
err := db.Callback().Raw().Before("gorm:raw").
|
||
Register(p.Name()+":before_raw", p.beforeOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register before raw callback: %w", err)
|
||
}
|
||
err = db.Callback().Raw().After("gorm:raw").
|
||
Register(p.Name()+":after_raw", p.afterOperation)
|
||
if err != nil {
|
||
return fmt.Errorf("failed to register after raw callback: %w", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
p.logger.Info("GORM追踪插件已初始化")
|
||
return nil
|
||
}
|
||
|
||
// beforeOperation 操作前回调
|
||
func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) {
|
||
// 检查是否应该跳过追踪
|
||
if p.shouldSkipTracing(db) {
|
||
return
|
||
}
|
||
|
||
ctx := db.Statement.Context
|
||
if ctx == nil {
|
||
ctx = context.Background()
|
||
}
|
||
|
||
// 获取操作信息
|
||
operation := p.getOperationType(db)
|
||
tableName := p.getTableName(db)
|
||
|
||
// 检查是否应该排除此表
|
||
if p.isExcludedTable(tableName) {
|
||
return
|
||
}
|
||
|
||
// 开始追踪
|
||
ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName)
|
||
|
||
// 添加基础属性
|
||
p.tracer.AddSpanAttributes(span,
|
||
attribute.String("db.system", "postgresql"),
|
||
attribute.String("db.operation", operation),
|
||
)
|
||
|
||
if tableName != "" {
|
||
p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName))
|
||
}
|
||
|
||
// 保存追踪信息到GORM context
|
||
db.Set(gormSpanKey, span)
|
||
db.Set(gormOperationKey, operation)
|
||
db.Set(gormTableNameKey, tableName)
|
||
db.Set(gormStartTimeKey, time.Now())
|
||
|
||
// 更新statement context
|
||
db.Statement.Context = ctx
|
||
}
|
||
|
||
// afterOperation 操作后回调
|
||
func (p *GormTracingPlugin) afterOperation(db *gorm.DB) {
|
||
// 获取span
|
||
spanValue, exists := db.Get(gormSpanKey)
|
||
if !exists {
|
||
return
|
||
}
|
||
|
||
span, ok := spanValue.(trace.Span)
|
||
if !ok {
|
||
return
|
||
}
|
||
defer span.End()
|
||
|
||
// 获取操作信息
|
||
operation, _ := db.Get(gormOperationKey)
|
||
tableName, _ := db.Get(gormTableNameKey)
|
||
startTime, _ := db.Get(gormStartTimeKey)
|
||
|
||
// 计算执行时间
|
||
var duration time.Duration
|
||
if st, ok := startTime.(time.Time); ok {
|
||
duration = time.Since(st)
|
||
p.tracer.AddSpanAttributes(span,
|
||
attribute.Int64("db.duration_ms", duration.Milliseconds()),
|
||
)
|
||
}
|
||
|
||
// 添加SQL信息
|
||
if p.config.IncludeSQL && db.Statement.SQL.String() != "" {
|
||
sql := db.Statement.SQL.String()
|
||
if p.config.SanitizeSQL {
|
||
sql = p.sanitizeSQL(sql)
|
||
}
|
||
p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql))
|
||
}
|
||
|
||
// 添加影响行数
|
||
if db.Statement.RowsAffected >= 0 {
|
||
p.tracer.AddSpanAttributes(span,
|
||
attribute.Int64("db.rows_affected", db.Statement.RowsAffected),
|
||
)
|
||
}
|
||
|
||
// 处理错误
|
||
if db.Error != nil {
|
||
p.tracer.SetSpanError(span, db.Error)
|
||
span.SetStatus(codes.Error, db.Error.Error())
|
||
|
||
p.logger.Error("数据库操作失败",
|
||
zap.String("operation", fmt.Sprintf("%v", operation)),
|
||
zap.String("table", fmt.Sprintf("%v", tableName)),
|
||
zap.Error(db.Error),
|
||
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
|
||
)
|
||
} else {
|
||
p.tracer.SetSpanSuccess(span)
|
||
span.SetStatus(codes.Ok, "success")
|
||
|
||
// 检查慢查询
|
||
if duration > p.config.SlowThreshold {
|
||
p.tracer.AddSpanAttributes(span,
|
||
attribute.Bool("db.slow_query", true),
|
||
)
|
||
|
||
p.logger.Warn("慢SQL查询检测",
|
||
zap.String("operation", fmt.Sprintf("%v", operation)),
|
||
zap.String("table", fmt.Sprintf("%v", tableName)),
|
||
zap.Duration("duration", duration),
|
||
zap.String("sql", db.Statement.SQL.String()),
|
||
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
|
||
)
|
||
}
|
||
}
|
||
}
|
||
|
||
// shouldSkipTracing 检查是否应该跳过追踪
|
||
func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool {
|
||
// 检查是否已有span(避免重复追踪)
|
||
if _, exists := db.Get(gormSpanKey); exists {
|
||
return true
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// getOperationType 获取操作类型
|
||
func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string {
|
||
switch db.Statement.ReflectValue.Kind() {
|
||
default:
|
||
sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String()))
|
||
if sql == "" {
|
||
return "unknown"
|
||
}
|
||
|
||
if strings.HasPrefix(sql, "SELECT") {
|
||
return "select"
|
||
} else if strings.HasPrefix(sql, "INSERT") {
|
||
return "insert"
|
||
} else if strings.HasPrefix(sql, "UPDATE") {
|
||
return "update"
|
||
} else if strings.HasPrefix(sql, "DELETE") {
|
||
return "delete"
|
||
} else if strings.HasPrefix(sql, "CREATE") {
|
||
return "create"
|
||
} else if strings.HasPrefix(sql, "DROP") {
|
||
return "drop"
|
||
} else if strings.HasPrefix(sql, "ALTER") {
|
||
return "alter"
|
||
}
|
||
|
||
return "query"
|
||
}
|
||
}
|
||
|
||
// getTableName 获取表名
|
||
func (p *GormTracingPlugin) getTableName(db *gorm.DB) string {
|
||
if db.Statement.Table != "" {
|
||
return db.Statement.Table
|
||
}
|
||
|
||
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
|
||
return db.Statement.Schema.Table
|
||
}
|
||
|
||
return ""
|
||
}
|
||
|
||
// isExcludedTable 检查是否为排除的表
|
||
func (p *GormTracingPlugin) isExcludedTable(tableName string) bool {
|
||
for _, excluded := range p.config.ExcludeTables {
|
||
if tableName == excluded {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// sanitizeSQL 清理SQL语句,移除敏感信息
|
||
func (p *GormTracingPlugin) sanitizeSQL(sql string) string {
|
||
// 简单的SQL清理,将参数替换为占位符
|
||
// 在生产环境中,您可能需要更复杂的清理逻辑
|
||
return strings.ReplaceAll(sql, "'", "?")
|
||
}
|