133 lines
4.1 KiB
Go
133 lines
4.1 KiB
Go
package model
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"github.com/zeromicro/go-zero/core/stores/cache"
|
|
"github.com/zeromicro/go-zero/core/stores/sqlx"
|
|
"strings"
|
|
)
|
|
|
|
var _ WalletsModel = (*customWalletsModel)(nil)
|
|
|
|
type (
|
|
// WalletsModel is an interface to be customized, add more methods here,
|
|
// and implement the added methods in customWalletsModel.
|
|
WalletsModel interface {
|
|
walletsModel
|
|
InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error)
|
|
UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error
|
|
TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error
|
|
FindWalletsByUserIds(ctx context.Context, userIds []int64) (map[int64]*Wallets, error)
|
|
}
|
|
|
|
customWalletsModel struct {
|
|
*defaultWalletsModel
|
|
}
|
|
)
|
|
|
|
var ErrBalanceNotEnough = errors.New("余额不足")
|
|
var ErrVersionMismatch = errors.New("版本号不匹配,请重试")
|
|
|
|
// NewWalletsModel returns a model for the database table.
|
|
func NewWalletsModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WalletsModel {
|
|
return &customWalletsModel{
|
|
defaultWalletsModel: newWalletsModel(conn, c, opts...),
|
|
}
|
|
}
|
|
func (m *customWalletsModel) TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error {
|
|
// 使用带 ctx 的事务处理
|
|
err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
|
|
return fn(ctx, session)
|
|
})
|
|
return err
|
|
}
|
|
|
|
// 更新余额的方法
|
|
func (m *customWalletsModel) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error {
|
|
|
|
//wallet, err := m.FindOneByUserId(ctx, userId)
|
|
//if err != nil {
|
|
// return err
|
|
//}
|
|
var wallet Wallets
|
|
query := fmt.Sprintf("SELECT %s FROM %s WHERE `user_id` = ? limit 1", walletsRows, m.table)
|
|
if err := m.QueryRowNoCacheCtx(ctx, &wallet, query, userId); err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查当前版本是否达到了上限
|
|
newVersion := wallet.Version + 1
|
|
if newVersion > 100000000 { // 你可以根据实际需要设定上限,比如 UNSIGNED INT 的最大值
|
|
newVersion = 1 // 或者设置为其他重置值,比如 0
|
|
}
|
|
|
|
// 使用乐观锁更新余额和版本
|
|
result, err := session.Exec("UPDATE wallets SET balance = balance + ?, version = ? WHERE user_id = ? AND version = ?", amount, newVersion, userId, wallet.Version)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查影响的行数,确保更新成功
|
|
rowsAffected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rowsAffected == 0 {
|
|
return ErrVersionMismatch
|
|
}
|
|
|
|
walletsUserIdKey := fmt.Sprintf("%s%v", cacheWalletsUserIdPrefix, userId)
|
|
cacheErrors := m.DelCacheCtx(ctx, walletsUserIdKey)
|
|
if cacheErrors != nil {
|
|
return cacheErrors
|
|
}
|
|
|
|
return nil
|
|
}
|
|
func (m *customWalletsModel) InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) {
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?, ?, ?)", m.table, walletsRowsExpectAutoSet)
|
|
ret, err := session.ExecCtx(ctx, query, wallets.UserId, wallets.Balance, wallets.Version)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ret, err
|
|
}
|
|
func (m *customWalletsModel) FindWalletsByUserIds(ctx context.Context, userIds []int64) (map[int64]*Wallets, error) {
|
|
if len(userIds) == 0 {
|
|
return make(map[int64]*Wallets), nil
|
|
}
|
|
|
|
queryBuilder := strings.Builder{}
|
|
queryBuilder.WriteString(fmt.Sprintf("SELECT %s FROM %s WHERE user_id IN (", walletsRows, m.table))
|
|
|
|
placeholders := make([]string, len(userIds))
|
|
args := make([]interface{}, len(userIds))
|
|
for i, userId := range userIds {
|
|
placeholders[i] = "?"
|
|
args[i] = userId
|
|
}
|
|
|
|
queryBuilder.WriteString(strings.Join(placeholders, ","))
|
|
queryBuilder.WriteString(")")
|
|
|
|
query := queryBuilder.String()
|
|
|
|
var wallets []Wallets
|
|
err := m.QueryRowsNoCacheCtx(ctx, &wallets, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
walletMap := make(map[int64]*Wallets)
|
|
for _, wallet := range wallets {
|
|
walletCopy := wallet // Create a copy to ensure the correct address is used
|
|
walletMap[wallet.UserId] = &walletCopy
|
|
}
|
|
|
|
return walletMap, nil
|
|
}
|