Files
tyapi-server/internal/shared/database/transaction.go
2025-07-20 20:53:26 +08:00

302 lines
6.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package database
import (
"context"
"errors"
"strings"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
)
// 自定义错误类型
var (
ErrTransactionRollback = errors.New("事务回滚失败")
ErrTransactionCommit = errors.New("事务提交失败")
)
// 定义context key
type txKey struct{}
// WithTx 将事务对象存储到context中
func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
return context.WithValue(ctx, txKey{}, tx)
}
// GetTx 从context中获取事务对象
func GetTx(ctx context.Context) (*gorm.DB, bool) {
tx, ok := ctx.Value(txKey{}).(*gorm.DB)
return tx, ok
}
// TransactionManager 事务管理器
type TransactionManager struct {
db *gorm.DB
logger *zap.Logger
}
// NewTransactionManager 创建事务管理器
func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
return &TransactionManager{
db: db,
logger: logger,
}
}
// ExecuteInTx 在事务中执行函数(推荐使用)
// 自动处理事务的开启、提交和回滚
func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
// 检查是否已经在事务中
if _, ok := GetTx(ctx); ok {
// 如果已经在事务中,直接执行函数,避免嵌套事务
return fn(ctx)
}
tx := tm.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 创建带事务的context
txCtx := WithTx(ctx, tx)
// 执行函数
if err := fn(txCtx); err != nil {
// 回滚事务
if rbErr := tx.Rollback().Error; rbErr != nil {
tm.logger.Error("事务回滚失败",
zap.Error(err),
zap.Error(rbErr),
)
return errors.Join(err, ErrTransactionRollback, rbErr)
}
return err
}
// 提交事务
if err := tx.Commit().Error; err != nil {
tm.logger.Error("事务提交失败", zap.Error(err))
return errors.Join(ErrTransactionCommit, err)
}
return nil
}
// ExecuteInTxWithTimeout 在事务中执行函数(带超时)
func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return tm.ExecuteInTx(ctx, fn)
}
// BeginTx 开始事务(手动管理)
func (tm *TransactionManager) BeginTx() *gorm.DB {
return tm.db.Begin()
}
// TxWrapper 事务包装器(手动管理)
type TxWrapper struct {
tx *gorm.DB
}
// NewTxWrapper 创建事务包装器
func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
return &TxWrapper{
tx: tm.BeginTx(),
}
}
// Commit 提交事务
func (tx *TxWrapper) Commit() error {
return tx.tx.Commit().Error
}
// Rollback 回滚事务
func (tx *TxWrapper) Rollback() error {
return tx.tx.Rollback().Error
}
// GetDB 获取事务数据库实例
func (tx *TxWrapper) GetDB() *gorm.DB {
return tx.tx
}
// WithTx 在事务中执行函数(兼容旧接口)
func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
tx := tm.BeginTx()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}
// TransactionOptions 事务选项
type TransactionOptions struct {
Timeout time.Duration
ReadOnly bool // 是否只读事务
}
// ExecuteInTxWithOptions 在事务中执行函数(带选项)
func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
// 设置事务选项
tx := tm.db.Begin()
if tx.Error != nil {
return tx.Error
}
// 设置只读事务
if options != nil && options.ReadOnly {
tx = tx.Session(&gorm.Session{})
// 注意GORM的只读事务需要数据库支持这里只是标记
}
// 创建带事务的context
txCtx := WithTx(ctx, tx)
// 设置超时
if options != nil && options.Timeout > 0 {
var cancel context.CancelFunc
txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
defer cancel()
}
// 执行函数
if err := fn(txCtx); err != nil {
// 回滚事务
if rbErr := tx.Rollback().Error; rbErr != nil {
return err
}
return err
}
// 提交事务
return tx.Commit().Error
}
// TransactionStats 事务统计信息
type TransactionStats struct {
TotalTransactions int64
SuccessfulTransactions int64
FailedTransactions int64
AverageDuration time.Duration
}
// GetStats 获取事务统计信息(预留接口)
func (tm *TransactionManager) GetStats() *TransactionStats {
// TODO: 实现事务统计
return &TransactionStats{}
}
// RetryableTransactionOptions 可重试事务选项
type RetryableTransactionOptions struct {
MaxRetries int // 最大重试次数
RetryDelay time.Duration // 重试延迟
RetryBackoff float64 // 退避倍数
}
// DefaultRetryableOptions 默认重试选项
func DefaultRetryableOptions() *RetryableTransactionOptions {
return &RetryableTransactionOptions{
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
RetryBackoff: 2.0,
}
}
// ExecuteInTxWithRetry 在事务中执行函数(支持重试)
// 适用于处理死锁等临时性错误
func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
if options == nil {
options = DefaultRetryableOptions()
}
var lastErr error
delay := options.RetryDelay
for attempt := 0; attempt <= options.MaxRetries; attempt++ {
// 检查上下文是否已取消
if ctx.Err() != nil {
return ctx.Err()
}
err := tm.ExecuteInTx(ctx, fn)
if err == nil {
return nil
}
// 检查是否是可重试的错误(死锁、连接错误等)
if !isRetryableError(err) {
return err
}
lastErr = err
// 如果不是最后一次尝试,等待后重试
if attempt < options.MaxRetries {
tm.logger.Warn("事务执行失败,准备重试",
zap.Int("attempt", attempt+1),
zap.Int("max_retries", options.MaxRetries),
zap.Duration("delay", delay),
zap.Error(err),
)
select {
case <-time.After(delay):
delay = time.Duration(float64(delay) * options.RetryBackoff)
case <-ctx.Done():
return ctx.Err()
}
}
}
tm.logger.Error("事务执行失败,已超过最大重试次数",
zap.Int("max_retries", options.MaxRetries),
zap.Error(lastErr),
)
return lastErr
}
// isRetryableError 判断是否是可重试的错误
func isRetryableError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// MySQL 死锁错误
if contains(errStr, "Deadlock found") {
return true
}
// MySQL 锁等待超时
if contains(errStr, "Lock wait timeout exceeded") {
return true
}
// 连接错误
if contains(errStr, "connection") {
return true
}
// 可以根据需要添加更多的可重试错误类型
return false
}
// contains 检查字符串是否包含子字符串(不区分大小写)
func contains(s, substr string) bool {
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
}