302 lines
6.9 KiB
Go
302 lines
6.9 KiB
Go
|
|
package database
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"errors"
|
|||
|
|
"strings"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"go.uber.org/zap"
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// 自定义错误类型
|
|||
|
|
var (
|
|||
|
|
ErrTransactionRollback = errors.New("事务回滚失败")
|
|||
|
|
ErrTransactionCommit = errors.New("事务提交失败")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// 定义context key
|
|||
|
|
type txKey struct{}
|
|||
|
|
|
|||
|
|
// WithTx 将事务对象存储到context中
|
|||
|
|
func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
|
|||
|
|
return context.WithValue(ctx, txKey{}, tx)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetTx 从context中获取事务对象
|
|||
|
|
func GetTx(ctx context.Context) (*gorm.DB, bool) {
|
|||
|
|
tx, ok := ctx.Value(txKey{}).(*gorm.DB)
|
|||
|
|
return tx, ok
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TransactionManager 事务管理器
|
|||
|
|
type TransactionManager struct {
|
|||
|
|
db *gorm.DB
|
|||
|
|
logger *zap.Logger
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewTransactionManager 创建事务管理器
|
|||
|
|
func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
|
|||
|
|
return &TransactionManager{
|
|||
|
|
db: db,
|
|||
|
|
logger: logger,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ExecuteInTx 在事务中执行函数(推荐使用)
|
|||
|
|
// 自动处理事务的开启、提交和回滚
|
|||
|
|
func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
|
|||
|
|
// 检查是否已经在事务中
|
|||
|
|
if _, ok := GetTx(ctx); ok {
|
|||
|
|
// 如果已经在事务中,直接执行函数,避免嵌套事务
|
|||
|
|
return fn(ctx)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tx := tm.db.Begin()
|
|||
|
|
if tx.Error != nil {
|
|||
|
|
return tx.Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建带事务的context
|
|||
|
|
txCtx := WithTx(ctx, tx)
|
|||
|
|
|
|||
|
|
// 执行函数
|
|||
|
|
if err := fn(txCtx); err != nil {
|
|||
|
|
// 回滚事务
|
|||
|
|
if rbErr := tx.Rollback().Error; rbErr != nil {
|
|||
|
|
tm.logger.Error("事务回滚失败",
|
|||
|
|
zap.Error(err),
|
|||
|
|
zap.Error(rbErr),
|
|||
|
|
)
|
|||
|
|
return errors.Join(err, ErrTransactionRollback, rbErr)
|
|||
|
|
}
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 提交事务
|
|||
|
|
if err := tx.Commit().Error; err != nil {
|
|||
|
|
tm.logger.Error("事务提交失败", zap.Error(err))
|
|||
|
|
return errors.Join(ErrTransactionCommit, err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ExecuteInTxWithTimeout 在事务中执行函数(带超时)
|
|||
|
|
func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
|
|||
|
|
ctx, cancel := context.WithTimeout(ctx, timeout)
|
|||
|
|
defer cancel()
|
|||
|
|
|
|||
|
|
return tm.ExecuteInTx(ctx, fn)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BeginTx 开始事务(手动管理)
|
|||
|
|
func (tm *TransactionManager) BeginTx() *gorm.DB {
|
|||
|
|
return tm.db.Begin()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TxWrapper 事务包装器(手动管理)
|
|||
|
|
type TxWrapper struct {
|
|||
|
|
tx *gorm.DB
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewTxWrapper 创建事务包装器
|
|||
|
|
func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
|
|||
|
|
return &TxWrapper{
|
|||
|
|
tx: tm.BeginTx(),
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Commit 提交事务
|
|||
|
|
func (tx *TxWrapper) Commit() error {
|
|||
|
|
return tx.tx.Commit().Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Rollback 回滚事务
|
|||
|
|
func (tx *TxWrapper) Rollback() error {
|
|||
|
|
return tx.tx.Rollback().Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetDB 获取事务数据库实例
|
|||
|
|
func (tx *TxWrapper) GetDB() *gorm.DB {
|
|||
|
|
return tx.tx
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// WithTx 在事务中执行函数(兼容旧接口)
|
|||
|
|
func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
|
|||
|
|
tx := tm.BeginTx()
|
|||
|
|
defer func() {
|
|||
|
|
if r := recover(); r != nil {
|
|||
|
|
tx.Rollback()
|
|||
|
|
panic(r)
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
if err := fn(tx); err != nil {
|
|||
|
|
tx.Rollback()
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return tx.Commit().Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TransactionOptions 事务选项
|
|||
|
|
type TransactionOptions struct {
|
|||
|
|
Timeout time.Duration
|
|||
|
|
ReadOnly bool // 是否只读事务
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ExecuteInTxWithOptions 在事务中执行函数(带选项)
|
|||
|
|
func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
|
|||
|
|
// 设置事务选项
|
|||
|
|
tx := tm.db.Begin()
|
|||
|
|
if tx.Error != nil {
|
|||
|
|
return tx.Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 设置只读事务
|
|||
|
|
if options != nil && options.ReadOnly {
|
|||
|
|
tx = tx.Session(&gorm.Session{})
|
|||
|
|
// 注意:GORM的只读事务需要数据库支持,这里只是标记
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 创建带事务的context
|
|||
|
|
txCtx := WithTx(ctx, tx)
|
|||
|
|
|
|||
|
|
// 设置超时
|
|||
|
|
if options != nil && options.Timeout > 0 {
|
|||
|
|
var cancel context.CancelFunc
|
|||
|
|
txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
|
|||
|
|
defer cancel()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 执行函数
|
|||
|
|
if err := fn(txCtx); err != nil {
|
|||
|
|
// 回滚事务
|
|||
|
|
if rbErr := tx.Rollback().Error; rbErr != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 提交事务
|
|||
|
|
return tx.Commit().Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// TransactionStats 事务统计信息
|
|||
|
|
type TransactionStats struct {
|
|||
|
|
TotalTransactions int64
|
|||
|
|
SuccessfulTransactions int64
|
|||
|
|
FailedTransactions int64
|
|||
|
|
AverageDuration time.Duration
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetStats 获取事务统计信息(预留接口)
|
|||
|
|
func (tm *TransactionManager) GetStats() *TransactionStats {
|
|||
|
|
// TODO: 实现事务统计
|
|||
|
|
return &TransactionStats{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// RetryableTransactionOptions 可重试事务选项
|
|||
|
|
type RetryableTransactionOptions struct {
|
|||
|
|
MaxRetries int // 最大重试次数
|
|||
|
|
RetryDelay time.Duration // 重试延迟
|
|||
|
|
RetryBackoff float64 // 退避倍数
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DefaultRetryableOptions 默认重试选项
|
|||
|
|
func DefaultRetryableOptions() *RetryableTransactionOptions {
|
|||
|
|
return &RetryableTransactionOptions{
|
|||
|
|
MaxRetries: 3,
|
|||
|
|
RetryDelay: 100 * time.Millisecond,
|
|||
|
|
RetryBackoff: 2.0,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ExecuteInTxWithRetry 在事务中执行函数(支持重试)
|
|||
|
|
// 适用于处理死锁等临时性错误
|
|||
|
|
func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
|
|||
|
|
if options == nil {
|
|||
|
|
options = DefaultRetryableOptions()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
var lastErr error
|
|||
|
|
delay := options.RetryDelay
|
|||
|
|
|
|||
|
|
for attempt := 0; attempt <= options.MaxRetries; attempt++ {
|
|||
|
|
// 检查上下文是否已取消
|
|||
|
|
if ctx.Err() != nil {
|
|||
|
|
return ctx.Err()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
err := tm.ExecuteInTx(ctx, fn)
|
|||
|
|
if err == nil {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 检查是否是可重试的错误(死锁、连接错误等)
|
|||
|
|
if !isRetryableError(err) {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
lastErr = err
|
|||
|
|
|
|||
|
|
// 如果不是最后一次尝试,等待后重试
|
|||
|
|
if attempt < options.MaxRetries {
|
|||
|
|
tm.logger.Warn("事务执行失败,准备重试",
|
|||
|
|
zap.Int("attempt", attempt+1),
|
|||
|
|
zap.Int("max_retries", options.MaxRetries),
|
|||
|
|
zap.Duration("delay", delay),
|
|||
|
|
zap.Error(err),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
select {
|
|||
|
|
case <-time.After(delay):
|
|||
|
|
delay = time.Duration(float64(delay) * options.RetryBackoff)
|
|||
|
|
case <-ctx.Done():
|
|||
|
|
return ctx.Err()
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
tm.logger.Error("事务执行失败,已超过最大重试次数",
|
|||
|
|
zap.Int("max_retries", options.MaxRetries),
|
|||
|
|
zap.Error(lastErr),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return lastErr
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// isRetryableError 判断是否是可重试的错误
|
|||
|
|
func isRetryableError(err error) bool {
|
|||
|
|
if err == nil {
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
errStr := err.Error()
|
|||
|
|
|
|||
|
|
// MySQL 死锁错误
|
|||
|
|
if contains(errStr, "Deadlock found") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// MySQL 锁等待超时
|
|||
|
|
if contains(errStr, "Lock wait timeout exceeded") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 连接错误
|
|||
|
|
if contains(errStr, "connection") {
|
|||
|
|
return true
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 可以根据需要添加更多的可重试错误类型
|
|||
|
|
return false
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// contains 检查字符串是否包含子字符串(不区分大小写)
|
|||
|
|
func contains(s, substr string) bool {
|
|||
|
|
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
|||
|
|
}
|