first commit
This commit is contained in:
94
common/jwt/jwtx.go
Normal file
94
common/jwt/jwtx.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
const ExtraKey = "extra"
|
||||
|
||||
type JwtClaims struct {
|
||||
UserId int64 `json:"userId"`
|
||||
AgentId int64 `json:"agentId"`
|
||||
Platform string `json:"platform"`
|
||||
// 用户身份类型:0-临时用户,1-正式用户
|
||||
UserType int64 `json:"userType"`
|
||||
// 是否代理:0-否,1-是
|
||||
IsAgent int64 `json:"isAgent"`
|
||||
}
|
||||
|
||||
// MapToJwtClaims 将 map[string]interface{} 转换为 JwtClaims 结构体
|
||||
func MapToJwtClaims(claimsMap map[string]interface{}) (*JwtClaims, error) {
|
||||
// 使用JSON序列化/反序列化的方式自动转换
|
||||
jsonData, err := json.Marshal(claimsMap)
|
||||
if err != nil {
|
||||
return nil, errors.New("序列化claims失败")
|
||||
}
|
||||
|
||||
var claims JwtClaims
|
||||
if err := json.Unmarshal(jsonData, &claims); err != nil {
|
||||
return nil, errors.New("反序列化claims失败")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// GenerateJwtToken 生成JWT token
|
||||
func GenerateJwtToken(claims JwtClaims, secret string, expire int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// 将 claims 结构体转换为 map[string]interface{}
|
||||
claimsBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var claimsMap map[string]interface{}
|
||||
if err := json.Unmarshal(claimsBytes, &claimsMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jwtClaims := jwt.MapClaims{
|
||||
"exp": now + expire,
|
||||
"iat": now,
|
||||
"userId": claims.UserId,
|
||||
ExtraKey: claimsMap,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims)
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
func ParseJwtToken(tokenStr string, secret string) (*JwtClaims, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("invalid JWT")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("invalid JWT claims")
|
||||
}
|
||||
|
||||
extraInfo, exists := claims[ExtraKey]
|
||||
if !exists {
|
||||
return nil, errors.New("extra not found in JWT")
|
||||
}
|
||||
|
||||
// 尝试直接断言为 JwtClaims 结构体
|
||||
if jwtClaims, ok := extraInfo.(JwtClaims); ok {
|
||||
return &jwtClaims, nil
|
||||
}
|
||||
|
||||
// 尝试从 map[string]interface{} 中解析
|
||||
if claimsMap, ok := extraInfo.(map[string]interface{}); ok {
|
||||
return MapToJwtClaims(claimsMap)
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported extra type in JWT")
|
||||
}
|
||||
393
common/jwt/jwtx_test.go
Normal file
393
common/jwt/jwtx_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
func TestGenerateJwtToken(t *testing.T) {
|
||||
// 测试数据
|
||||
testClaims := JwtClaims{
|
||||
UserId: 1,
|
||||
AgentId: 0,
|
||||
Platform: "wxh5",
|
||||
UserType: 0,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
testExpire := int64(2592000) // 1小时
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims JwtClaims
|
||||
secret string
|
||||
expire int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常生成token",
|
||||
claims: testClaims,
|
||||
secret: testSecret,
|
||||
expire: testExpire,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "不同用户数据",
|
||||
claims: JwtClaims{
|
||||
UserId: 99999,
|
||||
AgentId: 11111,
|
||||
Platform: "mobile",
|
||||
UserType: 0,
|
||||
IsAgent: 1,
|
||||
},
|
||||
secret: testSecret,
|
||||
expire: testExpire,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空密钥",
|
||||
claims: testClaims,
|
||||
secret: "",
|
||||
expire: testExpire,
|
||||
wantErr: false, // 空密钥不会导致生成失败,但验证时会失败
|
||||
},
|
||||
{
|
||||
name: "零过期时间",
|
||||
claims: testClaims,
|
||||
secret: testSecret,
|
||||
expire: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "负数过期时间",
|
||||
claims: testClaims,
|
||||
secret: testSecret,
|
||||
expire: -3600,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := GenerateJwtToken(tt.claims, tt.secret, tt.expire)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateJwtToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// 验证token不为空
|
||||
if token == "" {
|
||||
t.Error("GenerateJwtToken() 返回的token为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token格式(JWT token应该包含两个点分隔符)
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("GenerateJwtToken() 返回的token格式不正确,期望3部分,实际%d部分", len(parts))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token可以被解析(不验证签名,只验证格式)
|
||||
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(tt.secret), nil
|
||||
})
|
||||
|
||||
if err == nil && parsedToken != nil {
|
||||
// 验证claims是否正确设置
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
||||
// 验证userId
|
||||
if userId, exists := claims["userId"]; exists {
|
||||
if int64(userId.(float64)) != tt.claims.UserId {
|
||||
t.Errorf("token中的userId不匹配,期望%d,实际%v", tt.claims.UserId, userId)
|
||||
}
|
||||
} else {
|
||||
t.Error("token中缺少userId字段")
|
||||
}
|
||||
|
||||
// 验证extra字段存在
|
||||
if _, exists := claims[ExtraKey]; !exists {
|
||||
t.Error("token中缺少extra字段")
|
||||
}
|
||||
|
||||
// 验证exp字段
|
||||
if exp, exists := claims["exp"]; exists {
|
||||
expTime := int64(exp.(float64))
|
||||
now := time.Now().Unix()
|
||||
expectedExp := now + tt.expire
|
||||
// 允许5秒的时间差异
|
||||
if expTime < expectedExp-5 || expTime > expectedExp+5 {
|
||||
t.Errorf("token过期时间不正确,期望约%d,实际%d", expectedExp, expTime)
|
||||
}
|
||||
} else {
|
||||
t.Error("token中缺少exp字段")
|
||||
}
|
||||
|
||||
// 验证iat字段
|
||||
if _, exists := claims["iat"]; !exists {
|
||||
t.Error("token中缺少iat字段")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("生成的token: %s", token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJwtTokenAndParse(t *testing.T) {
|
||||
// 测试生成token后能够正确解析
|
||||
testClaims := JwtClaims{
|
||||
UserId: 12345,
|
||||
AgentId: 67890,
|
||||
Platform: "web",
|
||||
UserType: 1,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "test-secret-key"
|
||||
testExpire := int64(3600)
|
||||
|
||||
// 生成token
|
||||
token, err := GenerateJwtToken(testClaims, testSecret, testExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJwtToken() failed: %v", err)
|
||||
}
|
||||
|
||||
// 解析token
|
||||
parsedClaims, err := ParseJwtToken(token, testSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseJwtToken() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证解析出的claims与原始claims一致
|
||||
if parsedClaims.UserId != testClaims.UserId {
|
||||
t.Errorf("UserId不匹配,期望%d,实际%d", testClaims.UserId, parsedClaims.UserId)
|
||||
}
|
||||
if parsedClaims.AgentId != testClaims.AgentId {
|
||||
t.Errorf("AgentId不匹配,期望%d,实际%d", testClaims.AgentId, parsedClaims.AgentId)
|
||||
}
|
||||
if parsedClaims.Platform != testClaims.Platform {
|
||||
t.Errorf("Platform不匹配,期望%s,实际%s", testClaims.Platform, parsedClaims.Platform)
|
||||
}
|
||||
if parsedClaims.UserType != testClaims.UserType {
|
||||
t.Errorf("UserType不匹配,期望%d,实际%d", testClaims.UserType, parsedClaims.UserType)
|
||||
}
|
||||
if parsedClaims.IsAgent != testClaims.IsAgent {
|
||||
t.Errorf("IsAgent不匹配,期望%d,实际%d", testClaims.IsAgent, parsedClaims.IsAgent)
|
||||
}
|
||||
|
||||
t.Logf("测试通过: 生成token并成功解析,claims数据一致")
|
||||
}
|
||||
|
||||
func BenchmarkGenerateJwtToken(t *testing.B) {
|
||||
// 性能测试
|
||||
testClaims := JwtClaims{
|
||||
UserId: 12345,
|
||||
AgentId: 67890,
|
||||
Platform: "web",
|
||||
UserType: 1,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "test-secret-key"
|
||||
testExpire := int64(3600)
|
||||
|
||||
t.ResetTimer()
|
||||
for i := 0; i < t.N; i++ {
|
||||
_, err := GenerateJwtToken(testClaims, testSecret, testExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJwtToken() failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJwtToken(t *testing.T) {
|
||||
// 使用你修改的测试数据
|
||||
testClaims := JwtClaims{
|
||||
UserId: 6,
|
||||
AgentId: 0,
|
||||
Platform: "wxh5",
|
||||
UserType: 0,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
testExpire := int64(2592000) // 30天
|
||||
|
||||
// 先生成一个token用于测试
|
||||
token, err := GenerateJwtToken(testClaims, testSecret, testExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("生成token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("生成的测试token: %s", token)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
secret string
|
||||
wantErr bool
|
||||
wantClaims *JwtClaims
|
||||
}{
|
||||
{
|
||||
name: "正常解析token",
|
||||
token: token,
|
||||
secret: testSecret,
|
||||
wantErr: false,
|
||||
wantClaims: &testClaims,
|
||||
},
|
||||
{
|
||||
name: "错误的密钥",
|
||||
token: token,
|
||||
secret: "wrong-secret",
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "空token",
|
||||
token: "",
|
||||
secret: testSecret,
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "无效token格式",
|
||||
token: "invalid.token.format",
|
||||
secret: testSecret,
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "缺少点分隔符的token",
|
||||
token: "invalidtoken",
|
||||
secret: testSecret,
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "自定义token",
|
||||
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NTI5MDA5MTQsImV4dHJhIjp7ImFnZW50SWQiOjAsImlzQWdlbnQiOjAsInBsYXRmb3JtIjoid3hoNSIsInVzZXJJZCI6NiwidXNlclR5cGUiOjF9LCJpYXQiOjE3NTAzMDg5MTQsInVzZXJJZCI6Nn0.GPKgLOaALOIa1ft7Hipuo4YKFf5guYt0rz2MCDCSdCQ",
|
||||
secret: testSecret,
|
||||
wantErr: false,
|
||||
wantClaims: &testClaims,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := ParseJwtToken(tt.token, tt.secret)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseJwtToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && tt.wantClaims != nil {
|
||||
if claims == nil {
|
||||
t.Error("ParseJwtToken() 返回的claims为nil")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证各个字段
|
||||
if claims.UserId != tt.wantClaims.UserId {
|
||||
t.Errorf("UserId不匹配,期望%d,实际%d", tt.wantClaims.UserId, claims.UserId)
|
||||
}
|
||||
if claims.AgentId != tt.wantClaims.AgentId {
|
||||
t.Errorf("AgentId不匹配,期望%d,实际%d", tt.wantClaims.AgentId, claims.AgentId)
|
||||
}
|
||||
if claims.Platform != tt.wantClaims.Platform {
|
||||
t.Errorf("Platform不匹配,期望%s,实际%s", tt.wantClaims.Platform, claims.Platform)
|
||||
}
|
||||
if claims.UserType != tt.wantClaims.UserType {
|
||||
t.Errorf("UserType不匹配,期望%d,实际%d", tt.wantClaims.UserType, claims.UserType)
|
||||
}
|
||||
if claims.IsAgent != tt.wantClaims.IsAgent {
|
||||
t.Errorf("IsAgent不匹配,期望%d,实际%d", tt.wantClaims.IsAgent, claims.IsAgent)
|
||||
}
|
||||
|
||||
t.Logf("解析成功的claims: UserId=%d, AgentId=%d, Platform=%s, UserType=%d, IsAgent=%d",
|
||||
claims.UserId, claims.AgentId, claims.Platform, claims.UserType, claims.IsAgent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseCustomJwtToken 测试解析自定义token - 你可以在这里传入你自己的token
|
||||
func TestParseCustomJwtToken(t *testing.T) {
|
||||
// 在这里修改你想要测试的token和secret
|
||||
customToken := "" // 在这里粘贴你的token
|
||||
customSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA" // 你的密钥
|
||||
|
||||
// 如果没有提供自定义token,跳过测试
|
||||
if customToken == "" {
|
||||
t.Skip("跳过自定义token测试,请在代码中设置customToken值")
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("解析自定义token: %s", customToken)
|
||||
|
||||
claims, err := ParseJwtToken(customToken, customSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("解析自定义token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("解析结果:")
|
||||
t.Logf(" UserId: %d", claims.UserId)
|
||||
t.Logf(" AgentId: %d", claims.AgentId)
|
||||
t.Logf(" Platform: %s", claims.Platform)
|
||||
t.Logf(" UserType: %d", claims.UserType)
|
||||
t.Logf(" IsAgent: %d", claims.IsAgent)
|
||||
}
|
||||
|
||||
// TestGenerateAndParseWithRealData 生成一个真实的token并解析
|
||||
func TestGenerateAndParseWithRealData(t *testing.T) {
|
||||
// 使用真实数据生成token
|
||||
realClaims := JwtClaims{
|
||||
UserId: 1,
|
||||
AgentId: 0,
|
||||
Platform: "wxh5",
|
||||
UserType: 0,
|
||||
IsAgent: 0,
|
||||
}
|
||||
realSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
realExpire := int64(2592000) // 30天
|
||||
|
||||
// 生成token
|
||||
token, err := GenerateJwtToken(realClaims, realSecret, realExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("生成token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("=== 生成的完整token ===")
|
||||
t.Logf("Token: %s", token)
|
||||
t.Logf("========================")
|
||||
|
||||
// 解析token
|
||||
parsedClaims, err := ParseJwtToken(token, realSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("解析token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("=== 解析结果 ===")
|
||||
t.Logf("UserId: %d", parsedClaims.UserId)
|
||||
t.Logf("AgentId: %d", parsedClaims.AgentId)
|
||||
t.Logf("Platform: %s", parsedClaims.Platform)
|
||||
t.Logf("UserType: %d", parsedClaims.UserType)
|
||||
t.Logf("IsAgent: %d", parsedClaims.IsAgent)
|
||||
t.Logf("================")
|
||||
|
||||
// 验证数据一致性
|
||||
if parsedClaims.UserId != realClaims.UserId ||
|
||||
parsedClaims.AgentId != realClaims.AgentId ||
|
||||
parsedClaims.Platform != realClaims.Platform ||
|
||||
parsedClaims.UserType != realClaims.UserType ||
|
||||
parsedClaims.IsAgent != realClaims.IsAgent {
|
||||
t.Error("解析出的claims与原始数据不一致")
|
||||
} else {
|
||||
t.Log("✅ 数据一致性验证通过")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user