302 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			302 lines
		
	
	
		
			6.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package database | |||
|  | 
 | |||
|  | import ( | |||
|  | 	"context" | |||
|  | 	"errors" | |||
|  | 	"strings" | |||
|  | 	"time" | |||
|  | 
 | |||
|  | 	"go.uber.org/zap" | |||
|  | 	"gorm.io/gorm" | |||
|  | ) | |||
|  | 
 | |||
|  | // 自定义错误类型 | |||
|  | var ( | |||
|  | 	ErrTransactionRollback = errors.New("事务回滚失败") | |||
|  | 	ErrTransactionCommit   = errors.New("事务提交失败") | |||
|  | ) | |||
|  | 
 | |||
|  | // 定义context key | |||
|  | type txKey struct{} | |||
|  | 
 | |||
|  | // WithTx 将事务对象存储到context中 | |||
|  | func WithTx(ctx context.Context, tx *gorm.DB) context.Context { | |||
|  | 	return context.WithValue(ctx, txKey{}, tx) | |||
|  | } | |||
|  | 
 | |||
|  | // GetTx 从context中获取事务对象 | |||
|  | func GetTx(ctx context.Context) (*gorm.DB, bool) { | |||
|  | 	tx, ok := ctx.Value(txKey{}).(*gorm.DB) | |||
|  | 	return tx, ok | |||
|  | } | |||
|  | 
 | |||
|  | // TransactionManager 事务管理器 | |||
|  | type TransactionManager struct { | |||
|  | 	db     *gorm.DB | |||
|  | 	logger *zap.Logger | |||
|  | } | |||
|  | 
 | |||
|  | // NewTransactionManager 创建事务管理器 | |||
|  | func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager { | |||
|  | 	return &TransactionManager{ | |||
|  | 		db:     db, | |||
|  | 		logger: logger, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // ExecuteInTx 在事务中执行函数(推荐使用) | |||
|  | // 自动处理事务的开启、提交和回滚 | |||
|  | func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error { | |||
|  | 	// 检查是否已经在事务中 | |||
|  | 	if _, ok := GetTx(ctx); ok { | |||
|  | 		// 如果已经在事务中,直接执行函数,避免嵌套事务 | |||
|  | 		return fn(ctx) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tx := tm.db.Begin() | |||
|  | 	if tx.Error != nil { | |||
|  | 		return tx.Error | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 创建带事务的context | |||
|  | 	txCtx := WithTx(ctx, tx) | |||
|  | 
 | |||
|  | 	// 执行函数 | |||
|  | 	if err := fn(txCtx); err != nil { | |||
|  | 		// 回滚事务 | |||
|  | 		if rbErr := tx.Rollback().Error; rbErr != nil { | |||
|  | 			tm.logger.Error("事务回滚失败", | |||
|  | 				zap.Error(err), | |||
|  | 				zap.Error(rbErr), | |||
|  | 			) | |||
|  | 			return errors.Join(err, ErrTransactionRollback, rbErr) | |||
|  | 		} | |||
|  | 		return err | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 提交事务 | |||
|  | 	if err := tx.Commit().Error; err != nil { | |||
|  | 		tm.logger.Error("事务提交失败", zap.Error(err)) | |||
|  | 		return errors.Join(ErrTransactionCommit, err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // ExecuteInTxWithTimeout 在事务中执行函数(带超时) | |||
|  | func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error { | |||
|  | 	ctx, cancel := context.WithTimeout(ctx, timeout) | |||
|  | 	defer cancel() | |||
|  | 
 | |||
|  | 	return tm.ExecuteInTx(ctx, fn) | |||
|  | } | |||
|  | 
 | |||
|  | // BeginTx 开始事务(手动管理) | |||
|  | func (tm *TransactionManager) BeginTx() *gorm.DB { | |||
|  | 	return tm.db.Begin() | |||
|  | } | |||
|  | 
 | |||
|  | // TxWrapper 事务包装器(手动管理) | |||
|  | type TxWrapper struct { | |||
|  | 	tx *gorm.DB | |||
|  | } | |||
|  | 
 | |||
|  | // NewTxWrapper 创建事务包装器 | |||
|  | func (tm *TransactionManager) NewTxWrapper() *TxWrapper { | |||
|  | 	return &TxWrapper{ | |||
|  | 		tx: tm.BeginTx(), | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // Commit 提交事务 | |||
|  | func (tx *TxWrapper) Commit() error { | |||
|  | 	return tx.tx.Commit().Error | |||
|  | } | |||
|  | 
 | |||
|  | // Rollback 回滚事务 | |||
|  | func (tx *TxWrapper) Rollback() error { | |||
|  | 	return tx.tx.Rollback().Error | |||
|  | } | |||
|  | 
 | |||
|  | // GetDB 获取事务数据库实例 | |||
|  | func (tx *TxWrapper) GetDB() *gorm.DB { | |||
|  | 	return tx.tx | |||
|  | } | |||
|  | 
 | |||
|  | // WithTx 在事务中执行函数(兼容旧接口) | |||
|  | func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error { | |||
|  | 	tx := tm.BeginTx() | |||
|  | 	defer func() { | |||
|  | 		if r := recover(); r != nil { | |||
|  | 			tx.Rollback() | |||
|  | 			panic(r) | |||
|  | 		} | |||
|  | 	}() | |||
|  | 
 | |||
|  | 	if err := fn(tx); err != nil { | |||
|  | 		tx.Rollback() | |||
|  | 		return err | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return tx.Commit().Error | |||
|  | } | |||
|  | 
 | |||
|  | // TransactionOptions 事务选项 | |||
|  | type TransactionOptions struct { | |||
|  | 	Timeout  time.Duration | |||
|  | 	ReadOnly bool // 是否只读事务 | |||
|  | } | |||
|  | 
 | |||
|  | // ExecuteInTxWithOptions 在事务中执行函数(带选项) | |||
|  | func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error { | |||
|  | 	// 设置事务选项 | |||
|  | 	tx := tm.db.Begin() | |||
|  | 	if tx.Error != nil { | |||
|  | 		return tx.Error | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 设置只读事务 | |||
|  | 	if options != nil && options.ReadOnly { | |||
|  | 		tx = tx.Session(&gorm.Session{}) | |||
|  | 		// 注意:GORM的只读事务需要数据库支持,这里只是标记 | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 创建带事务的context | |||
|  | 	txCtx := WithTx(ctx, tx) | |||
|  | 
 | |||
|  | 	// 设置超时 | |||
|  | 	if options != nil && options.Timeout > 0 { | |||
|  | 		var cancel context.CancelFunc | |||
|  | 		txCtx, cancel = context.WithTimeout(txCtx, options.Timeout) | |||
|  | 		defer cancel() | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 执行函数 | |||
|  | 	if err := fn(txCtx); err != nil { | |||
|  | 		// 回滚事务 | |||
|  | 		if rbErr := tx.Rollback().Error; rbErr != nil { | |||
|  | 			return err | |||
|  | 		} | |||
|  | 		return err | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 提交事务 | |||
|  | 	return tx.Commit().Error | |||
|  | } | |||
|  | 
 | |||
|  | // TransactionStats 事务统计信息 | |||
|  | type TransactionStats struct { | |||
|  | 	TotalTransactions      int64 | |||
|  | 	SuccessfulTransactions int64 | |||
|  | 	FailedTransactions     int64 | |||
|  | 	AverageDuration        time.Duration | |||
|  | } | |||
|  | 
 | |||
|  | // GetStats 获取事务统计信息(预留接口) | |||
|  | func (tm *TransactionManager) GetStats() *TransactionStats { | |||
|  | 	// TODO: 实现事务统计 | |||
|  | 	return &TransactionStats{} | |||
|  | } | |||
|  | 
 | |||
|  | // RetryableTransactionOptions 可重试事务选项 | |||
|  | type RetryableTransactionOptions struct { | |||
|  | 	MaxRetries   int           // 最大重试次数 | |||
|  | 	RetryDelay   time.Duration // 重试延迟 | |||
|  | 	RetryBackoff float64       // 退避倍数 | |||
|  | } | |||
|  | 
 | |||
|  | // DefaultRetryableOptions 默认重试选项 | |||
|  | func DefaultRetryableOptions() *RetryableTransactionOptions { | |||
|  | 	return &RetryableTransactionOptions{ | |||
|  | 		MaxRetries:   3, | |||
|  | 		RetryDelay:   100 * time.Millisecond, | |||
|  | 		RetryBackoff: 2.0, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // ExecuteInTxWithRetry 在事务中执行函数(支持重试) | |||
|  | // 适用于处理死锁等临时性错误 | |||
|  | func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error { | |||
|  | 	if options == nil { | |||
|  | 		options = DefaultRetryableOptions() | |||
|  | 	} | |||
|  | 
 | |||
|  | 	var lastErr error | |||
|  | 	delay := options.RetryDelay | |||
|  | 
 | |||
|  | 	for attempt := 0; attempt <= options.MaxRetries; attempt++ { | |||
|  | 		// 检查上下文是否已取消 | |||
|  | 		if ctx.Err() != nil { | |||
|  | 			return ctx.Err() | |||
|  | 		} | |||
|  | 
 | |||
|  | 		err := tm.ExecuteInTx(ctx, fn) | |||
|  | 		if err == nil { | |||
|  | 			return nil | |||
|  | 		} | |||
|  | 
 | |||
|  | 		// 检查是否是可重试的错误(死锁、连接错误等) | |||
|  | 		if !isRetryableError(err) { | |||
|  | 			return err | |||
|  | 		} | |||
|  | 
 | |||
|  | 		lastErr = err | |||
|  | 
 | |||
|  | 		// 如果不是最后一次尝试,等待后重试 | |||
|  | 		if attempt < options.MaxRetries { | |||
|  | 			tm.logger.Warn("事务执行失败,准备重试", | |||
|  | 				zap.Int("attempt", attempt+1), | |||
|  | 				zap.Int("max_retries", options.MaxRetries), | |||
|  | 				zap.Duration("delay", delay), | |||
|  | 				zap.Error(err), | |||
|  | 			) | |||
|  | 
 | |||
|  | 			select { | |||
|  | 			case <-time.After(delay): | |||
|  | 				delay = time.Duration(float64(delay) * options.RetryBackoff) | |||
|  | 			case <-ctx.Done(): | |||
|  | 				return ctx.Err() | |||
|  | 			} | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Error("事务执行失败,已超过最大重试次数", | |||
|  | 		zap.Int("max_retries", options.MaxRetries), | |||
|  | 		zap.Error(lastErr), | |||
|  | 	) | |||
|  | 
 | |||
|  | 	return lastErr | |||
|  | } | |||
|  | 
 | |||
|  | // isRetryableError 判断是否是可重试的错误 | |||
|  | func isRetryableError(err error) bool { | |||
|  | 	if err == nil { | |||
|  | 		return false | |||
|  | 	} | |||
|  | 
 | |||
|  | 	errStr := err.Error() | |||
|  | 
 | |||
|  | 	// MySQL 死锁错误 | |||
|  | 	if contains(errStr, "Deadlock found") { | |||
|  | 		return true | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// MySQL 锁等待超时 | |||
|  | 	if contains(errStr, "Lock wait timeout exceeded") { | |||
|  | 		return true | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 连接错误 | |||
|  | 	if contains(errStr, "connection") { | |||
|  | 		return true | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 可以根据需要添加更多的可重试错误类型 | |||
|  | 	return false | |||
|  | } | |||
|  | 
 | |||
|  | // contains 检查字符串是否包含子字符串(不区分大小写) | |||
|  | func contains(s, substr string) bool { | |||
|  | 	return strings.Contains(strings.ToLower(s), strings.ToLower(substr)) | |||
|  | } |