180 lines
4.2 KiB
Go
180 lines
4.2 KiB
Go
|
package ctxdata
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"errors"
|
||
|
"testing"
|
||
|
)
|
||
|
|
||
|
func TestGetUidFromCtx(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
ctxSetup func() context.Context
|
||
|
wantUid int64
|
||
|
wantError error
|
||
|
}{
|
||
|
{
|
||
|
name: "正常情况_有效用户ID_json.Number",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("12345"))
|
||
|
},
|
||
|
wantUid: 12345,
|
||
|
wantError: nil,
|
||
|
},
|
||
|
{
|
||
|
name: "正常情况_有效用户ID_int64",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.WithValue(context.Background(), CtxKeyJwtUserId, int64(12345))
|
||
|
},
|
||
|
wantUid: 12345,
|
||
|
wantError: nil,
|
||
|
},
|
||
|
{
|
||
|
name: "正常情况_有效用户ID_int",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.WithValue(context.Background(), CtxKeyJwtUserId, 12345)
|
||
|
},
|
||
|
wantUid: 12345,
|
||
|
wantError: nil,
|
||
|
},
|
||
|
{
|
||
|
name: "正常情况_有效用户ID_float64",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.WithValue(context.Background(), CtxKeyJwtUserId, float64(12345))
|
||
|
},
|
||
|
wantUid: 12345,
|
||
|
wantError: nil,
|
||
|
},
|
||
|
{
|
||
|
name: "异常情况_上下文中无用户ID",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.Background()
|
||
|
},
|
||
|
wantUid: 0,
|
||
|
wantError: ErrNoUserIdInCtx,
|
||
|
},
|
||
|
{
|
||
|
name: "异常情况_用户ID类型错误",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.WithValue(context.Background(), CtxKeyJwtUserId, "非数字类型")
|
||
|
},
|
||
|
wantUid: 0,
|
||
|
wantError: ErrInvalidUserId,
|
||
|
},
|
||
|
{
|
||
|
name: "异常情况_用户ID无法转换为int64",
|
||
|
ctxSetup: func() context.Context {
|
||
|
return context.WithValue(context.Background(), CtxKeyJwtUserId, json.Number("非数字内容"))
|
||
|
},
|
||
|
wantUid: 0,
|
||
|
wantError: ErrInvalidUserId,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
ctx := tt.ctxSetup()
|
||
|
gotUid, gotErr := GetUidFromCtx(ctx)
|
||
|
|
||
|
// 检查返回的用户ID
|
||
|
if gotUid != tt.wantUid {
|
||
|
t.Errorf("GetUidFromCtx() 返回用户ID = %v, 期望值 %v", gotUid, tt.wantUid)
|
||
|
}
|
||
|
|
||
|
// 检查错误类型
|
||
|
if tt.wantError == nil && gotErr != nil {
|
||
|
t.Errorf("GetUidFromCtx() 返回意外错误 = %v", gotErr)
|
||
|
}
|
||
|
|
||
|
if tt.wantError != nil && !errors.Is(gotErr, tt.wantError) {
|
||
|
t.Errorf("GetUidFromCtx() 错误类型 = %v, 期望错误类型 %v", gotErr, tt.wantError)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIsNoUserIdError(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
err error
|
||
|
expected bool
|
||
|
}{
|
||
|
{
|
||
|
name: "是未登录错误",
|
||
|
err: ErrNoUserIdInCtx,
|
||
|
expected: true,
|
||
|
},
|
||
|
{
|
||
|
name: "包装的未登录错误",
|
||
|
err: errors.New("外层错误: " + ErrNoUserIdInCtx.Error()),
|
||
|
expected: false, // 直接字符串拼接不会保留错误链
|
||
|
},
|
||
|
{
|
||
|
name: "使用fmt.Errorf包装的未登录错误",
|
||
|
err: errors.New("外层错误"),
|
||
|
expected: false,
|
||
|
},
|
||
|
{
|
||
|
name: "非未登录错误",
|
||
|
err: ErrInvalidUserId,
|
||
|
expected: false,
|
||
|
},
|
||
|
{
|
||
|
name: "nil错误",
|
||
|
err: nil,
|
||
|
expected: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
if got := IsNoUserIdError(tt.err); got != tt.expected {
|
||
|
t.Errorf("IsNoUserIdError() = %v, 期望值 %v", got, tt.expected)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func TestIsInvalidUserIdError(t *testing.T) {
|
||
|
tests := []struct {
|
||
|
name string
|
||
|
err error
|
||
|
expected bool
|
||
|
}{
|
||
|
{
|
||
|
name: "是无效用户ID错误",
|
||
|
err: ErrInvalidUserId,
|
||
|
expected: true,
|
||
|
},
|
||
|
{
|
||
|
name: "包装的无效用户ID错误",
|
||
|
err: errors.New("外层错误: " + ErrInvalidUserId.Error()),
|
||
|
expected: false, // 直接字符串拼接不会保留错误链
|
||
|
},
|
||
|
{
|
||
|
name: "使用fmt.Errorf包装的无效用户ID错误",
|
||
|
err: errors.New("外层错误"),
|
||
|
expected: false,
|
||
|
},
|
||
|
{
|
||
|
name: "非无效用户ID错误",
|
||
|
err: ErrNoUserIdInCtx,
|
||
|
expected: false,
|
||
|
},
|
||
|
{
|
||
|
name: "nil错误",
|
||
|
err: nil,
|
||
|
expected: false,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
for _, tt := range tests {
|
||
|
t.Run(tt.name, func(t *testing.T) {
|
||
|
if got := IsInvalidUserIdError(tt.err); got != tt.expected {
|
||
|
t.Errorf("IsInvalidUserIdError() = %v, 期望值 %v", got, tt.expected)
|
||
|
}
|
||
|
})
|
||
|
}
|
||
|
}
|