first commit
This commit is contained in:
104
common/ctxdata/ctxData.go
Normal file
104
common/ctxdata/ctxData.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package ctxdata
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hm-server/app/main/model"
|
||||
jwtx "hm-server/common/jwt"
|
||||
)
|
||||
|
||||
const CtxKeyJwtUserId = "userId"
|
||||
|
||||
// 定义错误类型
|
||||
var (
|
||||
ErrNoInCtx = errors.New("上下文中没有相关数据")
|
||||
ErrInvalidUserId = errors.New("用户ID格式无效") // 数据异常
|
||||
)
|
||||
|
||||
// GetUidFromCtx 从 context 中获取用户 ID
|
||||
func GetUidFromCtx(ctx context.Context) (int64, error) {
|
||||
// 尝试从上下文中获取 jwtUserId
|
||||
value := ctx.Value(CtxKeyJwtUserId)
|
||||
if value == nil {
|
||||
claims, err := GetClaimsFromCtx(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return claims.UserId, nil
|
||||
}
|
||||
|
||||
// 根据值的类型进行不同处理
|
||||
switch v := value.(type) {
|
||||
case json.Number:
|
||||
// 如果是 json.Number 类型,转换为 int64
|
||||
uid, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("%w: %v", ErrInvalidUserId, err)
|
||||
}
|
||||
return uid, nil
|
||||
case int64:
|
||||
// 如果已经是 int64 类型,直接返回
|
||||
return v, nil
|
||||
case float64:
|
||||
// 有些JSON解析器可能会将数字解析为float64
|
||||
return int64(v), nil
|
||||
case int:
|
||||
// 处理int类型
|
||||
return int64(v), nil
|
||||
default:
|
||||
// 其他类型都视为无效
|
||||
return 0, fmt.Errorf("%w: 期望类型 json.Number 或 int64, 实际类型 %T", ErrInvalidUserId, value)
|
||||
}
|
||||
}
|
||||
|
||||
func GetClaimsFromCtx(ctx context.Context) (*jwtx.JwtClaims, error) {
|
||||
value := ctx.Value(jwtx.ExtraKey)
|
||||
if value == nil {
|
||||
return nil, ErrNoInCtx
|
||||
}
|
||||
|
||||
// 首先尝试直接断言为 *jwtx.JwtClaims
|
||||
if claims, ok := value.(*jwtx.JwtClaims); ok {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// 如果直接断言失败,尝试从 map[string]interface{} 中解析
|
||||
if claimsMap, ok := value.(map[string]interface{}); ok {
|
||||
return jwtx.MapToJwtClaims(claimsMap)
|
||||
}
|
||||
|
||||
return nil, ErrNoInCtx
|
||||
}
|
||||
|
||||
// IsNoUserIdError 判断是否是未登录错误
|
||||
func IsNoUserIdError(err error) bool {
|
||||
return errors.Is(err, ErrNoInCtx)
|
||||
}
|
||||
|
||||
// IsInvalidUserIdError 判断是否是用户ID格式错误
|
||||
func IsInvalidUserIdError(err error) bool {
|
||||
return errors.Is(err, ErrInvalidUserId)
|
||||
}
|
||||
|
||||
// GetPlatformFromCtx 从 context 中获取平台
|
||||
func GetPlatformFromCtx(ctx context.Context) (string, error) {
|
||||
platform, platformOk := ctx.Value("platform").(string)
|
||||
if !platformOk {
|
||||
return "", fmt.Errorf("平台不存在: %s", platform)
|
||||
}
|
||||
|
||||
switch platform {
|
||||
case model.PlatformWxMini:
|
||||
return model.PlatformWxMini, nil
|
||||
case model.PlatformWxH5:
|
||||
return model.PlatformWxH5, nil
|
||||
case model.PlatformApp:
|
||||
return model.PlatformApp, nil
|
||||
case model.PlatformH5:
|
||||
return model.PlatformH5, nil
|
||||
default:
|
||||
return "", fmt.Errorf("不支持的支付平台: %s", platform)
|
||||
}
|
||||
}
|
||||
14
common/globalkey/constantKey.go
Normal file
14
common/globalkey/constantKey.go
Normal file
@@ -0,0 +1,14 @@
|
||||
package globalkey
|
||||
|
||||
/**
|
||||
global constant key
|
||||
*/
|
||||
|
||||
//软删除
|
||||
var DelStateNo int64 = 0 //未删除
|
||||
var DelStateYes int64 = 1 //已删除
|
||||
|
||||
//时间格式化模版
|
||||
var DateTimeFormatTplStandardDateTime = "Y-m-d H:i:s"
|
||||
var DateTimeFormatTplStandardDate = "Y-m-d"
|
||||
var DateTimeFormatTplStandardTime = "H:i:s"
|
||||
9
common/globalkey/redisCacheKey.go
Normal file
9
common/globalkey/redisCacheKey.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package globalkey
|
||||
|
||||
/**
|
||||
redis key except "model cache key" in here,
|
||||
but "model cache key" in model
|
||||
*/
|
||||
|
||||
// CacheUserTokenKey /** 用户登陆的token
|
||||
const CacheUserTokenKey = "user_token:%d"
|
||||
39
common/interceptor/rpcserver/loggerInterceptor.go
Normal file
39
common/interceptor/rpcserver/loggerInterceptor.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package rpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"hm-server/common/xerr"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
/**
|
||||
* @Description rpc service logger interceptor
|
||||
* @Author Mikael
|
||||
* @Date 2021/1/9 13:35
|
||||
* @Version 1.0
|
||||
**/
|
||||
|
||||
func LoggerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
|
||||
|
||||
resp, err = handler(ctx, req)
|
||||
if err != nil {
|
||||
causeErr := errors.Cause(err) // err类型
|
||||
if e, ok := causeErr.(*xerr.CodeError); ok { //自定义错误类型
|
||||
logx.WithContext(ctx).Errorf("【RPC-SRV-ERR】 %v", err)
|
||||
|
||||
//转成grpc err
|
||||
err = status.Error(codes.Code(e.GetErrCode()), e.GetErrMsg())
|
||||
} else {
|
||||
logx.WithContext(ctx).Errorf("【RPC-SRV-ERR】 %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
94
common/jwt/jwtx.go
Normal file
94
common/jwt/jwtx.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
const ExtraKey = "extra"
|
||||
|
||||
type JwtClaims struct {
|
||||
UserId int64 `json:"userId"`
|
||||
AgentId int64 `json:"agentId"`
|
||||
Platform string `json:"platform"`
|
||||
// 用户身份类型:0-临时用户,1-正式用户
|
||||
UserType int64 `json:"userType"`
|
||||
// 是否代理:0-否,1-是
|
||||
IsAgent int64 `json:"isAgent"`
|
||||
}
|
||||
|
||||
// MapToJwtClaims 将 map[string]interface{} 转换为 JwtClaims 结构体
|
||||
func MapToJwtClaims(claimsMap map[string]interface{}) (*JwtClaims, error) {
|
||||
// 使用JSON序列化/反序列化的方式自动转换
|
||||
jsonData, err := json.Marshal(claimsMap)
|
||||
if err != nil {
|
||||
return nil, errors.New("序列化claims失败")
|
||||
}
|
||||
|
||||
var claims JwtClaims
|
||||
if err := json.Unmarshal(jsonData, &claims); err != nil {
|
||||
return nil, errors.New("反序列化claims失败")
|
||||
}
|
||||
|
||||
return &claims, nil
|
||||
}
|
||||
|
||||
// GenerateJwtToken 生成JWT token
|
||||
func GenerateJwtToken(claims JwtClaims, secret string, expire int64) (string, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
// 将 claims 结构体转换为 map[string]interface{}
|
||||
claimsBytes, err := json.Marshal(claims)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var claimsMap map[string]interface{}
|
||||
if err := json.Unmarshal(claimsBytes, &claimsMap); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
jwtClaims := jwt.MapClaims{
|
||||
"exp": now + expire,
|
||||
"iat": now,
|
||||
"userId": claims.UserId,
|
||||
ExtraKey: claimsMap,
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwtClaims)
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
func ParseJwtToken(tokenStr string, secret string) (*JwtClaims, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return nil, errors.New("invalid JWT")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, errors.New("invalid JWT claims")
|
||||
}
|
||||
|
||||
extraInfo, exists := claims[ExtraKey]
|
||||
if !exists {
|
||||
return nil, errors.New("extra not found in JWT")
|
||||
}
|
||||
|
||||
// 尝试直接断言为 JwtClaims 结构体
|
||||
if jwtClaims, ok := extraInfo.(JwtClaims); ok {
|
||||
return &jwtClaims, nil
|
||||
}
|
||||
|
||||
// 尝试从 map[string]interface{} 中解析
|
||||
if claimsMap, ok := extraInfo.(map[string]interface{}); ok {
|
||||
return MapToJwtClaims(claimsMap)
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported extra type in JWT")
|
||||
}
|
||||
393
common/jwt/jwtx_test.go
Normal file
393
common/jwt/jwtx_test.go
Normal file
@@ -0,0 +1,393 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
)
|
||||
|
||||
func TestGenerateJwtToken(t *testing.T) {
|
||||
// 测试数据
|
||||
testClaims := JwtClaims{
|
||||
UserId: 1,
|
||||
AgentId: 0,
|
||||
Platform: "wxh5",
|
||||
UserType: 0,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
testExpire := int64(2592000) // 1小时
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
claims JwtClaims
|
||||
secret string
|
||||
expire int64
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常生成token",
|
||||
claims: testClaims,
|
||||
secret: testSecret,
|
||||
expire: testExpire,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "不同用户数据",
|
||||
claims: JwtClaims{
|
||||
UserId: 99999,
|
||||
AgentId: 11111,
|
||||
Platform: "mobile",
|
||||
UserType: 0,
|
||||
IsAgent: 1,
|
||||
},
|
||||
secret: testSecret,
|
||||
expire: testExpire,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空密钥",
|
||||
claims: testClaims,
|
||||
secret: "",
|
||||
expire: testExpire,
|
||||
wantErr: false, // 空密钥不会导致生成失败,但验证时会失败
|
||||
},
|
||||
{
|
||||
name: "零过期时间",
|
||||
claims: testClaims,
|
||||
secret: testSecret,
|
||||
expire: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "负数过期时间",
|
||||
claims: testClaims,
|
||||
secret: testSecret,
|
||||
expire: -3600,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, err := GenerateJwtToken(tt.claims, tt.secret, tt.expire)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("GenerateJwtToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
// 验证token不为空
|
||||
if token == "" {
|
||||
t.Error("GenerateJwtToken() 返回的token为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token格式(JWT token应该包含两个点分隔符)
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
t.Errorf("GenerateJwtToken() 返回的token格式不正确,期望3部分,实际%d部分", len(parts))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token可以被解析(不验证签名,只验证格式)
|
||||
parsedToken, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(tt.secret), nil
|
||||
})
|
||||
|
||||
if err == nil && parsedToken != nil {
|
||||
// 验证claims是否正确设置
|
||||
if claims, ok := parsedToken.Claims.(jwt.MapClaims); ok {
|
||||
// 验证userId
|
||||
if userId, exists := claims["userId"]; exists {
|
||||
if int64(userId.(float64)) != tt.claims.UserId {
|
||||
t.Errorf("token中的userId不匹配,期望%d,实际%v", tt.claims.UserId, userId)
|
||||
}
|
||||
} else {
|
||||
t.Error("token中缺少userId字段")
|
||||
}
|
||||
|
||||
// 验证extra字段存在
|
||||
if _, exists := claims[ExtraKey]; !exists {
|
||||
t.Error("token中缺少extra字段")
|
||||
}
|
||||
|
||||
// 验证exp字段
|
||||
if exp, exists := claims["exp"]; exists {
|
||||
expTime := int64(exp.(float64))
|
||||
now := time.Now().Unix()
|
||||
expectedExp := now + tt.expire
|
||||
// 允许5秒的时间差异
|
||||
if expTime < expectedExp-5 || expTime > expectedExp+5 {
|
||||
t.Errorf("token过期时间不正确,期望约%d,实际%d", expectedExp, expTime)
|
||||
}
|
||||
} else {
|
||||
t.Error("token中缺少exp字段")
|
||||
}
|
||||
|
||||
// 验证iat字段
|
||||
if _, exists := claims["iat"]; !exists {
|
||||
t.Error("token中缺少iat字段")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("生成的token: %s", token)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateJwtTokenAndParse(t *testing.T) {
|
||||
// 测试生成token后能够正确解析
|
||||
testClaims := JwtClaims{
|
||||
UserId: 12345,
|
||||
AgentId: 67890,
|
||||
Platform: "web",
|
||||
UserType: 1,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "test-secret-key"
|
||||
testExpire := int64(3600)
|
||||
|
||||
// 生成token
|
||||
token, err := GenerateJwtToken(testClaims, testSecret, testExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJwtToken() failed: %v", err)
|
||||
}
|
||||
|
||||
// 解析token
|
||||
parsedClaims, err := ParseJwtToken(token, testSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("ParseJwtToken() failed: %v", err)
|
||||
}
|
||||
|
||||
// 验证解析出的claims与原始claims一致
|
||||
if parsedClaims.UserId != testClaims.UserId {
|
||||
t.Errorf("UserId不匹配,期望%d,实际%d", testClaims.UserId, parsedClaims.UserId)
|
||||
}
|
||||
if parsedClaims.AgentId != testClaims.AgentId {
|
||||
t.Errorf("AgentId不匹配,期望%d,实际%d", testClaims.AgentId, parsedClaims.AgentId)
|
||||
}
|
||||
if parsedClaims.Platform != testClaims.Platform {
|
||||
t.Errorf("Platform不匹配,期望%s,实际%s", testClaims.Platform, parsedClaims.Platform)
|
||||
}
|
||||
if parsedClaims.UserType != testClaims.UserType {
|
||||
t.Errorf("UserType不匹配,期望%d,实际%d", testClaims.UserType, parsedClaims.UserType)
|
||||
}
|
||||
if parsedClaims.IsAgent != testClaims.IsAgent {
|
||||
t.Errorf("IsAgent不匹配,期望%d,实际%d", testClaims.IsAgent, parsedClaims.IsAgent)
|
||||
}
|
||||
|
||||
t.Logf("测试通过: 生成token并成功解析,claims数据一致")
|
||||
}
|
||||
|
||||
func BenchmarkGenerateJwtToken(t *testing.B) {
|
||||
// 性能测试
|
||||
testClaims := JwtClaims{
|
||||
UserId: 12345,
|
||||
AgentId: 67890,
|
||||
Platform: "web",
|
||||
UserType: 1,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "test-secret-key"
|
||||
testExpire := int64(3600)
|
||||
|
||||
t.ResetTimer()
|
||||
for i := 0; i < t.N; i++ {
|
||||
_, err := GenerateJwtToken(testClaims, testSecret, testExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateJwtToken() failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseJwtToken(t *testing.T) {
|
||||
// 使用你修改的测试数据
|
||||
testClaims := JwtClaims{
|
||||
UserId: 6,
|
||||
AgentId: 0,
|
||||
Platform: "wxh5",
|
||||
UserType: 0,
|
||||
IsAgent: 0,
|
||||
}
|
||||
testSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
testExpire := int64(2592000) // 30天
|
||||
|
||||
// 先生成一个token用于测试
|
||||
token, err := GenerateJwtToken(testClaims, testSecret, testExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("生成token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("生成的测试token: %s", token)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
secret string
|
||||
wantErr bool
|
||||
wantClaims *JwtClaims
|
||||
}{
|
||||
{
|
||||
name: "正常解析token",
|
||||
token: token,
|
||||
secret: testSecret,
|
||||
wantErr: false,
|
||||
wantClaims: &testClaims,
|
||||
},
|
||||
{
|
||||
name: "错误的密钥",
|
||||
token: token,
|
||||
secret: "wrong-secret",
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "空token",
|
||||
token: "",
|
||||
secret: testSecret,
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "无效token格式",
|
||||
token: "invalid.token.format",
|
||||
secret: testSecret,
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "缺少点分隔符的token",
|
||||
token: "invalidtoken",
|
||||
secret: testSecret,
|
||||
wantErr: true,
|
||||
wantClaims: nil,
|
||||
},
|
||||
{
|
||||
name: "自定义token",
|
||||
token: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOjE3NTI5MDA5MTQsImV4dHJhIjp7ImFnZW50SWQiOjAsImlzQWdlbnQiOjAsInBsYXRmb3JtIjoid3hoNSIsInVzZXJJZCI6NiwidXNlclR5cGUiOjF9LCJpYXQiOjE3NTAzMDg5MTQsInVzZXJJZCI6Nn0.GPKgLOaALOIa1ft7Hipuo4YKFf5guYt0rz2MCDCSdCQ",
|
||||
secret: testSecret,
|
||||
wantErr: false,
|
||||
wantClaims: &testClaims,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
claims, err := ParseJwtToken(tt.token, tt.secret)
|
||||
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("ParseJwtToken() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && tt.wantClaims != nil {
|
||||
if claims == nil {
|
||||
t.Error("ParseJwtToken() 返回的claims为nil")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证各个字段
|
||||
if claims.UserId != tt.wantClaims.UserId {
|
||||
t.Errorf("UserId不匹配,期望%d,实际%d", tt.wantClaims.UserId, claims.UserId)
|
||||
}
|
||||
if claims.AgentId != tt.wantClaims.AgentId {
|
||||
t.Errorf("AgentId不匹配,期望%d,实际%d", tt.wantClaims.AgentId, claims.AgentId)
|
||||
}
|
||||
if claims.Platform != tt.wantClaims.Platform {
|
||||
t.Errorf("Platform不匹配,期望%s,实际%s", tt.wantClaims.Platform, claims.Platform)
|
||||
}
|
||||
if claims.UserType != tt.wantClaims.UserType {
|
||||
t.Errorf("UserType不匹配,期望%d,实际%d", tt.wantClaims.UserType, claims.UserType)
|
||||
}
|
||||
if claims.IsAgent != tt.wantClaims.IsAgent {
|
||||
t.Errorf("IsAgent不匹配,期望%d,实际%d", tt.wantClaims.IsAgent, claims.IsAgent)
|
||||
}
|
||||
|
||||
t.Logf("解析成功的claims: UserId=%d, AgentId=%d, Platform=%s, UserType=%d, IsAgent=%d",
|
||||
claims.UserId, claims.AgentId, claims.Platform, claims.UserType, claims.IsAgent)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseCustomJwtToken 测试解析自定义token - 你可以在这里传入你自己的token
|
||||
func TestParseCustomJwtToken(t *testing.T) {
|
||||
// 在这里修改你想要测试的token和secret
|
||||
customToken := "" // 在这里粘贴你的token
|
||||
customSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA" // 你的密钥
|
||||
|
||||
// 如果没有提供自定义token,跳过测试
|
||||
if customToken == "" {
|
||||
t.Skip("跳过自定义token测试,请在代码中设置customToken值")
|
||||
return
|
||||
}
|
||||
|
||||
t.Logf("解析自定义token: %s", customToken)
|
||||
|
||||
claims, err := ParseJwtToken(customToken, customSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("解析自定义token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("解析结果:")
|
||||
t.Logf(" UserId: %d", claims.UserId)
|
||||
t.Logf(" AgentId: %d", claims.AgentId)
|
||||
t.Logf(" Platform: %s", claims.Platform)
|
||||
t.Logf(" UserType: %d", claims.UserType)
|
||||
t.Logf(" IsAgent: %d", claims.IsAgent)
|
||||
}
|
||||
|
||||
// TestGenerateAndParseWithRealData 生成一个真实的token并解析
|
||||
func TestGenerateAndParseWithRealData(t *testing.T) {
|
||||
// 使用真实数据生成token
|
||||
realClaims := JwtClaims{
|
||||
UserId: 1,
|
||||
AgentId: 0,
|
||||
Platform: "wxh5",
|
||||
UserType: 0,
|
||||
IsAgent: 0,
|
||||
}
|
||||
realSecret := "WUvoIwL-FK0qnlxhvxR9tV6SjfOpeJMpKmY2QvT99lA"
|
||||
realExpire := int64(2592000) // 30天
|
||||
|
||||
// 生成token
|
||||
token, err := GenerateJwtToken(realClaims, realSecret, realExpire)
|
||||
if err != nil {
|
||||
t.Fatalf("生成token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("=== 生成的完整token ===")
|
||||
t.Logf("Token: %s", token)
|
||||
t.Logf("========================")
|
||||
|
||||
// 解析token
|
||||
parsedClaims, err := ParseJwtToken(token, realSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("解析token失败: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("=== 解析结果 ===")
|
||||
t.Logf("UserId: %d", parsedClaims.UserId)
|
||||
t.Logf("AgentId: %d", parsedClaims.AgentId)
|
||||
t.Logf("Platform: %s", parsedClaims.Platform)
|
||||
t.Logf("UserType: %d", parsedClaims.UserType)
|
||||
t.Logf("IsAgent: %d", parsedClaims.IsAgent)
|
||||
t.Logf("================")
|
||||
|
||||
// 验证数据一致性
|
||||
if parsedClaims.UserId != realClaims.UserId ||
|
||||
parsedClaims.AgentId != realClaims.AgentId ||
|
||||
parsedClaims.Platform != realClaims.Platform ||
|
||||
parsedClaims.UserType != realClaims.UserType ||
|
||||
parsedClaims.IsAgent != realClaims.IsAgent {
|
||||
t.Error("解析出的claims与原始数据不一致")
|
||||
} else {
|
||||
t.Log("✅ 数据一致性验证通过")
|
||||
}
|
||||
}
|
||||
8
common/kqueue/message.go
Normal file
8
common/kqueue/message.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//KqMessage
|
||||
package kqueue
|
||||
|
||||
//第三方支付回调更改支付状态通知
|
||||
type ThirdPaymentUpdatePayStatusNotifyMessage struct {
|
||||
PayStatus int64 `json:"payStatus"`
|
||||
OrderSn string `json:"orderSn"`
|
||||
}
|
||||
31
common/middleware/commonJwtAuthMiddleware.go
Normal file
31
common/middleware/commonJwtAuthMiddleware.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/rest/handler"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// CommonJwtAuthMiddleware : with jwt on the verification, no jwt on the verification
|
||||
type CommonJwtAuthMiddleware struct {
|
||||
secret string
|
||||
}
|
||||
|
||||
func NewCommonJwtAuthMiddleware(secret string) *CommonJwtAuthMiddleware {
|
||||
return &CommonJwtAuthMiddleware{
|
||||
secret: secret,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *CommonJwtAuthMiddleware) Handle(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if len(r.Header.Get("Authorization")) > 0 {
|
||||
//has jwt Authorization
|
||||
authHandler := handler.Authorize(m.secret)
|
||||
authHandler(next).ServeHTTP(w, r)
|
||||
return
|
||||
} else {
|
||||
//no jwt Authorization
|
||||
next(w, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
89
common/result/httpResult.go
Normal file
89
common/result/httpResult.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package result
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"hm-server/common/xerr"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// http返回
|
||||
func HttpResult(r *http.Request, w http.ResponseWriter, resp interface{}, err error) {
|
||||
|
||||
if err == nil {
|
||||
httpx.WriteJson(w, http.StatusOK, Success(resp))
|
||||
} else {
|
||||
//错误返回
|
||||
errcode := xerr.SERVER_COMMON_ERROR
|
||||
errmsg := "服务器开小差啦,稍后再来试一试"
|
||||
|
||||
causeErr := errors.Cause(err) // err类型
|
||||
if e, ok := causeErr.(*xerr.CodeError); ok { //自定义错误类型
|
||||
//自定义CodeError
|
||||
errcode = e.GetErrCode()
|
||||
errmsg = e.GetErrMsg()
|
||||
} else {
|
||||
if gstatus, ok := status.FromError(causeErr); ok { // grpc err错误
|
||||
grpcCode := uint32(gstatus.Code())
|
||||
if xerr.IsCodeErr(grpcCode) { //区分自定义错误跟系统底层、db等错误,底层、db错误不能返回给前端
|
||||
errcode = grpcCode
|
||||
errmsg = gstatus.Message()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logx.WithContext(r.Context()).Errorf("【API-ERR】 : %+v ", err)
|
||||
|
||||
httpx.WriteJson(w, http.StatusOK, Error(errcode, errmsg))
|
||||
}
|
||||
}
|
||||
|
||||
// 授权的http方法
|
||||
func AuthHttpResult(r *http.Request, w http.ResponseWriter, resp interface{}, err error) {
|
||||
|
||||
if err == nil {
|
||||
//成功返回
|
||||
r := Success(resp)
|
||||
httpx.WriteJson(w, http.StatusOK, r)
|
||||
} else {
|
||||
//错误返回
|
||||
errcode := xerr.SERVER_COMMON_ERROR
|
||||
errmsg := "服务器开小差啦,稍后再来试一试"
|
||||
|
||||
causeErr := errors.Cause(err) // err类型
|
||||
if e, ok := causeErr.(*xerr.CodeError); ok { //自定义错误类型
|
||||
//自定义CodeError
|
||||
errcode = e.GetErrCode()
|
||||
errmsg = e.GetErrMsg()
|
||||
} else {
|
||||
if gstatus, ok := status.FromError(causeErr); ok { // grpc err错误
|
||||
grpcCode := uint32(gstatus.Code())
|
||||
if xerr.IsCodeErr(grpcCode) { //区分自定义错误跟系统底层、db等错误,底层、db错误不能返回给前端
|
||||
errcode = grpcCode
|
||||
errmsg = gstatus.Message()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logx.WithContext(r.Context()).Errorf("【GATEWAY-ERR】 : %+v ", err)
|
||||
|
||||
httpx.WriteJson(w, http.StatusUnauthorized, Error(errcode, errmsg))
|
||||
}
|
||||
}
|
||||
|
||||
// http 参数错误返回
|
||||
func ParamErrorResult(r *http.Request, w http.ResponseWriter, err error) {
|
||||
errMsg := fmt.Sprintf("%s,%s", xerr.MapErrMsg(xerr.REUQEST_PARAM_ERROR), err.Error())
|
||||
httpx.WriteJson(w, http.StatusOK, Error(xerr.REUQEST_PARAM_ERROR, errMsg))
|
||||
}
|
||||
|
||||
// http 参数校验失败返回
|
||||
func ParamValidateErrorResult(r *http.Request, w http.ResponseWriter, err error) {
|
||||
//errMsg := fmt.Sprintf("%s,%s", xerr.MapErrMsg(xerr.REUQEST_PARAM_ERROR), err.Error())
|
||||
httpx.WriteJson(w, http.StatusOK, Error(xerr.PARAM_VERIFICATION_ERROR, err.Error()))
|
||||
}
|
||||
44
common/result/jobResult.go
Normal file
44
common/result/jobResult.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package result
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"hm-server/common/xerr"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// job返回
|
||||
func JobResult(ctx context.Context, resp interface{}, err error) {
|
||||
if err == nil {
|
||||
// 成功返回 ,只有dev环境下才会打印info,线上不显示
|
||||
if resp != nil {
|
||||
logx.Infof("resp: %+v", resp)
|
||||
}
|
||||
return
|
||||
} else {
|
||||
errCode := xerr.SERVER_COMMON_ERROR
|
||||
errMsg := "服务器开小差啦,稍后再来试一试"
|
||||
|
||||
// 错误返回
|
||||
causeErr := errors.Cause(err) // err类型
|
||||
if e, ok := causeErr.(*xerr.CodeError); ok { // 自定义错误类型
|
||||
// 自定义CodeError
|
||||
errCode = e.GetErrCode()
|
||||
errMsg = e.GetErrMsg()
|
||||
} else {
|
||||
if gstatus, ok := status.FromError(causeErr); ok { // grpc err错误
|
||||
grpcCode := uint32(gstatus.Code())
|
||||
if xerr.IsCodeErr(grpcCode) { // 区分自定义错误跟系统底层、db等错误,底层、db错误不能返回给前端
|
||||
errCode = grpcCode
|
||||
errMsg = gstatus.Message()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logx.WithContext(ctx).Errorf("【JOB-ERR】 : %+v ,errCode:%d , errMsg:%s ", err, errCode, errMsg)
|
||||
return
|
||||
}
|
||||
}
|
||||
21
common/result/responseBean.go
Normal file
21
common/result/responseBean.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package result
|
||||
|
||||
type ResponseSuccessBean struct {
|
||||
Code uint32 `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
type NullJson struct{}
|
||||
|
||||
func Success(data interface{}) *ResponseSuccessBean {
|
||||
return &ResponseSuccessBean{200, "OK", data}
|
||||
}
|
||||
|
||||
type ResponseErrorBean struct {
|
||||
Code uint32 `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
}
|
||||
|
||||
func Error(errCode uint32, errMsg string) *ResponseErrorBean {
|
||||
return &ResponseErrorBean{errCode, errMsg}
|
||||
}
|
||||
19
common/tool/coinconvert.go
Normal file
19
common/tool/coinconvert.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package tool
|
||||
|
||||
import "github.com/shopspring/decimal"
|
||||
|
||||
var oneHundredDecimal decimal.Decimal = decimal.NewFromInt(100)
|
||||
|
||||
//分转元
|
||||
func Fen2Yuan(fen int64) float64 {
|
||||
y, _ := decimal.NewFromInt(fen).Div(oneHundredDecimal).Truncate(2).Float64()
|
||||
return y
|
||||
}
|
||||
|
||||
//元转分
|
||||
func Yuan2Fen(yuan float64) int64 {
|
||||
|
||||
f, _ := decimal.NewFromFloat(yuan).Mul(oneHundredDecimal).Truncate(0).Float64()
|
||||
return int64(f)
|
||||
|
||||
}
|
||||
23
common/tool/encryption.go
Normal file
23
common/tool/encryption.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
/** 加密方式 **/
|
||||
|
||||
func Md5ByString(str string) string {
|
||||
m := md5.New()
|
||||
_, err := io.WriteString(m, str)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
arr := m.Sum(nil)
|
||||
return fmt.Sprintf("%x", arr)
|
||||
}
|
||||
|
||||
func Md5ByBytes(b []byte) string {
|
||||
return fmt.Sprintf("%x", md5.Sum(b))
|
||||
}
|
||||
28
common/tool/krand.go
Normal file
28
common/tool/krand.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package tool
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
KC_RAND_KIND_NUM = 0 // 纯数字
|
||||
KC_RAND_KIND_LOWER = 1 // 小写字母
|
||||
KC_RAND_KIND_UPPER = 2 // 大写字母
|
||||
KC_RAND_KIND_ALL = 3 // 数字、大小写字母
|
||||
)
|
||||
|
||||
// 随机字符串
|
||||
func Krand(size int, kind int) string {
|
||||
ikind, kinds, result := kind, [][]int{[]int{10, 48}, []int{26, 97}, []int{26, 65}}, make([]byte, size)
|
||||
is_all := kind > 2 || kind < 0
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
for i := 0; i < size; i++ {
|
||||
if is_all { // random ikind
|
||||
ikind = rand.Intn(3)
|
||||
}
|
||||
scope, base := kinds[ikind][0], kinds[ikind][1]
|
||||
result[i] = uint8(base + rand.Intn(scope))
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
8
common/tool/krand_test.go
Normal file
8
common/tool/krand_test.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package tool
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMd5ByString(t *testing.T) {
|
||||
s := Md5ByString("AAA")
|
||||
t.Log(s)
|
||||
}
|
||||
15
common/tool/placeholders.go
Normal file
15
common/tool/placeholders.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package tool
|
||||
|
||||
import "strings"
|
||||
|
||||
//替换
|
||||
func InPlaceholders(n int) string {
|
||||
var b strings.Builder
|
||||
for i := 0; i < n-1; i++ {
|
||||
b.WriteString("?,")
|
||||
}
|
||||
if n > 0 {
|
||||
b.WriteString("?")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
20
common/uniqueid/sn.go
Normal file
20
common/uniqueid/sn.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package uniqueid
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hm-server/common/tool"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 生成sn单号
|
||||
type SnPrefix string
|
||||
|
||||
const (
|
||||
SN_PREFIX_HOMESTAY_ORDER SnPrefix = "HSO" //民宿订单前缀 hm-server_order/homestay_order
|
||||
SN_PREFIX_THIRD_PAYMENT SnPrefix = "PMT" //第三方支付流水记录前缀 hm-server_payment/third_payment
|
||||
)
|
||||
|
||||
// 生成单号
|
||||
func GenSn(snPrefix SnPrefix) string {
|
||||
return fmt.Sprintf("%s%s%s", snPrefix, time.Now().Format("20060102150405"), tool.Krand(8, tool.KC_RAND_KIND_NUM))
|
||||
}
|
||||
7
common/uniqueid/sn_test.go
Normal file
7
common/uniqueid/sn_test.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package uniqueid
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGenSn(t *testing.T) {
|
||||
GenSn(SN_PREFIX_HOMESTAY_ORDER)
|
||||
}
|
||||
23
common/uniqueid/uniqueid.go
Normal file
23
common/uniqueid/uniqueid.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package uniqueid
|
||||
|
||||
import (
|
||||
"github.com/sony/sonyflake"
|
||||
"github.com/zeromicro/go-zero/core/logx"
|
||||
)
|
||||
|
||||
var flake *sonyflake.Sonyflake
|
||||
|
||||
func init() {
|
||||
flake = sonyflake.NewSonyflake(sonyflake.Settings{})
|
||||
}
|
||||
|
||||
func GenId() int64 {
|
||||
|
||||
id, err := flake.NextID()
|
||||
if err != nil {
|
||||
logx.Severef("flake NextID failed with %s \n", err)
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return int64(id)
|
||||
}
|
||||
7
common/wxminisub/tpl.go
Normal file
7
common/wxminisub/tpl.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package wxminisub
|
||||
|
||||
//订单支付成功
|
||||
const OrderPaySuccessTemplateID = "QIJPmfxaNqYzSjOlXGk1T6Xfw94JwbSPuOd3u_hi3WE"
|
||||
|
||||
//支付成功入驻通知
|
||||
const OrderPaySuccessLiveKnowTemplateID = "kmm-maRr6v_9eMxEPpj-5clJ2YW_EFpd8-ngyYk63e4"
|
||||
23
common/xerr/errCode.go
Normal file
23
common/xerr/errCode.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package xerr
|
||||
|
||||
// 成功返回
|
||||
const OK uint32 = 200
|
||||
|
||||
/**(前3位代表业务,后三位代表具体功能)**/
|
||||
|
||||
// 全局错误码
|
||||
const SERVER_COMMON_ERROR uint32 = 100001
|
||||
const REUQEST_PARAM_ERROR uint32 = 100002
|
||||
const TOKEN_EXPIRE_ERROR uint32 = 100003
|
||||
const TOKEN_GENERATE_ERROR uint32 = 100004
|
||||
const DB_ERROR uint32 = 100005
|
||||
const DB_UPDATE_AFFECTED_ZERO_ERROR uint32 = 100006
|
||||
const PARAM_VERIFICATION_ERROR uint32 = 100007
|
||||
const CUSTOM_ERROR uint32 = 100008
|
||||
const USER_NOT_FOUND uint32 = 100009
|
||||
const USER_NEED_BIND_MOBILE uint32 = 100010
|
||||
|
||||
const LOGIN_FAILED uint32 = 200001
|
||||
const LOGIC_QUERY_WAIT uint32 = 200002
|
||||
const LOGIC_QUERY_ERROR uint32 = 200003
|
||||
const LOGIC_QUERY_NOT_FOUND uint32 = 200004
|
||||
30
common/xerr/errMsg.go
Normal file
30
common/xerr/errMsg.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package xerr
|
||||
|
||||
var message map[uint32]string
|
||||
|
||||
func init() {
|
||||
message = make(map[uint32]string)
|
||||
message[OK] = "SUCCESS"
|
||||
message[SERVER_COMMON_ERROR] = "系统正在升级,请稍后再试"
|
||||
message[REUQEST_PARAM_ERROR] = "参数错误"
|
||||
message[TOKEN_EXPIRE_ERROR] = "token失效,请重新登陆"
|
||||
message[TOKEN_GENERATE_ERROR] = "生成token失败"
|
||||
message[DB_ERROR] = "系统维护升级中,请稍后再试"
|
||||
message[DB_UPDATE_AFFECTED_ZERO_ERROR] = "更新数据影响行数为0"
|
||||
}
|
||||
|
||||
func MapErrMsg(errcode uint32) string {
|
||||
if msg, ok := message[errcode]; ok {
|
||||
return msg
|
||||
} else {
|
||||
return "系统正在升级,请稍后再试"
|
||||
}
|
||||
}
|
||||
|
||||
func IsCodeErr(errcode uint32) bool {
|
||||
if _, ok := message[errcode]; ok {
|
||||
return true
|
||||
} else {
|
||||
return false
|
||||
}
|
||||
}
|
||||
39
common/xerr/errors.go
Normal file
39
common/xerr/errors.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package xerr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
/**
|
||||
常用通用固定错误
|
||||
*/
|
||||
|
||||
type CodeError struct {
|
||||
errCode uint32
|
||||
errMsg string
|
||||
}
|
||||
|
||||
// 返回给前端的错误码
|
||||
func (e *CodeError) GetErrCode() uint32 {
|
||||
return e.errCode
|
||||
}
|
||||
|
||||
// 返回给前端显示端错误信息
|
||||
func (e *CodeError) GetErrMsg() string {
|
||||
return e.errMsg
|
||||
}
|
||||
|
||||
func (e *CodeError) Error() string {
|
||||
return fmt.Sprintf("ErrCode:%d,ErrMsg:%s", e.errCode, e.errMsg)
|
||||
}
|
||||
|
||||
func NewErrCodeMsg(errCode uint32, errMsg string) *CodeError {
|
||||
return &CodeError{errCode: errCode, errMsg: errMsg}
|
||||
}
|
||||
func NewErrCode(errCode uint32) *CodeError {
|
||||
return &CodeError{errCode: errCode, errMsg: MapErrMsg(errCode)}
|
||||
}
|
||||
|
||||
func NewErrMsg(errMsg string) *CodeError {
|
||||
return &CodeError{errCode: CUSTOM_ERROR, errMsg: errMsg}
|
||||
}
|
||||
Reference in New Issue
Block a user