fix wxh5 jwt userid type
This commit is contained in:
parent
176ad00e35
commit
beda62f833
@ -2,6 +2,8 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"tydata-server/app/user/cmd/api/internal/config"
|
||||
@ -47,8 +49,10 @@ func (m *AuthInterceptorMiddleware) Handle(next http.HandlerFunc) http.HandlerFu
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户ID添加到请求上下文
|
||||
ctx := context.WithValue(r.Context(), ctxdata.CtxKeyJwtUserId, userId)
|
||||
// 将用户ID转换为json.Number类型后添加到请求上下文
|
||||
userIdStr := fmt.Sprintf("%d", userId)
|
||||
userIdJsonNum := json.Number(userIdStr)
|
||||
ctx := context.WithValue(r.Context(), ctxdata.CtxKeyJwtUserId, userIdJsonNum)
|
||||
|
||||
// 使用新的上下文继续处理请求
|
||||
next(w, r.WithContext(ctx))
|
||||
|
@ -23,19 +23,28 @@ func GetUidFromCtx(ctx context.Context) (int64, error) {
|
||||
return 0, ErrNoUserIdInCtx
|
||||
}
|
||||
|
||||
// 尝试转换为 json.Number
|
||||
jsonUid, ok := value.(json.Number)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("%w: 期望类型 json.Number, 实际类型 %T", ErrInvalidUserId, value)
|
||||
}
|
||||
|
||||
// 转换为 int64
|
||||
uid, err := jsonUid.Int64()
|
||||
// 根据值的类型进行不同处理
|
||||
switch v := value.(type) {
|
||||
case json.Number:
|
||||
// 如果是 json.Number 类型,转换为 int64
|
||||
uid, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err)
|
||||
}
|
||||
|
||||
return uid, nil
|
||||
case int64:
|
||||
// 如果已经是 int64 类型,直接返回
|
||||
return v, nil
|
||||
case float64:
|
||||
// 有些JSON解析器可能会将数字解析为float64
|
||||
return int64(v), nil
|
||||
case int:
|
||||
// 处理int类型
|
||||
return int64(v), nil
|
||||
default:
|
||||
// 其他类型都视为无效
|
||||
return 0, fmt.Errorf("%w: 期望类型 json.Number 或 int64, 实际类型 %T", ErrInvalidUserId, value)
|
||||
}
|
||||
}
|
||||
|
||||
// IsNoUserIdError 判断是否是未登录错误
|
||||
|
179
common/ctxdata/ctxData_test.go
Normal file
179
common/ctxdata/ctxData_test.go
Normal file
@ -0,0 +1,179 @@
|
||||
package ctxdata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetUidFromCtx(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ctxSetup func() context.Context
|
||||
wantUid int64
|
||||
wantError error
|
||||
}{
|
||||
{
|
||||
name: "正常情况_有效用户ID_json.Number",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("12345"))
|
||||
},
|
||||
wantUid: 12345,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "正常情况_有效用户ID_int64",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.WithValue(context.Background(), CtxKeyJwtUserId, int64(12345))
|
||||
},
|
||||
wantUid: 12345,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "正常情况_有效用户ID_int",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.WithValue(context.Background(), CtxKeyJwtUserId, 12345)
|
||||
},
|
||||
wantUid: 12345,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "正常情况_有效用户ID_float64",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.WithValue(context.Background(), CtxKeyJwtUserId, float64(12345))
|
||||
},
|
||||
wantUid: 12345,
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "异常情况_上下文中无用户ID",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.Background()
|
||||
},
|
||||
wantUid: 0,
|
||||
wantError: ErrNoUserIdInCtx,
|
||||
},
|
||||
{
|
||||
name: "异常情况_用户ID类型错误",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.WithValue(context.Background(), CtxKeyJwtUserId, "非数字类型")
|
||||
},
|
||||
wantUid: 0,
|
||||
wantError: ErrInvalidUserId,
|
||||
},
|
||||
{
|
||||
name: "异常情况_用户ID无法转换为int64",
|
||||
ctxSetup: func() context.Context {
|
||||
return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("非数字内容"))
|
||||
},
|
||||
wantUid: 0,
|
||||
wantError: ErrInvalidUserId,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := tt.ctxSetup()
|
||||
gotUid, gotErr := GetUidFromCtx(ctx)
|
||||
|
||||
// 检查返回的用户ID
|
||||
if gotUid != tt.wantUid {
|
||||
t.Errorf("GetUidFromCtx() 返回用户ID = %v, 期望值 %v", gotUid, tt.wantUid)
|
||||
}
|
||||
|
||||
// 检查错误类型
|
||||
if tt.wantError == nil && gotErr != nil {
|
||||
t.Errorf("GetUidFromCtx() 返回意外错误 = %v", gotErr)
|
||||
}
|
||||
|
||||
if tt.wantError != nil && !errors.Is(gotErr, tt.wantError) {
|
||||
t.Errorf("GetUidFromCtx() 错误类型 = %v, 期望错误类型 %v", gotErr, tt.wantError)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsNoUserIdError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "是未登录错误",
|
||||
err: ErrNoUserIdInCtx,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "包装的未登录错误",
|
||||
err: errors.New("外层错误: " + ErrNoUserIdInCtx.Error()),
|
||||
expected: false, // 直接字符串拼接不会保留错误链
|
||||
},
|
||||
{
|
||||
name: "使用fmt.Errorf包装的未登录错误",
|
||||
err: errors.New("外层错误"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "非未登录错误",
|
||||
err: ErrInvalidUserId,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil错误",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsNoUserIdError(tt.err); got != tt.expected {
|
||||
t.Errorf("IsNoUserIdError() = %v, 期望值 %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsInvalidUserIdError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "是无效用户ID错误",
|
||||
err: ErrInvalidUserId,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "包装的无效用户ID错误",
|
||||
err: errors.New("外层错误: " + ErrInvalidUserId.Error()),
|
||||
expected: false, // 直接字符串拼接不会保留错误链
|
||||
},
|
||||
{
|
||||
name: "使用fmt.Errorf包装的无效用户ID错误",
|
||||
err: errors.New("外层错误"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "非无效用户ID错误",
|
||||
err: ErrNoUserIdInCtx,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil错误",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := IsInvalidUserIdError(tt.err); got != tt.expected {
|
||||
t.Errorf("IsInvalidUserIdError() = %v, 期望值 %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user