package database import ( "context" "errors" "strings" "time" "go.uber.org/zap" "gorm.io/gorm" ) // 自定义错误类型 var ( ErrTransactionRollback = errors.New("事务回滚失败") ErrTransactionCommit = errors.New("事务提交失败") ) // 定义context key type txKey struct{} // WithTx 将事务对象存储到context中 func WithTx(ctx context.Context, tx *gorm.DB) context.Context { return context.WithValue(ctx, txKey{}, tx) } // GetTx 从context中获取事务对象 func GetTx(ctx context.Context) (*gorm.DB, bool) { tx, ok := ctx.Value(txKey{}).(*gorm.DB) return tx, ok } // TransactionManager 事务管理器 type TransactionManager struct { db *gorm.DB logger *zap.Logger } // NewTransactionManager 创建事务管理器 func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager { return &TransactionManager{ db: db, logger: logger, } } // ExecuteInTx 在事务中执行函数(推荐使用) // 自动处理事务的开启、提交和回滚 func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error { // 检查是否已经在事务中 if _, ok := GetTx(ctx); ok { // 如果已经在事务中,直接执行函数,避免嵌套事务 return fn(ctx) } tx := tm.db.Begin() if tx.Error != nil { return tx.Error } // 创建带事务的context txCtx := WithTx(ctx, tx) // 执行函数 if err := fn(txCtx); err != nil { // 回滚事务 if rbErr := tx.Rollback().Error; rbErr != nil { tm.logger.Error("事务回滚失败", zap.Error(err), zap.Error(rbErr), ) return errors.Join(err, ErrTransactionRollback, rbErr) } return err } // 提交事务 if err := tx.Commit().Error; err != nil { tm.logger.Error("事务提交失败", zap.Error(err)) return errors.Join(ErrTransactionCommit, err) } return nil } // ExecuteInTxWithTimeout 在事务中执行函数(带超时) func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error { ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() return tm.ExecuteInTx(ctx, fn) } // BeginTx 开始事务(手动管理) func (tm *TransactionManager) BeginTx() *gorm.DB { return tm.db.Begin() } // TxWrapper 事务包装器(手动管理) type TxWrapper struct { tx *gorm.DB } // NewTxWrapper 创建事务包装器 func (tm *TransactionManager) NewTxWrapper() *TxWrapper { return &TxWrapper{ tx: tm.BeginTx(), } } // Commit 提交事务 func (tx *TxWrapper) Commit() error { return tx.tx.Commit().Error } // Rollback 回滚事务 func (tx *TxWrapper) Rollback() error { return tx.tx.Rollback().Error } // GetDB 获取事务数据库实例 func (tx *TxWrapper) GetDB() *gorm.DB { return tx.tx } // WithTx 在事务中执行函数(兼容旧接口) func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error { tx := tm.BeginTx() defer func() { if r := recover(); r != nil { tx.Rollback() panic(r) } }() if err := fn(tx); err != nil { tx.Rollback() return err } return tx.Commit().Error } // TransactionOptions 事务选项 type TransactionOptions struct { Timeout time.Duration ReadOnly bool // 是否只读事务 } // ExecuteInTxWithOptions 在事务中执行函数(带选项) func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error { // 设置事务选项 tx := tm.db.Begin() if tx.Error != nil { return tx.Error } // 设置只读事务 if options != nil && options.ReadOnly { tx = tx.Session(&gorm.Session{}) // 注意:GORM的只读事务需要数据库支持,这里只是标记 } // 创建带事务的context txCtx := WithTx(ctx, tx) // 设置超时 if options != nil && options.Timeout > 0 { var cancel context.CancelFunc txCtx, cancel = context.WithTimeout(txCtx, options.Timeout) defer cancel() } // 执行函数 if err := fn(txCtx); err != nil { // 回滚事务 if rbErr := tx.Rollback().Error; rbErr != nil { return err } return err } // 提交事务 return tx.Commit().Error } // TransactionStats 事务统计信息 type TransactionStats struct { TotalTransactions int64 SuccessfulTransactions int64 FailedTransactions int64 AverageDuration time.Duration } // GetStats 获取事务统计信息(预留接口) func (tm *TransactionManager) GetStats() *TransactionStats { // TODO: 实现事务统计 return &TransactionStats{} } // RetryableTransactionOptions 可重试事务选项 type RetryableTransactionOptions struct { MaxRetries int // 最大重试次数 RetryDelay time.Duration // 重试延迟 RetryBackoff float64 // 退避倍数 } // DefaultRetryableOptions 默认重试选项 func DefaultRetryableOptions() *RetryableTransactionOptions { return &RetryableTransactionOptions{ MaxRetries: 3, RetryDelay: 100 * time.Millisecond, RetryBackoff: 2.0, } } // ExecuteInTxWithRetry 在事务中执行函数(支持重试) // 适用于处理死锁等临时性错误 func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error { if options == nil { options = DefaultRetryableOptions() } var lastErr error delay := options.RetryDelay for attempt := 0; attempt <= options.MaxRetries; attempt++ { // 检查上下文是否已取消 if ctx.Err() != nil { return ctx.Err() } err := tm.ExecuteInTx(ctx, fn) if err == nil { return nil } // 检查是否是可重试的错误(死锁、连接错误等) if !isRetryableError(err) { return err } lastErr = err // 如果不是最后一次尝试,等待后重试 if attempt < options.MaxRetries { tm.logger.Warn("事务执行失败,准备重试", zap.Int("attempt", attempt+1), zap.Int("max_retries", options.MaxRetries), zap.Duration("delay", delay), zap.Error(err), ) select { case <-time.After(delay): delay = time.Duration(float64(delay) * options.RetryBackoff) case <-ctx.Done(): return ctx.Err() } } } tm.logger.Error("事务执行失败,已超过最大重试次数", zap.Int("max_retries", options.MaxRetries), zap.Error(lastErr), ) return lastErr } // isRetryableError 判断是否是可重试的错误 func isRetryableError(err error) bool { if err == nil { return false } errStr := err.Error() // MySQL 死锁错误 if contains(errStr, "Deadlock found") { return true } // MySQL 锁等待超时 if contains(errStr, "Lock wait timeout exceeded") { return true } // 连接错误 if contains(errStr, "connection") { return true } // 可以根据需要添加更多的可重试错误类型 return false } // contains 检查字符串是否包含子字符串(不区分大小写) func contains(s, substr string) bool { return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) }