321 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			321 lines
		
	
	
		
			8.5 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package tracing | |||
|  | 
 | |||
|  | import ( | |||
|  | 	"context" | |||
|  | 	"fmt" | |||
|  | 	"strings" | |||
|  | 	"time" | |||
|  | 
 | |||
|  | 	"go.opentelemetry.io/otel/attribute" | |||
|  | 	"go.opentelemetry.io/otel/codes" | |||
|  | 	"go.opentelemetry.io/otel/trace" | |||
|  | 	"go.uber.org/zap" | |||
|  | 	"gorm.io/gorm" | |||
|  | ) | |||
|  | 
 | |||
|  | const ( | |||
|  | 	gormSpanKey      = "otel:span" | |||
|  | 	gormOperationKey = "otel:operation" | |||
|  | 	gormTableNameKey = "otel:table_name" | |||
|  | 	gormStartTimeKey = "otel:start_time" | |||
|  | ) | |||
|  | 
 | |||
|  | // GormTracingPlugin GORM链路追踪插件 | |||
|  | type GormTracingPlugin struct { | |||
|  | 	tracer *Tracer | |||
|  | 	logger *zap.Logger | |||
|  | 	config GormPluginConfig | |||
|  | } | |||
|  | 
 | |||
|  | // GormPluginConfig GORM插件配置 | |||
|  | type GormPluginConfig struct { | |||
|  | 	IncludeSQL    bool | |||
|  | 	IncludeValues bool | |||
|  | 	SlowThreshold time.Duration | |||
|  | 	ExcludeTables []string | |||
|  | 	SanitizeSQL   bool | |||
|  | } | |||
|  | 
 | |||
|  | // DefaultGormPluginConfig 默认GORM插件配置 | |||
|  | func DefaultGormPluginConfig() GormPluginConfig { | |||
|  | 	return GormPluginConfig{ | |||
|  | 		IncludeSQL:    true, | |||
|  | 		IncludeValues: false, // 生产环境建议设为false避免记录敏感数据 | |||
|  | 		SlowThreshold: 200 * time.Millisecond, | |||
|  | 		ExcludeTables: []string{"migrations", "schema_migrations"}, | |||
|  | 		SanitizeSQL:   true, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // NewGormTracingPlugin 创建GORM追踪插件 | |||
|  | func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin { | |||
|  | 	return &GormTracingPlugin{ | |||
|  | 		tracer: tracer, | |||
|  | 		logger: logger, | |||
|  | 		config: DefaultGormPluginConfig(), | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // Name 返回插件名称 | |||
|  | func (p *GormTracingPlugin) Name() string { | |||
|  | 	return "gorm-otel-tracing" | |||
|  | } | |||
|  | 
 | |||
|  | // Initialize 初始化插件 | |||
|  | func (p *GormTracingPlugin) Initialize(db *gorm.DB) error { | |||
|  | 	// 注册各种操作的回调 | |||
|  | 	callbacks := []string{"create", "query", "update", "delete", "raw"} | |||
|  | 
 | |||
|  | 	for _, operation := range callbacks { | |||
|  | 		switch operation { | |||
|  | 		case "create": | |||
|  | 			err := db.Callback().Create().Before("gorm:create"). | |||
|  | 				Register(p.Name()+":before_create", p.beforeOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register before create callback: %w", err) | |||
|  | 			} | |||
|  | 			err = db.Callback().Create().After("gorm:create"). | |||
|  | 				Register(p.Name()+":after_create", p.afterOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register after create callback: %w", err) | |||
|  | 			} | |||
|  | 		case "query": | |||
|  | 			err := db.Callback().Query().Before("gorm:query"). | |||
|  | 				Register(p.Name()+":before_query", p.beforeOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register before query callback: %w", err) | |||
|  | 			} | |||
|  | 			err = db.Callback().Query().After("gorm:query"). | |||
|  | 				Register(p.Name()+":after_query", p.afterOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register after query callback: %w", err) | |||
|  | 			} | |||
|  | 		case "update": | |||
|  | 			err := db.Callback().Update().Before("gorm:update"). | |||
|  | 				Register(p.Name()+":before_update", p.beforeOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register before update callback: %w", err) | |||
|  | 			} | |||
|  | 			err = db.Callback().Update().After("gorm:update"). | |||
|  | 				Register(p.Name()+":after_update", p.afterOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register after update callback: %w", err) | |||
|  | 			} | |||
|  | 		case "delete": | |||
|  | 			err := db.Callback().Delete().Before("gorm:delete"). | |||
|  | 				Register(p.Name()+":before_delete", p.beforeOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register before delete callback: %w", err) | |||
|  | 			} | |||
|  | 			err = db.Callback().Delete().After("gorm:delete"). | |||
|  | 				Register(p.Name()+":after_delete", p.afterOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register after delete callback: %w", err) | |||
|  | 			} | |||
|  | 		case "raw": | |||
|  | 			err := db.Callback().Raw().Before("gorm:raw"). | |||
|  | 				Register(p.Name()+":before_raw", p.beforeOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register before raw callback: %w", err) | |||
|  | 			} | |||
|  | 			err = db.Callback().Raw().After("gorm:raw"). | |||
|  | 				Register(p.Name()+":after_raw", p.afterOperation) | |||
|  | 			if err != nil { | |||
|  | 				return fmt.Errorf("failed to register after raw callback: %w", err) | |||
|  | 			} | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	p.logger.Info("GORM追踪插件已初始化") | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // beforeOperation 操作前回调 | |||
|  | func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) { | |||
|  | 	// 检查是否应该跳过追踪 | |||
|  | 	if p.shouldSkipTracing(db) { | |||
|  | 		return | |||
|  | 	} | |||
|  | 
 | |||
|  | 	ctx := db.Statement.Context | |||
|  | 	if ctx == nil { | |||
|  | 		ctx = context.Background() | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 获取操作信息 | |||
|  | 	operation := p.getOperationType(db) | |||
|  | 	tableName := p.getTableName(db) | |||
|  | 
 | |||
|  | 	// 检查是否应该排除此表 | |||
|  | 	if p.isExcludedTable(tableName) { | |||
|  | 		return | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 开始追踪 | |||
|  | 	ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName) | |||
|  | 
 | |||
|  | 	// 添加基础属性 | |||
|  | 	p.tracer.AddSpanAttributes(span, | |||
|  | 		attribute.String("db.system", "postgresql"), | |||
|  | 		attribute.String("db.operation", operation), | |||
|  | 	) | |||
|  | 
 | |||
|  | 	if tableName != "" { | |||
|  | 		p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName)) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 保存追踪信息到GORM context | |||
|  | 	db.Set(gormSpanKey, span) | |||
|  | 	db.Set(gormOperationKey, operation) | |||
|  | 	db.Set(gormTableNameKey, tableName) | |||
|  | 	db.Set(gormStartTimeKey, time.Now()) | |||
|  | 
 | |||
|  | 	// 更新statement context | |||
|  | 	db.Statement.Context = ctx | |||
|  | } | |||
|  | 
 | |||
|  | // afterOperation 操作后回调 | |||
|  | func (p *GormTracingPlugin) afterOperation(db *gorm.DB) { | |||
|  | 	// 获取span | |||
|  | 	spanValue, exists := db.Get(gormSpanKey) | |||
|  | 	if !exists { | |||
|  | 		return | |||
|  | 	} | |||
|  | 
 | |||
|  | 	span, ok := spanValue.(trace.Span) | |||
|  | 	if !ok { | |||
|  | 		return | |||
|  | 	} | |||
|  | 	defer span.End() | |||
|  | 
 | |||
|  | 	// 获取操作信息 | |||
|  | 	operation, _ := db.Get(gormOperationKey) | |||
|  | 	tableName, _ := db.Get(gormTableNameKey) | |||
|  | 	startTime, _ := db.Get(gormStartTimeKey) | |||
|  | 
 | |||
|  | 	// 计算执行时间 | |||
|  | 	var duration time.Duration | |||
|  | 	if st, ok := startTime.(time.Time); ok { | |||
|  | 		duration = time.Since(st) | |||
|  | 		p.tracer.AddSpanAttributes(span, | |||
|  | 			attribute.Int64("db.duration_ms", duration.Milliseconds()), | |||
|  | 		) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 添加SQL信息 | |||
|  | 	if p.config.IncludeSQL && db.Statement.SQL.String() != "" { | |||
|  | 		sql := db.Statement.SQL.String() | |||
|  | 		if p.config.SanitizeSQL { | |||
|  | 			sql = p.sanitizeSQL(sql) | |||
|  | 		} | |||
|  | 		p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql)) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 添加影响行数 | |||
|  | 	if db.Statement.RowsAffected >= 0 { | |||
|  | 		p.tracer.AddSpanAttributes(span, | |||
|  | 			attribute.Int64("db.rows_affected", db.Statement.RowsAffected), | |||
|  | 		) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 处理错误 | |||
|  | 	if db.Error != nil { | |||
|  | 		p.tracer.SetSpanError(span, db.Error) | |||
|  | 		span.SetStatus(codes.Error, db.Error.Error()) | |||
|  | 
 | |||
|  | 		p.logger.Error("数据库操作失败", | |||
|  | 			zap.String("operation", fmt.Sprintf("%v", operation)), | |||
|  | 			zap.String("table", fmt.Sprintf("%v", tableName)), | |||
|  | 			zap.Error(db.Error), | |||
|  | 			zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)), | |||
|  | 		) | |||
|  | 	} else { | |||
|  | 		p.tracer.SetSpanSuccess(span) | |||
|  | 		span.SetStatus(codes.Ok, "success") | |||
|  | 
 | |||
|  | 		// 检查慢查询 | |||
|  | 		if duration > p.config.SlowThreshold { | |||
|  | 			p.tracer.AddSpanAttributes(span, | |||
|  | 				attribute.Bool("db.slow_query", true), | |||
|  | 			) | |||
|  | 
 | |||
|  | 			p.logger.Warn("慢SQL查询检测", | |||
|  | 				zap.String("operation", fmt.Sprintf("%v", operation)), | |||
|  | 				zap.String("table", fmt.Sprintf("%v", tableName)), | |||
|  | 				zap.Duration("duration", duration), | |||
|  | 				zap.String("sql", db.Statement.SQL.String()), | |||
|  | 				zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)), | |||
|  | 			) | |||
|  | 		} | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // shouldSkipTracing 检查是否应该跳过追踪 | |||
|  | func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool { | |||
|  | 	// 检查是否已有span(避免重复追踪) | |||
|  | 	if _, exists := db.Get(gormSpanKey); exists { | |||
|  | 		return true | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return false | |||
|  | } | |||
|  | 
 | |||
|  | // getOperationType 获取操作类型 | |||
|  | func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string { | |||
|  | 	switch db.Statement.ReflectValue.Kind() { | |||
|  | 	default: | |||
|  | 		sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String())) | |||
|  | 		if sql == "" { | |||
|  | 			return "unknown" | |||
|  | 		} | |||
|  | 
 | |||
|  | 		if strings.HasPrefix(sql, "SELECT") { | |||
|  | 			return "select" | |||
|  | 		} else if strings.HasPrefix(sql, "INSERT") { | |||
|  | 			return "insert" | |||
|  | 		} else if strings.HasPrefix(sql, "UPDATE") { | |||
|  | 			return "update" | |||
|  | 		} else if strings.HasPrefix(sql, "DELETE") { | |||
|  | 			return "delete" | |||
|  | 		} else if strings.HasPrefix(sql, "CREATE") { | |||
|  | 			return "create" | |||
|  | 		} else if strings.HasPrefix(sql, "DROP") { | |||
|  | 			return "drop" | |||
|  | 		} else if strings.HasPrefix(sql, "ALTER") { | |||
|  | 			return "alter" | |||
|  | 		} | |||
|  | 
 | |||
|  | 		return "query" | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // getTableName 获取表名 | |||
|  | func (p *GormTracingPlugin) getTableName(db *gorm.DB) string { | |||
|  | 	if db.Statement.Table != "" { | |||
|  | 		return db.Statement.Table | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if db.Statement.Schema != nil && db.Statement.Schema.Table != "" { | |||
|  | 		return db.Statement.Schema.Table | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return "" | |||
|  | } | |||
|  | 
 | |||
|  | // isExcludedTable 检查是否为排除的表 | |||
|  | func (p *GormTracingPlugin) isExcludedTable(tableName string) bool { | |||
|  | 	for _, excluded := range p.config.ExcludeTables { | |||
|  | 		if tableName == excluded { | |||
|  | 			return true | |||
|  | 		} | |||
|  | 	} | |||
|  | 	return false | |||
|  | } | |||
|  | 
 | |||
|  | // sanitizeSQL 清理SQL语句,移除敏感信息 | |||
|  | func (p *GormTracingPlugin) sanitizeSQL(sql string) string { | |||
|  | 	// 简单的SQL清理,将参数替换为占位符 | |||
|  | 	// 在生产环境中,您可能需要更复杂的清理逻辑 | |||
|  | 	return strings.ReplaceAll(sql, "'", "?") | |||
|  | } |