package model import ( "context" "database/sql" "errors" "fmt" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" ) var _ WalletsModel = (*customWalletsModel)(nil) type ( // WalletsModel is an interface to be customized, add more methods here, // and implement the added methods in customWalletsModel. WalletsModel interface { walletsModel InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error } customWalletsModel struct { *defaultWalletsModel } ) var ErrBalanceNotEnough = errors.New("余额不足") var ErrVersionMismatch = errors.New("版本号不匹配,请重试") // NewWalletsModel returns a model for the database table. func NewWalletsModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WalletsModel { return &customWalletsModel{ defaultWalletsModel: newWalletsModel(conn, c, opts...), } } func (m *customWalletsModel) TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error { // 使用带 ctx 的事务处理 err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error { return fn(ctx, session) }) return err } // 更新余额的方法 func (m *customWalletsModel) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error { wallet, err := m.FindOneByUserId(ctx, userId) if err != nil { return err } // 检查当前版本是否达到了上限 newVersion := wallet.Version + 1 if newVersion > 100000000 { // 你可以根据实际需要设定上限,比如 UNSIGNED INT 的最大值 newVersion = 1 // 或者设置为其他重置值,比如 0 } // 使用乐观锁更新余额和版本 result, err := session.Exec("UPDATE wallets SET balance = balance + ?, version = ? WHERE user_id = ? AND version = ?", amount, newVersion, userId, wallet.Version) if err != nil { return err } // 检查影响的行数,确保更新成功 rowsAffected, err := result.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return ErrVersionMismatch } walletsUserIdKey := fmt.Sprintf("%s%v", cacheWalletsUserIdPrefix, userId) cacheErrors := m.DelCacheCtx(ctx, walletsUserIdKey) if cacheErrors != nil { return cacheErrors } return nil } func (m *customWalletsModel) InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) { query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?, ?, ?)", m.table, walletsRowsExpectAutoSet) ret, err := session.ExecCtx(ctx, query, wallets.UserId, wallets.Balance, wallets.Version) if err != nil { return nil, err } return ret, err }