302 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			302 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package database
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"errors"
 | ||
| 	"strings"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"go.uber.org/zap"
 | ||
| 	"gorm.io/gorm"
 | ||
| )
 | ||
| 
 | ||
| // 自定义错误类型
 | ||
| var (
 | ||
| 	ErrTransactionRollback = errors.New("事务回滚失败")
 | ||
| 	ErrTransactionCommit   = errors.New("事务提交失败")
 | ||
| )
 | ||
| 
 | ||
| // 定义context key
 | ||
| type txKey struct{}
 | ||
| 
 | ||
| // WithTx 将事务对象存储到context中
 | ||
| func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
 | ||
| 	return context.WithValue(ctx, txKey{}, tx)
 | ||
| }
 | ||
| 
 | ||
| // GetTx 从context中获取事务对象
 | ||
| func GetTx(ctx context.Context) (*gorm.DB, bool) {
 | ||
| 	tx, ok := ctx.Value(txKey{}).(*gorm.DB)
 | ||
| 	return tx, ok
 | ||
| }
 | ||
| 
 | ||
| // TransactionManager 事务管理器
 | ||
| type TransactionManager struct {
 | ||
| 	db     *gorm.DB
 | ||
| 	logger *zap.Logger
 | ||
| }
 | ||
| 
 | ||
| // NewTransactionManager 创建事务管理器
 | ||
| func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
 | ||
| 	return &TransactionManager{
 | ||
| 		db:     db,
 | ||
| 		logger: logger,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // ExecuteInTx 在事务中执行函数(推荐使用)
 | ||
| // 自动处理事务的开启、提交和回滚
 | ||
| func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
 | ||
| 	// 检查是否已经在事务中
 | ||
| 	if _, ok := GetTx(ctx); ok {
 | ||
| 		// 如果已经在事务中,直接执行函数,避免嵌套事务
 | ||
| 		return fn(ctx)
 | ||
| 	}
 | ||
| 
 | ||
| 	tx := tm.db.Begin()
 | ||
| 	if tx.Error != nil {
 | ||
| 		return tx.Error
 | ||
| 	}
 | ||
| 
 | ||
| 	// 创建带事务的context
 | ||
| 	txCtx := WithTx(ctx, tx)
 | ||
| 
 | ||
| 	// 执行函数
 | ||
| 	if err := fn(txCtx); err != nil {
 | ||
| 		// 回滚事务
 | ||
| 		if rbErr := tx.Rollback().Error; rbErr != nil {
 | ||
| 			tm.logger.Error("事务回滚失败",
 | ||
| 				zap.Error(err),
 | ||
| 				zap.Error(rbErr),
 | ||
| 			)
 | ||
| 			return errors.Join(err, ErrTransactionRollback, rbErr)
 | ||
| 		}
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	// 提交事务
 | ||
| 	if err := tx.Commit().Error; err != nil {
 | ||
| 		tm.logger.Error("事务提交失败", zap.Error(err))
 | ||
| 		return errors.Join(ErrTransactionCommit, err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // ExecuteInTxWithTimeout 在事务中执行函数(带超时)
 | ||
| func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
 | ||
| 	ctx, cancel := context.WithTimeout(ctx, timeout)
 | ||
| 	defer cancel()
 | ||
| 
 | ||
| 	return tm.ExecuteInTx(ctx, fn)
 | ||
| }
 | ||
| 
 | ||
| // BeginTx 开始事务(手动管理)
 | ||
| func (tm *TransactionManager) BeginTx() *gorm.DB {
 | ||
| 	return tm.db.Begin()
 | ||
| }
 | ||
| 
 | ||
| // TxWrapper 事务包装器(手动管理)
 | ||
| type TxWrapper struct {
 | ||
| 	tx *gorm.DB
 | ||
| }
 | ||
| 
 | ||
| // NewTxWrapper 创建事务包装器
 | ||
| func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
 | ||
| 	return &TxWrapper{
 | ||
| 		tx: tm.BeginTx(),
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // Commit 提交事务
 | ||
| func (tx *TxWrapper) Commit() error {
 | ||
| 	return tx.tx.Commit().Error
 | ||
| }
 | ||
| 
 | ||
| // Rollback 回滚事务
 | ||
| func (tx *TxWrapper) Rollback() error {
 | ||
| 	return tx.tx.Rollback().Error
 | ||
| }
 | ||
| 
 | ||
| // GetDB 获取事务数据库实例
 | ||
| func (tx *TxWrapper) GetDB() *gorm.DB {
 | ||
| 	return tx.tx
 | ||
| }
 | ||
| 
 | ||
| // WithTx 在事务中执行函数(兼容旧接口)
 | ||
| func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
 | ||
| 	tx := tm.BeginTx()
 | ||
| 	defer func() {
 | ||
| 		if r := recover(); r != nil {
 | ||
| 			tx.Rollback()
 | ||
| 			panic(r)
 | ||
| 		}
 | ||
| 	}()
 | ||
| 
 | ||
| 	if err := fn(tx); err != nil {
 | ||
| 		tx.Rollback()
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	return tx.Commit().Error
 | ||
| }
 | ||
| 
 | ||
| // TransactionOptions 事务选项
 | ||
| type TransactionOptions struct {
 | ||
| 	Timeout  time.Duration
 | ||
| 	ReadOnly bool // 是否只读事务
 | ||
| }
 | ||
| 
 | ||
| // ExecuteInTxWithOptions 在事务中执行函数(带选项)
 | ||
| func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
 | ||
| 	// 设置事务选项
 | ||
| 	tx := tm.db.Begin()
 | ||
| 	if tx.Error != nil {
 | ||
| 		return tx.Error
 | ||
| 	}
 | ||
| 
 | ||
| 	// 设置只读事务
 | ||
| 	if options != nil && options.ReadOnly {
 | ||
| 		tx = tx.Session(&gorm.Session{})
 | ||
| 		// 注意:GORM的只读事务需要数据库支持,这里只是标记
 | ||
| 	}
 | ||
| 
 | ||
| 	// 创建带事务的context
 | ||
| 	txCtx := WithTx(ctx, tx)
 | ||
| 
 | ||
| 	// 设置超时
 | ||
| 	if options != nil && options.Timeout > 0 {
 | ||
| 		var cancel context.CancelFunc
 | ||
| 		txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
 | ||
| 		defer cancel()
 | ||
| 	}
 | ||
| 
 | ||
| 	// 执行函数
 | ||
| 	if err := fn(txCtx); err != nil {
 | ||
| 		// 回滚事务
 | ||
| 		if rbErr := tx.Rollback().Error; rbErr != nil {
 | ||
| 			return err
 | ||
| 		}
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	// 提交事务
 | ||
| 	return tx.Commit().Error
 | ||
| }
 | ||
| 
 | ||
| // TransactionStats 事务统计信息
 | ||
| type TransactionStats struct {
 | ||
| 	TotalTransactions      int64
 | ||
| 	SuccessfulTransactions int64
 | ||
| 	FailedTransactions     int64
 | ||
| 	AverageDuration        time.Duration
 | ||
| }
 | ||
| 
 | ||
| // GetStats 获取事务统计信息(预留接口)
 | ||
| func (tm *TransactionManager) GetStats() *TransactionStats {
 | ||
| 	// TODO: 实现事务统计
 | ||
| 	return &TransactionStats{}
 | ||
| }
 | ||
| 
 | ||
| // RetryableTransactionOptions 可重试事务选项
 | ||
| type RetryableTransactionOptions struct {
 | ||
| 	MaxRetries   int           // 最大重试次数
 | ||
| 	RetryDelay   time.Duration // 重试延迟
 | ||
| 	RetryBackoff float64       // 退避倍数
 | ||
| }
 | ||
| 
 | ||
| // DefaultRetryableOptions 默认重试选项
 | ||
| func DefaultRetryableOptions() *RetryableTransactionOptions {
 | ||
| 	return &RetryableTransactionOptions{
 | ||
| 		MaxRetries:   3,
 | ||
| 		RetryDelay:   100 * time.Millisecond,
 | ||
| 		RetryBackoff: 2.0,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // ExecuteInTxWithRetry 在事务中执行函数(支持重试)
 | ||
| // 适用于处理死锁等临时性错误
 | ||
| func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
 | ||
| 	if options == nil {
 | ||
| 		options = DefaultRetryableOptions()
 | ||
| 	}
 | ||
| 
 | ||
| 	var lastErr error
 | ||
| 	delay := options.RetryDelay
 | ||
| 
 | ||
| 	for attempt := 0; attempt <= options.MaxRetries; attempt++ {
 | ||
| 		// 检查上下文是否已取消
 | ||
| 		if ctx.Err() != nil {
 | ||
| 			return ctx.Err()
 | ||
| 		}
 | ||
| 
 | ||
| 		err := tm.ExecuteInTx(ctx, fn)
 | ||
| 		if err == nil {
 | ||
| 			return nil
 | ||
| 		}
 | ||
| 
 | ||
| 		// 检查是否是可重试的错误(死锁、连接错误等)
 | ||
| 		if !isRetryableError(err) {
 | ||
| 			return err
 | ||
| 		}
 | ||
| 
 | ||
| 		lastErr = err
 | ||
| 
 | ||
| 		// 如果不是最后一次尝试,等待后重试
 | ||
| 		if attempt < options.MaxRetries {
 | ||
| 			tm.logger.Warn("事务执行失败,准备重试",
 | ||
| 				zap.Int("attempt", attempt+1),
 | ||
| 				zap.Int("max_retries", options.MaxRetries),
 | ||
| 				zap.Duration("delay", delay),
 | ||
| 				zap.Error(err),
 | ||
| 			)
 | ||
| 
 | ||
| 			select {
 | ||
| 			case <-time.After(delay):
 | ||
| 				delay = time.Duration(float64(delay) * options.RetryBackoff)
 | ||
| 			case <-ctx.Done():
 | ||
| 				return ctx.Err()
 | ||
| 			}
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Error("事务执行失败,已超过最大重试次数",
 | ||
| 		zap.Int("max_retries", options.MaxRetries),
 | ||
| 		zap.Error(lastErr),
 | ||
| 	)
 | ||
| 
 | ||
| 	return lastErr
 | ||
| }
 | ||
| 
 | ||
| // isRetryableError 判断是否是可重试的错误
 | ||
| func isRetryableError(err error) bool {
 | ||
| 	if err == nil {
 | ||
| 		return false
 | ||
| 	}
 | ||
| 
 | ||
| 	errStr := err.Error()
 | ||
| 
 | ||
| 	// MySQL 死锁错误
 | ||
| 	if contains(errStr, "Deadlock found") {
 | ||
| 		return true
 | ||
| 	}
 | ||
| 
 | ||
| 	// MySQL 锁等待超时
 | ||
| 	if contains(errStr, "Lock wait timeout exceeded") {
 | ||
| 		return true
 | ||
| 	}
 | ||
| 
 | ||
| 	// 连接错误
 | ||
| 	if contains(errStr, "connection") {
 | ||
| 		return true
 | ||
| 	}
 | ||
| 
 | ||
| 	// 可以根据需要添加更多的可重试错误类型
 | ||
| 	return false
 | ||
| }
 | ||
| 
 | ||
| // contains 检查字符串是否包含子字符串(不区分大小写)
 | ||
| func contains(s, substr string) bool {
 | ||
| 	return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
 | ||
| }
 |