90 lines
2.7 KiB
Go
90 lines
2.7 KiB
Go
package model
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
|
)
|
|
|
|
var _ WalletsModel = (*customWalletsModel)(nil)
|
|
|
|
type (
|
|
// WalletsModel is an interface to be customized, add more methods here,
|
|
// and implement the added methods in customWalletsModel.
|
|
WalletsModel interface {
|
|
walletsModel
|
|
InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error)
|
|
UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error
|
|
TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error
|
|
}
|
|
|
|
customWalletsModel struct {
|
|
*defaultWalletsModel
|
|
}
|
|
)
|
|
|
|
var ErrBalanceNotEnough = errors.New("余额不足")
|
|
var ErrVersionMismatch = errors.New("版本号不匹配,请重试")
|
|
|
|
// NewWalletsModel returns a model for the database table.
|
|
func NewWalletsModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WalletsModel {
|
|
return &customWalletsModel{
|
|
defaultWalletsModel: newWalletsModel(conn, c, opts...),
|
|
}
|
|
}
|
|
func (m *customWalletsModel) TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error {
|
|
// 使用带 ctx 的事务处理
|
|
err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
|
|
return fn(ctx, session)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// 更新余额的方法
|
|
func (m *customWalletsModel) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error {
|
|
|
|
wallet, err := m.FindOneByUserId(ctx, userId)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查余额是否足够
|
|
if wallet.Balance+amount < 0 {
|
|
return ErrBalanceNotEnough
|
|
}
|
|
|
|
// 使用乐观锁更新余额
|
|
result, err := session.Exec("UPDATE wallets SET balance = balance + ?, version = version + 1 WHERE user_id = ? AND version = ?", amount, userId, wallet.Version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查影响的行数,确保更新成功
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rowsAffected == 0 {
|
|
return ErrVersionMismatch
|
|
}
|
|
|
|
walletsUserIdKey := fmt.Sprintf("%s%v", cacheWalletsUserIdPrefix, userId)
|
|
cacheErrors := m.DelCacheCtx(ctx, walletsUserIdKey)
|
|
if cacheErrors != nil {
|
|
return cacheErrors
|
|
}
|
|
return nil
|
|
}
|
|
func (m *customWalletsModel) InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) {
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?, ?, ?)", m.table, walletsRowsExpectAutoSet)
|
|
ret, err := session.ExecCtx(ctx, query, wallets.UserId, wallets.Balance, wallets.Version)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, err
|
|
}
|