package ctxdata import ( "context" "encoding/json" "errors" "testing" ) func TestGetUidFromCtx(t *testing.T) { tests := []struct { name string ctxSetup func() context.Context wantUid int64 wantError error }{ { name: "正常情况_有效用户ID_json.Number", ctxSetup: func() context.Context { return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("12345")) }, wantUid: 12345, wantError: nil, }, { name: "正常情况_有效用户ID_int64", ctxSetup: func() context.Context { return context.WithValue(context.Background(), CtxKeyJwtUserId, int64(12345)) }, wantUid: 12345, wantError: nil, }, { name: "正常情况_有效用户ID_int", ctxSetup: func() context.Context { return context.WithValue(context.Background(), CtxKeyJwtUserId, 12345) }, wantUid: 12345, wantError: nil, }, { name: "正常情况_有效用户ID_float64", ctxSetup: func() context.Context { return context.WithValue(context.Background(), CtxKeyJwtUserId, float64(12345)) }, wantUid: 12345, wantError: nil, }, { name: "异常情况_上下文中无用户ID", ctxSetup: func() context.Context { return context.Background() }, wantUid: 0, wantError: ErrNoUserIdInCtx, }, { name: "异常情况_用户ID类型错误", ctxSetup: func() context.Context { return context.WithValue(context.Background(), CtxKeyJwtUserId, "非数字类型") }, wantUid: 0, wantError: ErrInvalidUserId, }, { name: "异常情况_用户ID无法转换为int64", ctxSetup: func() context.Context { return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("非数字内容")) }, wantUid: 0, wantError: ErrInvalidUserId, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := tt.ctxSetup() gotUid, gotErr := GetUidFromCtx(ctx) // 检查返回的用户ID if gotUid != tt.wantUid { t.Errorf("GetUidFromCtx() 返回用户ID = %v, 期望值 %v", gotUid, tt.wantUid) } // 检查错误类型 if tt.wantError == nil && gotErr != nil { t.Errorf("GetUidFromCtx() 返回意外错误 = %v", gotErr) } if tt.wantError != nil && !errors.Is(gotErr, tt.wantError) { t.Errorf("GetUidFromCtx() 错误类型 = %v, 期望错误类型 %v", gotErr, tt.wantError) } }) } } func TestIsNoUserIdError(t *testing.T) { tests := []struct { name string err error expected bool }{ { name: "是未登录错误", err: ErrNoUserIdInCtx, expected: true, }, { name: "包装的未登录错误", err: errors.New("外层错误: " + ErrNoUserIdInCtx.Error()), expected: false, // 直接字符串拼接不会保留错误链 }, { name: "使用fmt.Errorf包装的未登录错误", err: errors.New("外层错误"), expected: false, }, { name: "非未登录错误", err: ErrInvalidUserId, expected: false, }, { name: "nil错误", err: nil, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsNoUserIdError(tt.err); got != tt.expected { t.Errorf("IsNoUserIdError() = %v, 期望值 %v", got, tt.expected) } }) } } func TestIsInvalidUserIdError(t *testing.T) { tests := []struct { name string err error expected bool }{ { name: "是无效用户ID错误", err: ErrInvalidUserId, expected: true, }, { name: "包装的无效用户ID错误", err: errors.New("外层错误: " + ErrInvalidUserId.Error()), expected: false, // 直接字符串拼接不会保留错误链 }, { name: "使用fmt.Errorf包装的无效用户ID错误", err: errors.New("外层错误"), expected: false, }, { name: "非无效用户ID错误", err: ErrNoUserIdInCtx, expected: false, }, { name: "nil错误", err: nil, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := IsInvalidUserIdError(tt.err); got != tt.expected { t.Errorf("IsInvalidUserIdError() = %v, 期望值 %v", got, tt.expected) } }) } }