三端用户手机号联通,增加临时用户
This commit is contained in:
@@ -5,14 +5,16 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"qnc-server/app/main/model"
|
||||
jwtx "qnc-server/common/jwt"
|
||||
)
|
||||
|
||||
const CtxKeyJwtUserId = "userId"
|
||||
|
||||
// 定义错误类型
|
||||
var (
|
||||
ErrNoUserIdInCtx = errors.New("上下文中没有用户ID") // 未登录
|
||||
ErrInvalidUserId = errors.New("用户ID格式无效") // 数据异常
|
||||
ErrNoInCtx = errors.New("上下文中没有相关数据")
|
||||
ErrInvalidUserId = errors.New("用户ID格式无效") // 数据异常
|
||||
)
|
||||
|
||||
// GetUidFromCtx 从 context 中获取用户 ID
|
||||
@@ -20,7 +22,11 @@ func GetUidFromCtx(ctx context.Context) (int64, error) {
|
||||
// 尝试从上下文中获取 jwtUserId
|
||||
value := ctx.Value(CtxKeyJwtUserId)
|
||||
if value == nil {
|
||||
return 0, ErrNoUserIdInCtx
|
||||
claims, err := GetClaimsFromCtx(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
// 根据值的类型进行不同处理
|
||||
@@ -47,12 +53,52 @@ func GetUidFromCtx(ctx context.Context) (int64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func GetClaimsFromCtx(ctx context.Context) (*jwtx.JwtClaims, error) {
|
||||
value := ctx.Value(jwtx.ExtraKey)
|
||||
if value == nil {
|
||||
return nil, ErrNoInCtx
|
||||
}
|
||||
|
||||
// 首先尝试直接断言为 *jwtx.JwtClaims
|
||||
if claims, ok := value.(*jwtx.JwtClaims); ok {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// 如果直接断言失败,尝试从 map[string]interface{} 中解析
|
||||
if claimsMap, ok := value.(map[string]interface{}); ok {
|
||||
return jwtx.MapToJwtClaims(claimsMap)
|
||||
}
|
||||
|
||||
return nil, ErrNoInCtx
|
||||
}
|
||||
|
||||
// IsNoUserIdError 判断是否是未登录错误
|
||||
func IsNoUserIdError(err error) bool {
|
||||
return errors.Is(err, ErrNoUserIdInCtx)
|
||||
return errors.Is(err, ErrNoInCtx)
|
||||
}
|
||||
|
||||
// IsInvalidUserIdError 判断是否是用户ID格式错误
|
||||
func IsInvalidUserIdError(err error) bool {
|
||||
return errors.Is(err, ErrInvalidUserId)
|
||||
}
|
||||
|
||||
// GetPlatformFromCtx 从 context 中获取平台
|
||||
func GetPlatformFromCtx(ctx context.Context) (string, error) {
|
||||
platform, platformOk := ctx.Value("platform").(string)
|
||||
if !platformOk {
|
||||
return "", fmt.Errorf("平台不存在: %s", platform)
|
||||
}
|
||||
|
||||
switch platform {
|
||||
case model.PlatformWxMini:
|
||||
return model.PlatformWxMini, nil
|
||||
case model.PlatformWxH5:
|
||||
return model.PlatformWxH5, nil
|
||||
case model.PlatformApp:
|
||||
return model.PlatformApp, nil
|
||||
case model.PlatformH5:
|
||||
return model.PlatformH5, nil
|
||||
default:
|
||||
return "", fmt.Errorf("不支持的支付平台: %s", platform)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,68 +1,94 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
// Token 生成逻辑的函数,接收 userId、过期时间和密钥,返回生成的 token
|
||||
func GenerateJwtToken(userId int64, secret string, expireTime int64) (string, error) {
|
||||
// 获取当前时间戳
|
||||
now := time.Now().Unix()
|
||||
// 定义 JWT Claims
|
||||
claims := jwt.MapClaims{
|
||||
"exp": now + expireTime, // token 过期时间
|
||||
"iat": now, // 签发时间
|
||||
"userId": userId, // 用户ID
|
||||
const ExtraKey = "extra"
|
||||
|
||||
type JwtClaims struct {
|
||||
UserId int64 `json:"userId"`
|
||||
AgentId int64 `json:"agentId"`
|
||||
Platform string `json:"platform"`
|
||||
// 用户身份类型:0-临时用户,1-正式用户
|
||||
UserType int64 `json:"userType"`
|
||||
// 是否代理:0-否,1-是
|
||||
IsAgent int64 `json:"isAgent"`
|
||||
}
|
||||
|
||||
// MapToJwtClaims 将 map[string]interface{} 转换为 JwtClaims 结构体
|
||||
func MapToJwtClaims(claimsMap map[string]interface{}) (*JwtClaims, error) {
|
||||
// 使用JSON序列化/反序列化的方式自动转换
|
||||
jsonData, err := json.Marshal(claimsMap)
|
||||
if err != nil {
|
||||
return nil, errors.New("序列化claims失败")
|
||||
}
|
||||
|
||||
// 创建新的 JWT token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
var claims JwtClaims
|
||||
if err := json.Unmarshal(jsonData, &claims); err != nil {
|
||||
return nil, errors.New("反序列化claims失败")
|
||||
}
|
||||
|
||||
// 使用密钥对 token 签名
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// GenerateJwtToken 生成JWT token
|
||||
func GenerateJwtToken(claims JwtClaims, secret string, expire int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// 将 claims 结构体转换为 map[string]interface{}
|
||||
claimsBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
var claimsMap map[string]interface{}
|
||||
if err := json.Unmarshal(claimsBytes, &claimsMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jwtClaims := jwt.MapClaims{
|
||||
"exp": now + expire,
|
||||
"iat": now,
|
||||
"userId": claims.UserId,
|
||||
ExtraKey: claimsMap,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims)
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
func ParseJwtToken(tokenStr string, secret string) (int64, error) {
|
||||
|
||||
func ParseJwtToken(tokenStr string, secret string) (*JwtClaims, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return 0, errors.New("invalid JWT")
|
||||
return nil, errors.New("invalid JWT")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return 0, errors.New("invalid JWT claims")
|
||||
return nil, errors.New("invalid JWT claims")
|
||||
}
|
||||
|
||||
// 从 claims 中提取 userId
|
||||
userIdRaw, ok := claims["userId"]
|
||||
if !ok {
|
||||
return 0, errors.New("userId not found in JWT")
|
||||
extraInfo, exists := claims[ExtraKey]
|
||||
if !exists {
|
||||
return nil, errors.New("extra not found in JWT")
|
||||
}
|
||||
|
||||
// 处理不同类型的 userId,确保它被转换为 int64
|
||||
switch userId := userIdRaw.(type) {
|
||||
case float64:
|
||||
return int64(userId), nil
|
||||
case int64:
|
||||
return userId, nil
|
||||
case string:
|
||||
// 如果 userId 是字符串,可以尝试将其转换为 int64
|
||||
parsedId, err := strconv.ParseInt(userId, 10, 64)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid userId in JWT")
|
||||
}
|
||||
return parsedId, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported userId type in JWT")
|
||||
// 尝试直接断言为 JwtClaims 结构体
|
||||
if jwtClaims, ok := extraInfo.(JwtClaims); ok {
|
||||
return &jwtClaims, nil
|
||||
}
|
||||
|
||||
// 尝试从 map[string]interface{} 中解析
|
||||
if claimsMap, ok := extraInfo.(map[string]interface{}); ok {
|
||||
return MapToJwtClaims(claimsMap)
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported extra type in JWT")
|
||||
}
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateAndParseJwtToken(t *testing.T) {
|
||||
// 测试参数
|
||||
userId := int64(39)
|
||||
secret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
expireTime := int64(2592000) // 1小时过期
|
||||
|
||||
// 生成token
|
||||
token, err := GenerateJwtToken(userId, secret, expireTime)
|
||||
if err != nil {
|
||||
t.Fatalf("生成JWT令牌失败: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Fatal("生成的JWT令牌为空")
|
||||
}
|
||||
fmt.Println(token)
|
||||
// 解析token
|
||||
parsedUserId, err := ParseJwtToken("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NTA5MjYyNTgsImlhdCI6MTc0ODMzNDI1OCwidXNlcklkIjoxNDN9.V4qV3dAjE6G-xm0KmB6QHCVhy2SmRDGAWcvl32hLNmI", secret)
|
||||
if err != nil {
|
||||
t.Fatalf("解析JWT令牌失败: %v", err)
|
||||
}
|
||||
fmt.Printf("解析出的userId: %d\n", parsedUserId)
|
||||
// 验证解析出的userId是否正确
|
||||
// if parsedUserId != userId {
|
||||
// t.Errorf("解析出的userId不匹配: 期望 %d, 实际 %d", userId, parsedUserId)
|
||||
// }
|
||||
}
|
||||
|
||||
func TestTokenExpiration(t *testing.T) {
|
||||
// 测试参数
|
||||
userId := int64(10086)
|
||||
secret := "test_secret_key"
|
||||
expireTime := int64(1) // 1秒过期
|
||||
|
||||
// 生成token
|
||||
token, err := GenerateJwtToken(userId, secret, expireTime)
|
||||
if err != nil {
|
||||
t.Fatalf("生成JWT令牌失败: %v", err)
|
||||
}
|
||||
|
||||
// 等待令牌过期
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
// 解析已过期token
|
||||
_, err = ParseJwtToken(token, secret)
|
||||
if err == nil {
|
||||
t.Error("期望令牌过期错误,但没有发生错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidToken(t *testing.T) {
|
||||
secret := "test_secret_key"
|
||||
|
||||
// 测试无效token
|
||||
invalidToken := "invalid.token.string"
|
||||
_, err := ParseJwtToken(invalidToken, secret)
|
||||
if err == nil {
|
||||
t.Error("期望无效令牌错误,但没有发生错误")
|
||||
}
|
||||
|
||||
// 测试密钥不匹配
|
||||
userId := int64(10086)
|
||||
expireTime := int64(3600)
|
||||
token, _ := GenerateJwtToken(userId, "original_secret", expireTime)
|
||||
_, err = ParseJwtToken(token, "wrong_secret")
|
||||
if err == nil {
|
||||
t.Error("期望密钥不匹配错误,但没有发生错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserIdTypes(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
setupFn func() (string, string, int64)
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "正常int64类型",
|
||||
setupFn: func() (string, string, int64) {
|
||||
userId := int64(10086)
|
||||
secret := "test_secret"
|
||||
expireTime := int64(3600)
|
||||
token, _ := GenerateJwtToken(userId, secret, expireTime)
|
||||
return token, secret, userId
|
||||
},
|
||||
expected: 10086,
|
||||
},
|
||||
// 其他类型在实际场景中通过手动修改token内容测试,这里省略
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token, secret, expected := tc.setupFn()
|
||||
userId, err := ParseJwtToken(token, secret)
|
||||
if err != nil {
|
||||
t.Fatalf("解析失败: %v", err)
|
||||
}
|
||||
if userId != expected {
|
||||
t.Errorf("用户ID不匹配: 期望 %d, 实际 %d", expected, userId)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -6,17 +6,44 @@ const OK uint32 = 200
|
||||
/**(前3位代表业务,后三位代表具体功能)**/
|
||||
|
||||
// 全局错误码
|
||||
// 服务器通用错误
|
||||
const SERVER_COMMON_ERROR uint32 = 100001
|
||||
|
||||
// 请求参数错误
|
||||
const REUQEST_PARAM_ERROR uint32 = 100002
|
||||
|
||||
// token过期错误
|
||||
const TOKEN_EXPIRE_ERROR uint32 = 100003
|
||||
|
||||
// token生成错误
|
||||
const TOKEN_GENERATE_ERROR uint32 = 100004
|
||||
|
||||
// 数据库错误
|
||||
const DB_ERROR uint32 = 100005
|
||||
|
||||
// 数据库更新影响行数为0错误
|
||||
const DB_UPDATE_AFFECTED_ZERO_ERROR uint32 = 100006
|
||||
|
||||
// 参数验证错误
|
||||
const PARAM_VERIFICATION_ERROR uint32 = 100007
|
||||
|
||||
// 自定义错误
|
||||
const CUSTOM_ERROR uint32 = 100008
|
||||
|
||||
// 用户不存在错误
|
||||
const USER_NOT_FOUND uint32 = 100009
|
||||
|
||||
// 用户需要绑定手机号
|
||||
const USER_NEED_BIND_MOBILE uint32 = 100010
|
||||
|
||||
// 登录失败错误
|
||||
const LOGIN_FAILED uint32 = 200001
|
||||
|
||||
// 查询等待中
|
||||
const LOGIC_QUERY_WAIT uint32 = 200002
|
||||
|
||||
// 查询错误
|
||||
const LOGIC_QUERY_ERROR uint32 = 200003
|
||||
|
||||
// 查询结果不存在
|
||||
const LOGIC_QUERY_NOT_FOUND uint32 = 200004
|
||||
|
||||
Reference in New Issue
Block a user