Files
tyapi-server/internal/shared/database/transaction.go

302 lines
6.9 KiB
Go
Raw Normal View History

2025-07-20 20:53:26 +08:00
package database
import (
"context"
"errors"
"strings"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
)
// 自定义错误类型
var (
ErrTransactionRollback = errors.New("事务回滚失败")
ErrTransactionCommit = errors.New("事务提交失败")
)
// 定义context key
type txKey struct{}
// WithTx 将事务对象存储到context中
func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
return context.WithValue(ctx, txKey{}, tx)
}
// GetTx 从context中获取事务对象
func GetTx(ctx context.Context) (*gorm.DB, bool) {
tx, ok := ctx.Value(txKey{}).(*gorm.DB)
return tx, ok
}
// TransactionManager 事务管理器
type TransactionManager struct {
db *gorm.DB
logger *zap.Logger
}
// NewTransactionManager 创建事务管理器
func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
return &TransactionManager{
db: db,
logger: logger,
}
}
// ExecuteInTx 在事务中执行函数(推荐使用)
// 自动处理事务的开启、提交和回滚
func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
// 检查是否已经在事务中
if _, ok := GetTx(ctx); ok {
// 如果已经在事务中,直接执行函数,避免嵌套事务
return fn(ctx)
}
tx := tm.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 创建带事务的context
txCtx := WithTx(ctx, tx)
// 执行函数
if err := fn(txCtx); err != nil {
// 回滚事务
if rbErr := tx.Rollback().Error; rbErr != nil {
tm.logger.Error("事务回滚失败",
zap.Error(err),
zap.Error(rbErr),
)
return errors.Join(err, ErrTransactionRollback, rbErr)
}
return err
}
// 提交事务
if err := tx.Commit().Error; err != nil {
tm.logger.Error("事务提交失败", zap.Error(err))
return errors.Join(ErrTransactionCommit, err)
}
return nil
}
// ExecuteInTxWithTimeout 在事务中执行函数(带超时)
func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return tm.ExecuteInTx(ctx, fn)
}
// BeginTx 开始事务(手动管理)
func (tm *TransactionManager) BeginTx() *gorm.DB {
return tm.db.Begin()
}
// TxWrapper 事务包装器(手动管理)
type TxWrapper struct {
tx *gorm.DB
}
// NewTxWrapper 创建事务包装器
func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
return &TxWrapper{
tx: tm.BeginTx(),
}
}
// Commit 提交事务
func (tx *TxWrapper) Commit() error {
return tx.tx.Commit().Error
}
// Rollback 回滚事务
func (tx *TxWrapper) Rollback() error {
return tx.tx.Rollback().Error
}
// GetDB 获取事务数据库实例
func (tx *TxWrapper) GetDB() *gorm.DB {
return tx.tx
}
// WithTx 在事务中执行函数(兼容旧接口)
func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
tx := tm.BeginTx()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// TransactionOptions 事务选项
type TransactionOptions struct {
Timeout time.Duration
ReadOnly bool // 是否只读事务
}
// ExecuteInTxWithOptions 在事务中执行函数(带选项)
func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
// 设置事务选项
tx := tm.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 设置只读事务
if options != nil && options.ReadOnly {
tx = tx.Session(&gorm.Session{})
// 注意GORM的只读事务需要数据库支持这里只是标记
}
// 创建带事务的context
txCtx := WithTx(ctx, tx)
// 设置超时
if options != nil && options.Timeout > 0 {
var cancel context.CancelFunc
txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
defer cancel()
}
// 执行函数
if err := fn(txCtx); err != nil {
// 回滚事务
if rbErr := tx.Rollback().Error; rbErr != nil {
return err
}
return err
}
// 提交事务
return tx.Commit().Error
}
// TransactionStats 事务统计信息
type TransactionStats struct {
TotalTransactions int64
SuccessfulTransactions int64
FailedTransactions int64
AverageDuration time.Duration
}
// GetStats 获取事务统计信息(预留接口)
func (tm *TransactionManager) GetStats() *TransactionStats {
// TODO: 实现事务统计
return &TransactionStats{}
}
// RetryableTransactionOptions 可重试事务选项
type RetryableTransactionOptions struct {
MaxRetries int // 最大重试次数
RetryDelay time.Duration // 重试延迟
RetryBackoff float64 // 退避倍数
}
// DefaultRetryableOptions 默认重试选项
func DefaultRetryableOptions() *RetryableTransactionOptions {
return &RetryableTransactionOptions{
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
RetryBackoff: 2.0,
}
}
// ExecuteInTxWithRetry 在事务中执行函数(支持重试)
// 适用于处理死锁等临时性错误
func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
if options == nil {
options = DefaultRetryableOptions()
}
var lastErr error
delay := options.RetryDelay
for attempt := 0; attempt <= options.MaxRetries; attempt++ {
// 检查上下文是否已取消
if ctx.Err() != nil {
return ctx.Err()
}
err := tm.ExecuteInTx(ctx, fn)
if err == nil {
return nil
}
// 检查是否是可重试的错误(死锁、连接错误等)
if !isRetryableError(err) {
return err
}
lastErr = err
// 如果不是最后一次尝试,等待后重试
if attempt < options.MaxRetries {
tm.logger.Warn("事务执行失败,准备重试",
zap.Int("attempt", attempt+1),
zap.Int("max_retries", options.MaxRetries),
zap.Duration("delay", delay),
zap.Error(err),
)
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * options.RetryBackoff)
case <-ctx.Done():
return ctx.Err()
}
}
}
tm.logger.Error("事务执行失败,已超过最大重试次数",
zap.Int("max_retries", options.MaxRetries),
zap.Error(lastErr),
)
return lastErr
}
// isRetryableError 判断是否是可重试的错误
func isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// MySQL 死锁错误
if contains(errStr, "Deadlock found") {
return true
}
// MySQL 锁等待超时
if contains(errStr, "Lock wait timeout exceeded") {
return true
}
// 连接错误
if contains(errStr, "connection") {
return true
}
// 可以根据需要添加更多的可重试错误类型
return false
}
// contains 检查字符串是否包含子字符串(不区分大小写)
func contains(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}