fix wxh5 jwt userid type
This commit is contained in:
		| @@ -2,6 +2,8 @@ package middleware | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"fmt" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"tydata-server/app/user/cmd/api/internal/config" | ||||
| @@ -47,8 +49,10 @@ func (m *AuthInterceptorMiddleware) Handle(next http.HandlerFunc) http.HandlerFu | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// 将用户ID添加到请求上下文 | ||||
| 		ctx := context.WithValue(r.Context(), ctxdata.CtxKeyJwtUserId, userId) | ||||
| 		// 将用户ID转换为json.Number类型后添加到请求上下文 | ||||
| 		userIdStr := fmt.Sprintf("%d", userId) | ||||
| 		userIdJsonNum := json.Number(userIdStr) | ||||
| 		ctx := context.WithValue(r.Context(), ctxdata.CtxKeyJwtUserId, userIdJsonNum) | ||||
|  | ||||
| 		// 使用新的上下文继续处理请求 | ||||
| 		next(w, r.WithContext(ctx)) | ||||
|   | ||||
| @@ -23,19 +23,28 @@ func GetUidFromCtx(ctx context.Context) (int64, error) { | ||||
| 		return 0, ErrNoUserIdInCtx | ||||
| 	} | ||||
|  | ||||
| 	// 尝试转换为 json.Number | ||||
| 	jsonUid, ok := value.(json.Number) | ||||
| 	if !ok { | ||||
| 		return 0, fmt.Errorf("%w: 期望类型 json.Number, 实际类型 %T", ErrInvalidUserId, value) | ||||
| 	// 根据值的类型进行不同处理 | ||||
| 	switch v := value.(type) { | ||||
| 	case json.Number: | ||||
| 		// 如果是 json.Number 类型,转换为 int64 | ||||
| 		uid, err := v.Int64() | ||||
| 		if err != nil { | ||||
| 			return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err) | ||||
| 		} | ||||
| 		return uid, nil | ||||
| 	case int64: | ||||
| 		// 如果已经是 int64 类型,直接返回 | ||||
| 		return v, nil | ||||
| 	case float64: | ||||
| 		// 有些JSON解析器可能会将数字解析为float64 | ||||
| 		return int64(v), nil | ||||
| 	case int: | ||||
| 		// 处理int类型 | ||||
| 		return int64(v), nil | ||||
| 	default: | ||||
| 		// 其他类型都视为无效 | ||||
| 		return 0, fmt.Errorf("%w: 期望类型 json.Number 或 int64, 实际类型 %T", ErrInvalidUserId, value) | ||||
| 	} | ||||
|  | ||||
| 	// 转换为 int64 | ||||
| 	uid, err := jsonUid.Int64() | ||||
| 	if err != nil { | ||||
| 		return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err) | ||||
| 	} | ||||
|  | ||||
| 	return uid, nil | ||||
| } | ||||
|  | ||||
| // IsNoUserIdError 判断是否是未登录错误 | ||||
|   | ||||
							
								
								
									
										179
									
								
								common/ctxdata/ctxData_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										179
									
								
								common/ctxdata/ctxData_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,179 @@ | ||||
| package ctxdata | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestGetUidFromCtx(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name      string | ||||
| 		ctxSetup  func() context.Context | ||||
| 		wantUid   int64 | ||||
| 		wantError error | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name: "正常情况_有效用户ID_json.Number", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("12345")) | ||||
| 			}, | ||||
| 			wantUid:   12345, | ||||
| 			wantError: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "正常情况_有效用户ID_int64", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.WithValue(context.Background(), CtxKeyJwtUserId, int64(12345)) | ||||
| 			}, | ||||
| 			wantUid:   12345, | ||||
| 			wantError: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "正常情况_有效用户ID_int", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.WithValue(context.Background(), CtxKeyJwtUserId, 12345) | ||||
| 			}, | ||||
| 			wantUid:   12345, | ||||
| 			wantError: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "正常情况_有效用户ID_float64", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.WithValue(context.Background(), CtxKeyJwtUserId, float64(12345)) | ||||
| 			}, | ||||
| 			wantUid:   12345, | ||||
| 			wantError: nil, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "异常情况_上下文中无用户ID", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.Background() | ||||
| 			}, | ||||
| 			wantUid:   0, | ||||
| 			wantError: ErrNoUserIdInCtx, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "异常情况_用户ID类型错误", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.WithValue(context.Background(), CtxKeyJwtUserId, "非数字类型") | ||||
| 			}, | ||||
| 			wantUid:   0, | ||||
| 			wantError: ErrInvalidUserId, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name: "异常情况_用户ID无法转换为int64", | ||||
| 			ctxSetup: func() context.Context { | ||||
| 				return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("非数字内容")) | ||||
| 			}, | ||||
| 			wantUid:   0, | ||||
| 			wantError: ErrInvalidUserId, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			ctx := tt.ctxSetup() | ||||
| 			gotUid, gotErr := GetUidFromCtx(ctx) | ||||
|  | ||||
| 			// 检查返回的用户ID | ||||
| 			if gotUid != tt.wantUid { | ||||
| 				t.Errorf("GetUidFromCtx() 返回用户ID = %v, 期望值 %v", gotUid, tt.wantUid) | ||||
| 			} | ||||
|  | ||||
| 			// 检查错误类型 | ||||
| 			if tt.wantError == nil && gotErr != nil { | ||||
| 				t.Errorf("GetUidFromCtx() 返回意外错误 = %v", gotErr) | ||||
| 			} | ||||
|  | ||||
| 			if tt.wantError != nil && !errors.Is(gotErr, tt.wantError) { | ||||
| 				t.Errorf("GetUidFromCtx() 错误类型 = %v, 期望错误类型 %v", gotErr, tt.wantError) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIsNoUserIdError(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name     string | ||||
| 		err      error | ||||
| 		expected bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:     "是未登录错误", | ||||
| 			err:      ErrNoUserIdInCtx, | ||||
| 			expected: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "包装的未登录错误", | ||||
| 			err:      errors.New("外层错误: " + ErrNoUserIdInCtx.Error()), | ||||
| 			expected: false, // 直接字符串拼接不会保留错误链 | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "使用fmt.Errorf包装的未登录错误", | ||||
| 			err:      errors.New("外层错误"), | ||||
| 			expected: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "非未登录错误", | ||||
| 			err:      ErrInvalidUserId, | ||||
| 			expected: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "nil错误", | ||||
| 			err:      nil, | ||||
| 			expected: false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if got := IsNoUserIdError(tt.err); got != tt.expected { | ||||
| 				t.Errorf("IsNoUserIdError() = %v, 期望值 %v", got, tt.expected) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestIsInvalidUserIdError(t *testing.T) { | ||||
| 	tests := []struct { | ||||
| 		name     string | ||||
| 		err      error | ||||
| 		expected bool | ||||
| 	}{ | ||||
| 		{ | ||||
| 			name:     "是无效用户ID错误", | ||||
| 			err:      ErrInvalidUserId, | ||||
| 			expected: true, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "包装的无效用户ID错误", | ||||
| 			err:      errors.New("外层错误: " + ErrInvalidUserId.Error()), | ||||
| 			expected: false, // 直接字符串拼接不会保留错误链 | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "使用fmt.Errorf包装的无效用户ID错误", | ||||
| 			err:      errors.New("外层错误"), | ||||
| 			expected: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "非无效用户ID错误", | ||||
| 			err:      ErrNoUserIdInCtx, | ||||
| 			expected: false, | ||||
| 		}, | ||||
| 		{ | ||||
| 			name:     "nil错误", | ||||
| 			err:      nil, | ||||
| 			expected: false, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tt := range tests { | ||||
| 		t.Run(tt.name, func(t *testing.T) { | ||||
| 			if got := IsInvalidUserIdError(tt.err); got != tt.expected { | ||||
| 				t.Errorf("IsInvalidUserIdError() = %v, 期望值 %v", got, tt.expected) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
		Reference in New Issue
	
	Block a user