tianyuan-api-server/apps/user/internal/model/walletsmodel.go

126 lines
3.9 KiB
Go
Raw Normal View History

2024-10-12 20:41:55 +08:00
package model
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/sqlx"
"strings"
2024-10-12 20:41:55 +08:00
)
var _ WalletsModel = (*customWalletsModel)(nil)
type (
// WalletsModel is an interface to be customized, add more methods here,
// and implement the added methods in customWalletsModel.
WalletsModel interface {
walletsModel
InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error)
UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error
TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error
FindWalletsByUserIds(ctx context.Context, userIds []int64) (map[int64]*Wallets, error)
2024-10-12 20:41:55 +08:00
}
customWalletsModel struct {
*defaultWalletsModel
}
)
var ErrBalanceNotEnough = errors.New("余额不足")
var ErrVersionMismatch = errors.New("版本号不匹配,请重试")
// NewWalletsModel returns a model for the database table.
func NewWalletsModel(conn sqlx.SqlConn, c cache.CacheConf, opts ...cache.Option) WalletsModel {
return &customWalletsModel{
defaultWalletsModel: newWalletsModel(conn, c, opts...),
}
}
func (m *customWalletsModel) TransCtx(ctx context.Context, fn func(ctx context.Context, session sqlx.Session) error) error {
// 使用带 ctx 的事务处理
err := m.TransactCtx(ctx, func(ctx context.Context, session sqlx.Session) error {
return fn(ctx, session)
})
return err
}
// 更新余额的方法
func (m *customWalletsModel) UpdateBalance(session sqlx.Session, ctx context.Context, userId int64, amount float64) error {
2024-10-21 17:07:25 +08:00
//wallet, err := m.FindOneByUserId(ctx, userId)
//if err != nil {
// return err
//}
var wallet Wallets
query := fmt.Sprintf("SELECT %s FROM %s WHERE `user_id` = ? limit 1", walletsRows, m.table)
if err := m.QueryRowNoCacheCtx(ctx, &wallet, query, userId); err != nil {
2024-10-12 20:41:55 +08:00
return err
}
2024-10-15 23:58:36 +08:00
// 检查当前版本是否达到了上限
newVersion := wallet.Version + 1
if newVersion > 100000000 { // 你可以根据实际需要设定上限,比如 UNSIGNED INT 的最大值
newVersion = 1 // 或者设置为其他重置值,比如 0
}
// 使用乐观锁更新余额和版本
result, err := session.Exec("UPDATE wallets SET balance = balance + ?, version = ? WHERE user_id = ? AND version = ?", amount, newVersion, userId, wallet.Version)
2024-10-12 20:41:55 +08:00
if err != nil {
return err
}
// 检查影响的行数,确保更新成功
rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}
if rowsAffected == 0 {
return ErrVersionMismatch
}
return nil
}
func (m *customWalletsModel) InsertWalletsTrans(ctx context.Context, wallets *Wallets, session sqlx.Session) (sql.Result, error) {
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (?, ?, ?)", m.table, walletsRowsExpectAutoSet)
ret, err := session.ExecCtx(ctx, query, wallets.UserId, wallets.Balance, wallets.Version)
if err != nil {
return nil, err
}
return ret, err
}
func (m *customWalletsModel) FindWalletsByUserIds(ctx context.Context, userIds []int64) (map[int64]*Wallets, error) {
if len(userIds) == 0 {
return make(map[int64]*Wallets), nil
}
queryBuilder := strings.Builder{}
2024-10-21 17:07:25 +08:00
queryBuilder.WriteString(fmt.Sprintf("SELECT %s FROM %s WHERE user_id IN (", walletsRows, m.table))
placeholders := make([]string, len(userIds))
args := make([]interface{}, len(userIds))
for i, userId := range userIds {
placeholders[i] = "?"
args[i] = userId
}
queryBuilder.WriteString(strings.Join(placeholders, ","))
queryBuilder.WriteString(")")
query := queryBuilder.String()
2024-10-21 17:07:25 +08:00
var wallets []Wallets
err := m.QueryRowsNoCacheCtx(ctx, &wallets, query, args...)
if err != nil {
return nil, err
}
walletMap := make(map[int64]*Wallets)
for _, wallet := range wallets {
2024-10-21 17:07:25 +08:00
walletCopy := wallet // Create a copy to ensure the correct address is used
walletMap[wallet.UserId] = &walletCopy
}
return walletMap, nil
}