Files
tyapi-server/internal/shared/tracing/gorm_plugin.go

321 lines
8.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package tracing
import (
"context"
"fmt"
"strings"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
gormSpanKey = "otel:span"
gormOperationKey = "otel:operation"
gormTableNameKey = "otel:table_name"
gormStartTimeKey = "otel:start_time"
)
// GormTracingPlugin GORM链路追踪插件
type GormTracingPlugin struct {
tracer *Tracer
logger *zap.Logger
config GormPluginConfig
}
// GormPluginConfig GORM插件配置
type GormPluginConfig struct {
IncludeSQL bool
IncludeValues bool
SlowThreshold time.Duration
ExcludeTables []string
SanitizeSQL bool
}
// DefaultGormPluginConfig 默认GORM插件配置
func DefaultGormPluginConfig() GormPluginConfig {
return GormPluginConfig{
IncludeSQL: true,
IncludeValues: false, // 生产环境建议设为false避免记录敏感数据
SlowThreshold: 200 * time.Millisecond,
ExcludeTables: []string{"migrations", "schema_migrations"},
SanitizeSQL: true,
}
}
// NewGormTracingPlugin 创建GORM追踪插件
func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin {
return &GormTracingPlugin{
tracer: tracer,
logger: logger,
config: DefaultGormPluginConfig(),
}
}
// Name 返回插件名称
func (p *GormTracingPlugin) Name() string {
return "gorm-otel-tracing"
}
// Initialize 初始化插件
func (p *GormTracingPlugin) Initialize(db *gorm.DB) error {
// 注册各种操作的回调
callbacks := []string{"create", "query", "update", "delete", "raw"}
for _, operation := range callbacks {
switch operation {
case "create":
err := db.Callback().Create().Before("gorm:create").
Register(p.Name()+":before_create", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before create callback: %w", err)
}
err = db.Callback().Create().After("gorm:create").
Register(p.Name()+":after_create", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after create callback: %w", err)
}
case "query":
err := db.Callback().Query().Before("gorm:query").
Register(p.Name()+":before_query", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before query callback: %w", err)
}
err = db.Callback().Query().After("gorm:query").
Register(p.Name()+":after_query", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after query callback: %w", err)
}
case "update":
err := db.Callback().Update().Before("gorm:update").
Register(p.Name()+":before_update", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before update callback: %w", err)
}
err = db.Callback().Update().After("gorm:update").
Register(p.Name()+":after_update", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after update callback: %w", err)
}
case "delete":
err := db.Callback().Delete().Before("gorm:delete").
Register(p.Name()+":before_delete", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before delete callback: %w", err)
}
err = db.Callback().Delete().After("gorm:delete").
Register(p.Name()+":after_delete", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after delete callback: %w", err)
}
case "raw":
err := db.Callback().Raw().Before("gorm:raw").
Register(p.Name()+":before_raw", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before raw callback: %w", err)
}
err = db.Callback().Raw().After("gorm:raw").
Register(p.Name()+":after_raw", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after raw callback: %w", err)
}
}
}
p.logger.Info("GORM追踪插件已初始化")
return nil
}
// beforeOperation 操作前回调
func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) {
// 检查是否应该跳过追踪
if p.shouldSkipTracing(db) {
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
// 获取操作信息
operation := p.getOperationType(db)
tableName := p.getTableName(db)
// 检查是否应该排除此表
if p.isExcludedTable(tableName) {
return
}
// 开始追踪
ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName)
// 添加基础属性
p.tracer.AddSpanAttributes(span,
attribute.String("db.system", "postgresql"),
attribute.String("db.operation", operation),
)
if tableName != "" {
p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName))
}
// 保存追踪信息到GORM context
db.Set(gormSpanKey, span)
db.Set(gormOperationKey, operation)
db.Set(gormTableNameKey, tableName)
db.Set(gormStartTimeKey, time.Now())
// 更新statement context
db.Statement.Context = ctx
}
// afterOperation 操作后回调
func (p *GormTracingPlugin) afterOperation(db *gorm.DB) {
// 获取span
spanValue, exists := db.Get(gormSpanKey)
if !exists {
return
}
span, ok := spanValue.(trace.Span)
if !ok {
return
}
defer span.End()
// 获取操作信息
operation, _ := db.Get(gormOperationKey)
tableName, _ := db.Get(gormTableNameKey)
startTime, _ := db.Get(gormStartTimeKey)
// 计算执行时间
var duration time.Duration
if st, ok := startTime.(time.Time); ok {
duration = time.Since(st)
p.tracer.AddSpanAttributes(span,
attribute.Int64("db.duration_ms", duration.Milliseconds()),
)
}
// 添加SQL信息
if p.config.IncludeSQL && db.Statement.SQL.String() != "" {
sql := db.Statement.SQL.String()
if p.config.SanitizeSQL {
sql = p.sanitizeSQL(sql)
}
p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql))
}
// 添加影响行数
if db.Statement.RowsAffected >= 0 {
p.tracer.AddSpanAttributes(span,
attribute.Int64("db.rows_affected", db.Statement.RowsAffected),
)
}
// 处理错误
if db.Error != nil {
p.tracer.SetSpanError(span, db.Error)
span.SetStatus(codes.Error, db.Error.Error())
p.logger.Error("数据库操作失败",
zap.String("operation", fmt.Sprintf("%v", operation)),
zap.String("table", fmt.Sprintf("%v", tableName)),
zap.Error(db.Error),
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
)
} else {
p.tracer.SetSpanSuccess(span)
span.SetStatus(codes.Ok, "success")
// 检查慢查询
if duration > p.config.SlowThreshold {
p.tracer.AddSpanAttributes(span,
attribute.Bool("db.slow_query", true),
)
p.logger.Warn("慢SQL查询检测",
zap.String("operation", fmt.Sprintf("%v", operation)),
zap.String("table", fmt.Sprintf("%v", tableName)),
zap.Duration("duration", duration),
zap.String("sql", db.Statement.SQL.String()),
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
)
}
}
}
// shouldSkipTracing 检查是否应该跳过追踪
func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool {
// 检查是否已有span避免重复追踪
if _, exists := db.Get(gormSpanKey); exists {
return true
}
return false
}
// getOperationType 获取操作类型
func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string {
switch db.Statement.ReflectValue.Kind() {
default:
sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String()))
if sql == "" {
return "unknown"
}
if strings.HasPrefix(sql, "SELECT") {
return "select"
} else if strings.HasPrefix(sql, "INSERT") {
return "insert"
} else if strings.HasPrefix(sql, "UPDATE") {
return "update"
} else if strings.HasPrefix(sql, "DELETE") {
return "delete"
} else if strings.HasPrefix(sql, "CREATE") {
return "create"
} else if strings.HasPrefix(sql, "DROP") {
return "drop"
} else if strings.HasPrefix(sql, "ALTER") {
return "alter"
}
return "query"
}
}
// getTableName 获取表名
func (p *GormTracingPlugin) getTableName(db *gorm.DB) string {
if db.Statement.Table != "" {
return db.Statement.Table
}
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
return db.Statement.Schema.Table
}
return ""
}
// isExcludedTable 检查是否为排除的表
func (p *GormTracingPlugin) isExcludedTable(tableName string) bool {
for _, excluded := range p.config.ExcludeTables {
if tableName == excluded {
return true
}
}
return false
}
// sanitizeSQL 清理SQL语句移除敏感信息
func (p *GormTracingPlugin) sanitizeSQL(sql string) string {
// 简单的SQL清理将参数替换为占位符
// 在生产环境中,您可能需要更复杂的清理逻辑
return strings.ReplaceAll(sql, "'", "?")
}