temp
This commit is contained in:
301
internal/shared/database/transaction.go
Normal file
301
internal/shared/database/transaction.go
Normal file
@@ -0,0 +1,301 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// 自定义错误类型
|
||||
var (
|
||||
ErrTransactionRollback = errors.New("事务回滚失败")
|
||||
ErrTransactionCommit = errors.New("事务提交失败")
|
||||
)
|
||||
|
||||
// 定义context key
|
||||
type txKey struct{}
|
||||
|
||||
// WithTx 将事务对象存储到context中
|
||||
func WithTx(ctx context.Context, tx *gorm.DB) context.Context {
|
||||
return context.WithValue(ctx, txKey{}, tx)
|
||||
}
|
||||
|
||||
// GetTx 从context中获取事务对象
|
||||
func GetTx(ctx context.Context) (*gorm.DB, bool) {
|
||||
tx, ok := ctx.Value(txKey{}).(*gorm.DB)
|
||||
return tx, ok
|
||||
}
|
||||
|
||||
// TransactionManager 事务管理器
|
||||
type TransactionManager struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTransactionManager 创建事务管理器
|
||||
func NewTransactionManager(db *gorm.DB, logger *zap.Logger) *TransactionManager {
|
||||
return &TransactionManager{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInTx 在事务中执行函数(推荐使用)
|
||||
// 自动处理事务的开启、提交和回滚
|
||||
func (tm *TransactionManager) ExecuteInTx(ctx context.Context, fn func(context.Context) error) error {
|
||||
// 检查是否已经在事务中
|
||||
if _, ok := GetTx(ctx); ok {
|
||||
// 如果已经在事务中,直接执行函数,避免嵌套事务
|
||||
return fn(ctx)
|
||||
}
|
||||
|
||||
tx := tm.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// 创建带事务的context
|
||||
txCtx := WithTx(ctx, tx)
|
||||
|
||||
// 执行函数
|
||||
if err := fn(txCtx); err != nil {
|
||||
// 回滚事务
|
||||
if rbErr := tx.Rollback().Error; rbErr != nil {
|
||||
tm.logger.Error("事务回滚失败",
|
||||
zap.Error(err),
|
||||
zap.Error(rbErr),
|
||||
)
|
||||
return errors.Join(err, ErrTransactionRollback, rbErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
tm.logger.Error("事务提交失败", zap.Error(err))
|
||||
return errors.Join(ErrTransactionCommit, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteInTxWithTimeout 在事务中执行函数(带超时)
|
||||
func (tm *TransactionManager) ExecuteInTxWithTimeout(ctx context.Context, timeout time.Duration, fn func(context.Context) error) error {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
return tm.ExecuteInTx(ctx, fn)
|
||||
}
|
||||
|
||||
// BeginTx 开始事务(手动管理)
|
||||
func (tm *TransactionManager) BeginTx() *gorm.DB {
|
||||
return tm.db.Begin()
|
||||
}
|
||||
|
||||
// TxWrapper 事务包装器(手动管理)
|
||||
type TxWrapper struct {
|
||||
tx *gorm.DB
|
||||
}
|
||||
|
||||
// NewTxWrapper 创建事务包装器
|
||||
func (tm *TransactionManager) NewTxWrapper() *TxWrapper {
|
||||
return &TxWrapper{
|
||||
tx: tm.BeginTx(),
|
||||
}
|
||||
}
|
||||
|
||||
// Commit 提交事务
|
||||
func (tx *TxWrapper) Commit() error {
|
||||
return tx.tx.Commit().Error
|
||||
}
|
||||
|
||||
// Rollback 回滚事务
|
||||
func (tx *TxWrapper) Rollback() error {
|
||||
return tx.tx.Rollback().Error
|
||||
}
|
||||
|
||||
// GetDB 获取事务数据库实例
|
||||
func (tx *TxWrapper) GetDB() *gorm.DB {
|
||||
return tx.tx
|
||||
}
|
||||
|
||||
// WithTx 在事务中执行函数(兼容旧接口)
|
||||
func (tm *TransactionManager) WithTx(fn func(*gorm.DB) error) error {
|
||||
tx := tm.BeginTx()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// TransactionOptions 事务选项
|
||||
type TransactionOptions struct {
|
||||
Timeout time.Duration
|
||||
ReadOnly bool // 是否只读事务
|
||||
}
|
||||
|
||||
// ExecuteInTxWithOptions 在事务中执行函数(带选项)
|
||||
func (tm *TransactionManager) ExecuteInTxWithOptions(ctx context.Context, options *TransactionOptions, fn func(context.Context) error) error {
|
||||
// 设置事务选项
|
||||
tx := tm.db.Begin()
|
||||
if tx.Error != nil {
|
||||
return tx.Error
|
||||
}
|
||||
|
||||
// 设置只读事务
|
||||
if options != nil && options.ReadOnly {
|
||||
tx = tx.Session(&gorm.Session{})
|
||||
// 注意:GORM的只读事务需要数据库支持,这里只是标记
|
||||
}
|
||||
|
||||
// 创建带事务的context
|
||||
txCtx := WithTx(ctx, tx)
|
||||
|
||||
// 设置超时
|
||||
if options != nil && options.Timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
txCtx, cancel = context.WithTimeout(txCtx, options.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
// 执行函数
|
||||
if err := fn(txCtx); err != nil {
|
||||
// 回滚事务
|
||||
if rbErr := tx.Rollback().Error; rbErr != nil {
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
return tx.Commit().Error
|
||||
}
|
||||
|
||||
// TransactionStats 事务统计信息
|
||||
type TransactionStats struct {
|
||||
TotalTransactions int64
|
||||
SuccessfulTransactions int64
|
||||
FailedTransactions int64
|
||||
AverageDuration time.Duration
|
||||
}
|
||||
|
||||
// GetStats 获取事务统计信息(预留接口)
|
||||
func (tm *TransactionManager) GetStats() *TransactionStats {
|
||||
// TODO: 实现事务统计
|
||||
return &TransactionStats{}
|
||||
}
|
||||
|
||||
// RetryableTransactionOptions 可重试事务选项
|
||||
type RetryableTransactionOptions struct {
|
||||
MaxRetries int // 最大重试次数
|
||||
RetryDelay time.Duration // 重试延迟
|
||||
RetryBackoff float64 // 退避倍数
|
||||
}
|
||||
|
||||
// DefaultRetryableOptions 默认重试选项
|
||||
func DefaultRetryableOptions() *RetryableTransactionOptions {
|
||||
return &RetryableTransactionOptions{
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 100 * time.Millisecond,
|
||||
RetryBackoff: 2.0,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteInTxWithRetry 在事务中执行函数(支持重试)
|
||||
// 适用于处理死锁等临时性错误
|
||||
func (tm *TransactionManager) ExecuteInTxWithRetry(ctx context.Context, options *RetryableTransactionOptions, fn func(context.Context) error) error {
|
||||
if options == nil {
|
||||
options = DefaultRetryableOptions()
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
delay := options.RetryDelay
|
||||
|
||||
for attempt := 0; attempt <= options.MaxRetries; attempt++ {
|
||||
// 检查上下文是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
err := tm.ExecuteInTx(ctx, fn)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否是可重试的错误(死锁、连接错误等)
|
||||
if !isRetryableError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
|
||||
// 如果不是最后一次尝试,等待后重试
|
||||
if attempt < options.MaxRetries {
|
||||
tm.logger.Warn("事务执行失败,准备重试",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Int("max_retries", options.MaxRetries),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Error(err),
|
||||
)
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
delay = time.Duration(float64(delay) * options.RetryBackoff)
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tm.logger.Error("事务执行失败,已超过最大重试次数",
|
||||
zap.Int("max_retries", options.MaxRetries),
|
||||
zap.Error(lastErr),
|
||||
)
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// isRetryableError 判断是否是可重试的错误
|
||||
func isRetryableError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
|
||||
// MySQL 死锁错误
|
||||
if contains(errStr, "Deadlock found") {
|
||||
return true
|
||||
}
|
||||
|
||||
// MySQL 锁等待超时
|
||||
if contains(errStr, "Lock wait timeout exceeded") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 连接错误
|
||||
if contains(errStr, "connection") {
|
||||
return true
|
||||
}
|
||||
|
||||
// 可以根据需要添加更多的可重试错误类型
|
||||
return false
|
||||
}
|
||||
|
||||
// contains 检查字符串是否包含子字符串(不区分大小写)
|
||||
func contains(s, substr string) bool {
|
||||
return strings.Contains(strings.ToLower(s), strings.ToLower(substr))
|
||||
}
|
||||
Reference in New Issue
Block a user