tianyuan-api-server/apps/user/internal/model/walletsmodel.go
2024-10-12 20:41:55 +08:00

90 lines
2.7 KiB
Go

package model
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/sqlx"
)
var _ WalletsModel = (*customWalletsModel)(nil)
type (
// WalletsModel is an interface to be customized, add more methods here,
// and implement the added methods in customWalletsModel.
WalletsModel interface {
walletsModel
InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error)
UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error
TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error
}
customWalletsModel struct {
*defaultWalletsModel
}
)
var ErrBalanceNotEnough = errors.New("余额不足")
var ErrVersionMismatch = errors.New("版本号不匹配,请重试")
// NewWalletsModel returns a model for the database table.
func NewWalletsModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WalletsModel {
return &customWalletsModel{
defaultWalletsModel: newWalletsModel(conn, c, opts...),
}
}
func (m *customWalletsModel) TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error {
// 使用带 ctx 的事务处理
err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
return fn(ctx, session)
})
return err
}
// 更新余额的方法
func (m *customWalletsModel) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error {
wallet, err := m.FindOneByUserId(ctx, userId)
if err != nil {
return err
}
// 检查余额是否足够
if wallet.Balance+amount < 0 {
return ErrBalanceNotEnough
}
// 使用乐观锁更新余额
result, err := session.Exec("UPDATE wallets SET balance = balance + ?, version = version + 1 WHERE user_id = ? AND version = ?", amount, userId, wallet.Version)
if err != nil {
return err
}
// 检查影响的行数,确保更新成功
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return ErrVersionMismatch
}
walletsUserIdKey := fmt.Sprintf("%s%v", cacheWalletsUserIdPrefix, userId)
cacheErrors := m.DelCacheCtx(ctx, walletsUserIdKey)
if cacheErrors != nil {
return cacheErrors
}
return nil
}
func (m *customWalletsModel) InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) {
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?, ?, ?)", m.table, walletsRowsExpectAutoSet)
ret, err := session.ExecCtx(ctx, query, wallets.UserId, wallets.Balance, wallets.Version)
if err != nil {
return nil, err
}
return ret, err
}