89 lines
2.2 KiB
Go
89 lines
2.2 KiB
Go
|
|
package ctxdata
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"qnc-server/app/main/model"
|
|||
|
|
jwtx "qnc-server/common/jwt"
|
|||
|
|
"errors"
|
|||
|
|
"fmt"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
const CtxKeyJwtUserId = "userId"
|
|||
|
|
|
|||
|
|
// 定义错误类型
|
|||
|
|
var (
|
|||
|
|
ErrNoInCtx = errors.New("上下文中没有相关数据")
|
|||
|
|
ErrInvalidUserId = errors.New("用户ID格式无效") // 数据异常
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// GetUidFromCtx 从 context 中获取用户 ID(字符串)
|
|||
|
|
func GetUidFromCtx(ctx context.Context) (string, error) {
|
|||
|
|
// 尝试从上下文中获取 jwtUserId
|
|||
|
|
value := ctx.Value(CtxKeyJwtUserId)
|
|||
|
|
if value == nil {
|
|||
|
|
claims, err := GetClaimsFromCtx(ctx)
|
|||
|
|
if err != nil {
|
|||
|
|
return "", err
|
|||
|
|
}
|
|||
|
|
return claims.UserId, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 根据值的类型进行不同处理
|
|||
|
|
switch v := value.(type) {
|
|||
|
|
case string:
|
|||
|
|
return v, nil
|
|||
|
|
default:
|
|||
|
|
return "", fmt.Errorf("%w: 期望类型 string, 实际类型 %T", ErrInvalidUserId, value)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func GetClaimsFromCtx(ctx context.Context) (*jwtx.JwtClaims, error) {
|
|||
|
|
value := ctx.Value(jwtx.ExtraKey)
|
|||
|
|
if value == nil {
|
|||
|
|
return nil, ErrNoInCtx
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 首先尝试直接断言为 *jwtx.JwtClaims
|
|||
|
|
if claims, ok := value.(*jwtx.JwtClaims); ok {
|
|||
|
|
return claims, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 如果直接断言失败,尝试从 map[string]interface{} 中解析
|
|||
|
|
if claimsMap, ok := value.(map[string]interface{}); ok {
|
|||
|
|
return jwtx.MapToJwtClaims(claimsMap)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil, ErrNoInCtx
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IsNoUserIdError 判断是否是未登录错误
|
|||
|
|
func IsNoUserIdError(err error) bool {
|
|||
|
|
return errors.Is(err, ErrNoInCtx)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IsInvalidUserIdError 判断是否是用户ID格式错误
|
|||
|
|
func IsInvalidUserIdError(err error) bool {
|
|||
|
|
return errors.Is(err, ErrInvalidUserId)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetPlatformFromCtx 从 context 中获取平台
|
|||
|
|
func GetPlatformFromCtx(ctx context.Context) (string, error) {
|
|||
|
|
platform, platformOk := ctx.Value("platform").(string)
|
|||
|
|
if !platformOk {
|
|||
|
|
return "", fmt.Errorf("平台不存在: %s", platform)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
switch platform {
|
|||
|
|
case model.PlatformWxMini:
|
|||
|
|
return model.PlatformWxMini, nil
|
|||
|
|
case model.PlatformWxH5:
|
|||
|
|
return model.PlatformWxH5, nil
|
|||
|
|
case model.PlatformApp:
|
|||
|
|
return model.PlatformApp, nil
|
|||
|
|
case model.PlatformH5:
|
|||
|
|
return model.PlatformH5, nil
|
|||
|
|
default:
|
|||
|
|
return "", fmt.Errorf("不支持的支付平台: %s", platform)
|
|||
|
|
}
|
|||
|
|
}
|