package model import ( "context" "database/sql" "errors" "fmt" "github.com/zeromicro/go-zero/core/stores/cache" "github.com/zeromicro/go-zero/core/stores/sqlx" "strings" ) var _ WalletsModel = (*customWalletsModel)(nil) type ( // WalletsModel is an interface to be customized, add more methods here, // and implement the added methods in customWalletsModel. WalletsModel interface { walletsModel InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error FindWalletsByUserIds(ctx context.Context, userIds []int64) (map[int64]*Wallets, error) } customWalletsModel struct { *defaultWalletsModel } ) var ErrBalanceNotEnough = errors.New("余额不足") var ErrVersionMismatch = errors.New("版本号不匹配,请重试") // NewWalletsModel returns a model for the database table. func NewWalletsModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WalletsModel { return &customWalletsModel{ defaultWalletsModel: newWalletsModel(conn, c, opts...), } } func (m *customWalletsModel) TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error { // 使用带 ctx 的事务处理 err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error { return fn(ctx, session) }) return err } // 更新余额的方法 func (m *customWalletsModel) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error { wallet, err := m.FindOneByUserId(ctx, userId) if err != nil { return err } // 检查当前版本是否达到了上限 newVersion := wallet.Version + 1 if newVersion > 100000000 { // 你可以根据实际需要设定上限,比如 UNSIGNED INT 的最大值 newVersion = 1 // 或者设置为其他重置值,比如 0 } // 使用乐观锁更新余额和版本 result, err := session.Exec("UPDATE wallets SET balance = balance + ?, version = ? WHERE user_id = ? AND version = ?", amount, newVersion, userId, wallet.Version) if err != nil { return err } // 检查影响的行数,确保更新成功 rowsAffected, err := result.RowsAffected() if err != nil { return err } if rowsAffected == 0 { return ErrVersionMismatch } walletsUserIdKey := fmt.Sprintf("%s%v", cacheWalletsUserIdPrefix, userId) cacheErrors := m.DelCacheCtx(ctx, walletsUserIdKey) if cacheErrors != nil { return cacheErrors } return nil } func (m *customWalletsModel) InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) { query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?, ?, ?)", m.table, walletsRowsExpectAutoSet) ret, err := session.ExecCtx(ctx, query, wallets.UserId, wallets.Balance, wallets.Version) if err != nil { return nil, err } return ret, err } func (m *customWalletsModel) FindWalletsByUserIds(ctx context.Context, userIds []int64) (map[int64]*Wallets, error) { if len(userIds) == 0 { return make(map[int64]*Wallets), nil } queryBuilder := strings.Builder{} queryBuilder.WriteString(fmt.Sprintf("SELECT user_id, balance FROM %s WHERE user_id IN (", m.table)) placeholders := make([]string, len(userIds)) args := make([]interface{}, len(userIds)) for i, userId := range userIds { placeholders[i] = "?" args[i] = userId } queryBuilder.WriteString(strings.Join(placeholders, ",")) queryBuilder.WriteString(")") query := queryBuilder.String() var wallets []*Wallets err := m.QueryRowNoCacheCtx(ctx, &wallets, query, args...) if err != nil { return nil, err } walletMap := make(map[int64]*Wallets) for _, wallet := range wallets { walletMap[wallet.UserId] = wallet } return walletMap, nil }