105 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			105 lines
		
	
	
		
			2.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package ctxdata
 | |
| 
 | |
| import (
 | |
| 	"context"
 | |
| 	"encoding/json"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"hm-server/app/main/model"
 | |
| 	jwtx "hm-server/common/jwt"
 | |
| )
 | |
| 
 | |
| const CtxKeyJwtUserId = "userId"
 | |
| 
 | |
| // 定义错误类型
 | |
| var (
 | |
| 	ErrNoInCtx       = errors.New("上下文中没有相关数据")
 | |
| 	ErrInvalidUserId = errors.New("用户ID格式无效") // 数据异常
 | |
| )
 | |
| 
 | |
| // GetUidFromCtx 从 context 中获取用户 ID
 | |
| func GetUidFromCtx(ctx context.Context) (int64, error) {
 | |
| 	// 尝试从上下文中获取 jwtUserId
 | |
| 	value := ctx.Value(CtxKeyJwtUserId)
 | |
| 	if value == nil {
 | |
| 		claims, err := GetClaimsFromCtx(ctx)
 | |
| 		if err != nil {
 | |
| 			return 0, err
 | |
| 		}
 | |
| 		return claims.UserId, nil
 | |
| 	}
 | |
| 
 | |
| 	// 根据值的类型进行不同处理
 | |
| 	switch v := value.(type) {
 | |
| 	case json.Number:
 | |
| 		// 如果是 json.Number 类型,转换为 int64
 | |
| 		uid, err := v.Int64()
 | |
| 		if err != nil {
 | |
| 			return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err)
 | |
| 		}
 | |
| 		return uid, nil
 | |
| 	case int64:
 | |
| 		// 如果已经是 int64 类型,直接返回
 | |
| 		return v, nil
 | |
| 	case float64:
 | |
| 		// 有些JSON解析器可能会将数字解析为float64
 | |
| 		return int64(v), nil
 | |
| 	case int:
 | |
| 		// 处理int类型
 | |
| 		return int64(v), nil
 | |
| 	default:
 | |
| 		// 其他类型都视为无效
 | |
| 		return 0, fmt.Errorf("%w: 期望类型 json.Number 或 int64, 实际类型 %T", ErrInvalidUserId, value)
 | |
| 	}
 | |
| }
 | |
| 
 | |
| func GetClaimsFromCtx(ctx context.Context) (*jwtx.JwtClaims, error) {
 | |
| 	value := ctx.Value(jwtx.ExtraKey)
 | |
| 	if value == nil {
 | |
| 		return nil, ErrNoInCtx
 | |
| 	}
 | |
| 
 | |
| 	// 首先尝试直接断言为 *jwtx.JwtClaims
 | |
| 	if claims, ok := value.(*jwtx.JwtClaims); ok {
 | |
| 		return claims, nil
 | |
| 	}
 | |
| 
 | |
| 	// 如果直接断言失败,尝试从 map[string]interface{} 中解析
 | |
| 	if claimsMap, ok := value.(map[string]interface{}); ok {
 | |
| 		return jwtx.MapToJwtClaims(claimsMap)
 | |
| 	}
 | |
| 
 | |
| 	return nil, ErrNoInCtx
 | |
| }
 | |
| 
 | |
| // IsNoUserIdError 判断是否是未登录错误
 | |
| func IsNoUserIdError(err error) bool {
 | |
| 	return errors.Is(err, ErrNoInCtx)
 | |
| }
 | |
| 
 | |
| // IsInvalidUserIdError 判断是否是用户ID格式错误
 | |
| func IsInvalidUserIdError(err error) bool {
 | |
| 	return errors.Is(err, ErrInvalidUserId)
 | |
| }
 | |
| 
 | |
| // GetPlatformFromCtx 从 context 中获取平台
 | |
| func GetPlatformFromCtx(ctx context.Context) (string, error) {
 | |
| 	platform, platformOk := ctx.Value("platform").(string)
 | |
| 	if !platformOk {
 | |
| 		return "", fmt.Errorf("平台不存在: %s", platform)
 | |
| 	}
 | |
| 
 | |
| 	switch platform {
 | |
| 	case model.PlatformWxMini:
 | |
| 		return model.PlatformWxMini, nil
 | |
| 	case model.PlatformWxH5:
 | |
| 		return model.PlatformWxH5, nil
 | |
| 	case model.PlatformApp:
 | |
| 		return model.PlatformApp, nil
 | |
| 	case model.PlatformH5:
 | |
| 		return model.PlatformH5, nil
 | |
| 	default:
 | |
| 		return "", fmt.Errorf("不支持的支付平台: %s", platform)
 | |
| 	}
 | |
| }
 |