2024-10-12 20:41:55 +08:00
package model
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/zeromicro/go-zero/core/stores/cache"
"github.com/zeromicro/go-zero/core/stores/sqlx"
2024-10-21 16:01:20 +08:00
"strings"
2024-10-12 20:41:55 +08:00
)
var _ WalletsModel = ( * customWalletsModel ) ( nil )
type (
// WalletsModel is an interface to be customized, add more methods here,
// and implement the added methods in customWalletsModel.
WalletsModel interface {
walletsModel
InsertWalletsTrans ( ctx context . Context , wallets * Wallets , session sqlx . Session ) ( sql . Result , error )
UpdateBalance ( session sqlx . Session , ctx context . Context , userId int64 , amount float64 ) error
TransCtx ( ctx context . Context , fn func ( ctx context . Context , session sqlx . Session ) error ) error
2024-10-21 16:01:20 +08:00
FindWalletsByUserIds ( ctx context . Context , userIds [ ] int64 ) ( map [ int64 ] * Wallets , error )
2024-10-12 20:41:55 +08:00
}
customWalletsModel struct {
* defaultWalletsModel
}
)
var ErrBalanceNotEnough = errors . New ( "余额不足" )
var ErrVersionMismatch = errors . New ( "版本号不匹配,请重试" )
// NewWalletsModel returns a model for the database table.
func NewWalletsModel ( conn sqlx . SqlConn , c cache . CacheConf , opts ... cache . Option ) WalletsModel {
return & customWalletsModel {
defaultWalletsModel : newWalletsModel ( conn , c , opts ... ) ,
}
}
func ( m * customWalletsModel ) TransCtx ( ctx context . Context , fn func ( ctx context . Context , session sqlx . Session ) error ) error {
// 使用带 ctx 的事务处理
err := m . TransactCtx ( ctx , func ( ctx context . Context , session sqlx . Session ) error {
return fn ( ctx , session )
} )
return err
}
// 更新余额的方法
func ( m * customWalletsModel ) UpdateBalance ( session sqlx . Session , ctx context . Context , userId int64 , amount float64 ) error {
2024-10-21 17:07:25 +08:00
//wallet, err := m.FindOneByUserId(ctx, userId)
//if err != nil {
// return err
//}
var wallet Wallets
query := fmt . Sprintf ( "SELECT %s FROM %s WHERE `user_id` = ? limit 1" , walletsRows , m . table )
if err := m . QueryRowNoCacheCtx ( ctx , & wallet , query , userId ) ; err != nil {
2024-10-12 20:41:55 +08:00
return err
}
2024-10-15 23:58:36 +08:00
// 检查当前版本是否达到了上限
newVersion := wallet . Version + 1
if newVersion > 100000000 { // 你可以根据实际需要设定上限,比如 UNSIGNED INT 的最大值
newVersion = 1 // 或者设置为其他重置值,比如 0
}
// 使用乐观锁更新余额和版本
result , err := session . Exec ( "UPDATE wallets SET balance = balance + ?, version = ? WHERE user_id = ? AND version = ?" , amount , newVersion , userId , wallet . Version )
2024-10-12 20:41:55 +08:00
if err != nil {
return err
}
// 检查影响的行数,确保更新成功
rowsAffected , err := result . RowsAffected ( )
if err != nil {
return err
}
if rowsAffected == 0 {
return ErrVersionMismatch
}
2024-10-21 17:49:50 +08:00
walletsUserIdKey := fmt . Sprintf ( "%s%v" , cacheWalletsUserIdPrefix , userId )
cacheErrors := m . DelCacheCtx ( ctx , walletsUserIdKey )
if cacheErrors != nil {
return cacheErrors
}
2024-10-12 20:41:55 +08:00
return nil
}
func ( m * customWalletsModel ) InsertWalletsTrans ( ctx context . Context , wallets * Wallets , session sqlx . Session ) ( sql . Result , error ) {
query := fmt . Sprintf ( "INSERT INTO %s (%s) VALUES (?, ?, ?)" , m . table , walletsRowsExpectAutoSet )
ret , err := session . ExecCtx ( ctx , query , wallets . UserId , wallets . Balance , wallets . Version )
if err != nil {
return nil , err
}
return ret , err
}
2024-10-21 16:01:20 +08:00
func ( m * customWalletsModel ) FindWalletsByUserIds ( ctx context . Context , userIds [ ] int64 ) ( map [ int64 ] * Wallets , error ) {
if len ( userIds ) == 0 {
return make ( map [ int64 ] * Wallets ) , nil
}
queryBuilder := strings . Builder { }
2024-10-21 17:07:25 +08:00
queryBuilder . WriteString ( fmt . Sprintf ( "SELECT %s FROM %s WHERE user_id IN (" , walletsRows , m . table ) )
2024-10-21 16:01:20 +08:00
placeholders := make ( [ ] string , len ( userIds ) )
args := make ( [ ] interface { } , len ( userIds ) )
for i , userId := range userIds {
placeholders [ i ] = "?"
args [ i ] = userId
}
queryBuilder . WriteString ( strings . Join ( placeholders , "," ) )
queryBuilder . WriteString ( ")" )
query := queryBuilder . String ( )
2024-10-21 17:07:25 +08:00
var wallets [ ] Wallets
err := m . QueryRowsNoCacheCtx ( ctx , & wallets , query , args ... )
2024-10-21 16:01:20 +08:00
if err != nil {
return nil , err
}
walletMap := make ( map [ int64 ] * Wallets )
for _ , wallet := range wallets {
2024-10-21 17:07:25 +08:00
walletCopy := wallet // Create a copy to ensure the correct address is used
walletMap [ wallet . UserId ] = & walletCopy
2024-10-21 16:01:20 +08:00
}
return walletMap , nil
}