302 lines
6.9 KiB
Go
302 lines
6.9 KiB
Go
package database
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"strings"
|
||
"time"
|
||
|
||
"go.uber.org/zap"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// 自定义错误类型
|
||
var (
|
||
ErrTransactionRollback = errors.New("事务回滚失败")
|
||
ErrTransactionCommit = errors.New("事务提交失败")
|
||
)
|
||
|
||
// 定义context key
|
||
type txKey struct{}
|
||
|
||
// WithTx 将事务对象存储到context中
|
||
func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
|
||
return context.WithValue(ctx, txKey{}, tx)
|
||
}
|
||
|
||
// GetTx 从context中获取事务对象
|
||
func GetTx(ctx context.Context) (*gorm.DB, bool) {
|
||
tx, ok := ctx.Value(txKey{}).(*gorm.DB)
|
||
return tx, ok
|
||
}
|
||
|
||
// TransactionManager 事务管理器
|
||
type TransactionManager struct {
|
||
db *gorm.DB
|
||
logger *zap.Logger
|
||
}
|
||
|
||
// NewTransactionManager 创建事务管理器
|
||
func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
|
||
return &TransactionManager{
|
||
db: db,
|
||
logger: logger,
|
||
}
|
||
}
|
||
|
||
// ExecuteInTx 在事务中执行函数(推荐使用)
|
||
// 自动处理事务的开启、提交和回滚
|
||
func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
|
||
// 检查是否已经在事务中
|
||
if _, ok := GetTx(ctx); ok {
|
||
// 如果已经在事务中,直接执行函数,避免嵌套事务
|
||
return fn(ctx)
|
||
}
|
||
|
||
tx := tm.db.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 创建带事务的context
|
||
txCtx := WithTx(ctx, tx)
|
||
|
||
// 执行函数
|
||
if err := fn(txCtx); err != nil {
|
||
// 回滚事务
|
||
if rbErr := tx.Rollback().Error; rbErr != nil {
|
||
tm.logger.Error("事务回滚失败",
|
||
zap.Error(err),
|
||
zap.Error(rbErr),
|
||
)
|
||
return errors.Join(err, ErrTransactionRollback, rbErr)
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
if err := tx.Commit().Error; err != nil {
|
||
tm.logger.Error("事务提交失败", zap.Error(err))
|
||
return errors.Join(ErrTransactionCommit, err)
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// ExecuteInTxWithTimeout 在事务中执行函数(带超时)
|
||
func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
|
||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||
defer cancel()
|
||
|
||
return tm.ExecuteInTx(ctx, fn)
|
||
}
|
||
|
||
// BeginTx 开始事务(手动管理)
|
||
func (tm *TransactionManager) BeginTx() *gorm.DB {
|
||
return tm.db.Begin()
|
||
}
|
||
|
||
// TxWrapper 事务包装器(手动管理)
|
||
type TxWrapper struct {
|
||
tx *gorm.DB
|
||
}
|
||
|
||
// NewTxWrapper 创建事务包装器
|
||
func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
|
||
return &TxWrapper{
|
||
tx: tm.BeginTx(),
|
||
}
|
||
}
|
||
|
||
// Commit 提交事务
|
||
func (tx *TxWrapper) Commit() error {
|
||
return tx.tx.Commit().Error
|
||
}
|
||
|
||
// Rollback 回滚事务
|
||
func (tx *TxWrapper) Rollback() error {
|
||
return tx.tx.Rollback().Error
|
||
}
|
||
|
||
// GetDB 获取事务数据库实例
|
||
func (tx *TxWrapper) GetDB() *gorm.DB {
|
||
return tx.tx
|
||
}
|
||
|
||
// WithTx 在事务中执行函数(兼容旧接口)
|
||
func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
|
||
tx := tm.BeginTx()
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
tx.Rollback()
|
||
panic(r)
|
||
}
|
||
}()
|
||
|
||
if err := fn(tx); err != nil {
|
||
tx.Rollback()
|
||
return err
|
||
}
|
||
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
// TransactionOptions 事务选项
|
||
type TransactionOptions struct {
|
||
Timeout time.Duration
|
||
ReadOnly bool // 是否只读事务
|
||
}
|
||
|
||
// ExecuteInTxWithOptions 在事务中执行函数(带选项)
|
||
func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
|
||
// 设置事务选项
|
||
tx := tm.db.Begin()
|
||
if tx.Error != nil {
|
||
return tx.Error
|
||
}
|
||
|
||
// 设置只读事务
|
||
if options != nil && options.ReadOnly {
|
||
tx = tx.Session(&gorm.Session{})
|
||
// 注意:GORM的只读事务需要数据库支持,这里只是标记
|
||
}
|
||
|
||
// 创建带事务的context
|
||
txCtx := WithTx(ctx, tx)
|
||
|
||
// 设置超时
|
||
if options != nil && options.Timeout > 0 {
|
||
var cancel context.CancelFunc
|
||
txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
|
||
defer cancel()
|
||
}
|
||
|
||
// 执行函数
|
||
if err := fn(txCtx); err != nil {
|
||
// 回滚事务
|
||
if rbErr := tx.Rollback().Error; rbErr != nil {
|
||
return err
|
||
}
|
||
return err
|
||
}
|
||
|
||
// 提交事务
|
||
return tx.Commit().Error
|
||
}
|
||
|
||
// TransactionStats 事务统计信息
|
||
type TransactionStats struct {
|
||
TotalTransactions int64
|
||
SuccessfulTransactions int64
|
||
FailedTransactions int64
|
||
AverageDuration time.Duration
|
||
}
|
||
|
||
// GetStats 获取事务统计信息(预留接口)
|
||
func (tm *TransactionManager) GetStats() *TransactionStats {
|
||
// TODO: 实现事务统计
|
||
return &TransactionStats{}
|
||
}
|
||
|
||
// RetryableTransactionOptions 可重试事务选项
|
||
type RetryableTransactionOptions struct {
|
||
MaxRetries int // 最大重试次数
|
||
RetryDelay time.Duration // 重试延迟
|
||
RetryBackoff float64 // 退避倍数
|
||
}
|
||
|
||
// DefaultRetryableOptions 默认重试选项
|
||
func DefaultRetryableOptions() *RetryableTransactionOptions {
|
||
return &RetryableTransactionOptions{
|
||
MaxRetries: 3,
|
||
RetryDelay: 100 * time.Millisecond,
|
||
RetryBackoff: 2.0,
|
||
}
|
||
}
|
||
|
||
// ExecuteInTxWithRetry 在事务中执行函数(支持重试)
|
||
// 适用于处理死锁等临时性错误
|
||
func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
|
||
if options == nil {
|
||
options = DefaultRetryableOptions()
|
||
}
|
||
|
||
var lastErr error
|
||
delay := options.RetryDelay
|
||
|
||
for attempt := 0; attempt <= options.MaxRetries; attempt++ {
|
||
// 检查上下文是否已取消
|
||
if ctx.Err() != nil {
|
||
return ctx.Err()
|
||
}
|
||
|
||
err := tm.ExecuteInTx(ctx, fn)
|
||
if err == nil {
|
||
return nil
|
||
}
|
||
|
||
// 检查是否是可重试的错误(死锁、连接错误等)
|
||
if !isRetryableError(err) {
|
||
return err
|
||
}
|
||
|
||
lastErr = err
|
||
|
||
// 如果不是最后一次尝试,等待后重试
|
||
if attempt < options.MaxRetries {
|
||
tm.logger.Warn("事务执行失败,准备重试",
|
||
zap.Int("attempt", attempt+1),
|
||
zap.Int("max_retries", options.MaxRetries),
|
||
zap.Duration("delay", delay),
|
||
zap.Error(err),
|
||
)
|
||
|
||
select {
|
||
case <-time.After(delay):
|
||
delay = time.Duration(float64(delay) * options.RetryBackoff)
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
}
|
||
}
|
||
}
|
||
|
||
tm.logger.Error("事务执行失败,已超过最大重试次数",
|
||
zap.Int("max_retries", options.MaxRetries),
|
||
zap.Error(lastErr),
|
||
)
|
||
|
||
return lastErr
|
||
}
|
||
|
||
// isRetryableError 判断是否是可重试的错误
|
||
func isRetryableError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
errStr := err.Error()
|
||
|
||
// MySQL 死锁错误
|
||
if contains(errStr, "Deadlock found") {
|
||
return true
|
||
}
|
||
|
||
// MySQL 锁等待超时
|
||
if contains(errStr, "Lock wait timeout exceeded") {
|
||
return true
|
||
}
|
||
|
||
// 连接错误
|
||
if contains(errStr, "connection") {
|
||
return true
|
||
}
|
||
|
||
// 可以根据需要添加更多的可重试错误类型
|
||
return false
|
||
}
|
||
|
||
// contains 检查字符串是否包含子字符串(不区分大小写)
|
||
func contains(s, substr string) bool {
|
||
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
||
}
|