diff --git a/app/user/cmd/api/internal/middleware/authinterceptormiddleware.go b/app/user/cmd/api/internal/middleware/authinterceptormiddleware.go index ae4ec86..3bf8160 100644 --- a/app/user/cmd/api/internal/middleware/authinterceptormiddleware.go +++ b/app/user/cmd/api/internal/middleware/authinterceptormiddleware.go @@ -2,6 +2,8 @@ package middleware import ( "context" + "encoding/json" + "fmt" "net/http" "tydata-server/app/user/cmd/api/internal/config" @@ -47,8 +49,10 @@ func (m *AuthInterceptorMiddleware) Handle(next http.HandlerFunc) http.HandlerFu return } - // 将用户ID添加到请求上下文 - ctx := context.WithValue(r.Context(), ctxdata.CtxKeyJwtUserId, userId) + // 将用户ID转换为json.Number类型后添加到请求上下文 + userIdStr := fmt.Sprintf("%d", userId) + userIdJsonNum := json.Number(userIdStr) + ctx := context.WithValue(r.Context(), ctxdata.CtxKeyJwtUserId, userIdJsonNum) // 使用新的上下文继续处理请求 next(w, r.WithContext(ctx)) diff --git a/common/ctxdata/ctxData.go b/common/ctxdata/ctxData.go index ba3e4c6..10782f5 100644 --- a/common/ctxdata/ctxData.go +++ b/common/ctxdata/ctxData.go @@ -23,19 +23,28 @@ func GetUidFromCtx(ctx context.Context) (int64, error) { return 0, ErrNoUserIdInCtx } - // 尝试转换为 json.Number - jsonUid, ok := value.(json.Number) - if !ok { - return 0, fmt.Errorf("%w: 期望类型 json.Number, 实际类型 %T", ErrInvalidUserId, value) + // 根据值的类型进行不同处理 + switch v := value.(type) { + case json.Number: + // 如果是 json.Number 类型,转换为 int64 + uid, err := v.Int64() + if err != nil { + return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err) + } + return uid, nil + case int64: + // 如果已经是 int64 类型,直接返回 + return v, nil + case float64: + // 有些JSON解析器可能会将数字解析为float64 + return int64(v), nil + case int: + // 处理int类型 + return int64(v), nil + default: + // 其他类型都视为无效 + return 0, fmt.Errorf("%w: 期望类型 json.Number 或 int64, 实际类型 %T", ErrInvalidUserId, value) } - - // 转换为 int64 - uid, err := jsonUid.Int64() - if err != nil { - return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err) - } - - return uid, nil } // IsNoUserIdError 判断是否是未登录错误 diff --git a/common/ctxdata/ctxData_test.go b/common/ctxdata/ctxData_test.go new file mode 100644 index 0000000..6d7fb0e --- /dev/null +++ b/common/ctxdata/ctxData_test.go @@ -0,0 +1,179 @@ +package ctxdata + +import ( + "context" + "encoding/json" + "errors" + "testing" +) + +func TestGetUidFromCtx(t *testing.T) { + tests := []struct { + name string + ctxSetup func() context.Context + wantUid int64 + wantError error + }{ + { + name: "正常情况_有效用户ID_json.Number", + ctxSetup: func() context.Context { + return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("12345")) + }, + wantUid: 12345, + wantError: nil, + }, + { + name: "正常情况_有效用户ID_int64", + ctxSetup: func() context.Context { + return context.WithValue(context.Background(), CtxKeyJwtUserId, int64(12345)) + }, + wantUid: 12345, + wantError: nil, + }, + { + name: "正常情况_有效用户ID_int", + ctxSetup: func() context.Context { + return context.WithValue(context.Background(), CtxKeyJwtUserId, 12345) + }, + wantUid: 12345, + wantError: nil, + }, + { + name: "正常情况_有效用户ID_float64", + ctxSetup: func() context.Context { + return context.WithValue(context.Background(), CtxKeyJwtUserId, float64(12345)) + }, + wantUid: 12345, + wantError: nil, + }, + { + name: "异常情况_上下文中无用户ID", + ctxSetup: func() context.Context { + return context.Background() + }, + wantUid: 0, + wantError: ErrNoUserIdInCtx, + }, + { + name: "异常情况_用户ID类型错误", + ctxSetup: func() context.Context { + return context.WithValue(context.Background(), CtxKeyJwtUserId, "非数字类型") + }, + wantUid: 0, + wantError: ErrInvalidUserId, + }, + { + name: "异常情况_用户ID无法转换为int64", + ctxSetup: func() context.Context { + return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("非数字内容")) + }, + wantUid: 0, + wantError: ErrInvalidUserId, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := tt.ctxSetup() + gotUid, gotErr := GetUidFromCtx(ctx) + + // 检查返回的用户ID + if gotUid != tt.wantUid { + t.Errorf("GetUidFromCtx() 返回用户ID = %v, 期望值 %v", gotUid, tt.wantUid) + } + + // 检查错误类型 + if tt.wantError == nil && gotErr != nil { + t.Errorf("GetUidFromCtx() 返回意外错误 = %v", gotErr) + } + + if tt.wantError != nil && !errors.Is(gotErr, tt.wantError) { + t.Errorf("GetUidFromCtx() 错误类型 = %v, 期望错误类型 %v", gotErr, tt.wantError) + } + }) + } +} + +func TestIsNoUserIdError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "是未登录错误", + err: ErrNoUserIdInCtx, + expected: true, + }, + { + name: "包装的未登录错误", + err: errors.New("外层错误: " + ErrNoUserIdInCtx.Error()), + expected: false, // 直接字符串拼接不会保留错误链 + }, + { + name: "使用fmt.Errorf包装的未登录错误", + err: errors.New("外层错误"), + expected: false, + }, + { + name: "非未登录错误", + err: ErrInvalidUserId, + expected: false, + }, + { + name: "nil错误", + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsNoUserIdError(tt.err); got != tt.expected { + t.Errorf("IsNoUserIdError() = %v, 期望值 %v", got, tt.expected) + } + }) + } +} + +func TestIsInvalidUserIdError(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "是无效用户ID错误", + err: ErrInvalidUserId, + expected: true, + }, + { + name: "包装的无效用户ID错误", + err: errors.New("外层错误: " + ErrInvalidUserId.Error()), + expected: false, // 直接字符串拼接不会保留错误链 + }, + { + name: "使用fmt.Errorf包装的无效用户ID错误", + err: errors.New("外层错误"), + expected: false, + }, + { + name: "非无效用户ID错误", + err: ErrNoUserIdInCtx, + expected: false, + }, + { + name: "nil错误", + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsInvalidUserIdError(tt.err); got != tt.expected { + t.Errorf("IsInvalidUserIdError() = %v, 期望值 %v", got, tt.expected) + } + }) + } +}