first commit
This commit is contained in:
88
pkg/crypto/crypto.go
Normal file
88
pkg/crypto/crypto.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
// PKCS7填充
|
||||
func PKCS7Padding(ciphertext []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(ciphertext)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padtext...)
|
||||
}
|
||||
|
||||
// 去除PKCS7填充
|
||||
func PKCS7UnPadding(origData []byte) ([]byte, error) {
|
||||
length := len(origData)
|
||||
if length == 0 {
|
||||
return nil, errors.New("input data error")
|
||||
}
|
||||
unpadding := int(origData[length-1])
|
||||
if unpadding > length {
|
||||
return nil, errors.New("unpadding size is invalid")
|
||||
}
|
||||
return origData[:(length - unpadding)], nil
|
||||
}
|
||||
|
||||
// AES CBC模式加密,Base64传入传出
|
||||
func AesEncrypt(plainText, key []byte) (string, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
blockSize := block.BlockSize()
|
||||
plainText = PKCS7Padding(plainText, blockSize)
|
||||
|
||||
cipherText := make([]byte, blockSize+len(plainText))
|
||||
iv := cipherText[:blockSize] // 使用前blockSize字节作为IV
|
||||
_, err = io.ReadFull(rand.Reader, iv)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
mode.CryptBlocks(cipherText[blockSize:], plainText)
|
||||
|
||||
return base64.StdEncoding.EncodeToString(cipherText), nil
|
||||
}
|
||||
|
||||
// AES CBC模式解密,Base64传入传出
|
||||
func AesDecrypt(cipherTextBase64 string, key []byte) ([]byte, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
blockSize := block.BlockSize()
|
||||
if len(cipherText) < blockSize {
|
||||
return nil, errors.New("ciphertext too short")
|
||||
}
|
||||
|
||||
iv := cipherText[:blockSize]
|
||||
cipherText = cipherText[blockSize:]
|
||||
|
||||
if len(cipherText)%blockSize != 0 {
|
||||
return nil, errors.New("ciphertext is not a multiple of the block size")
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
mode.CryptBlocks(cipherText, cipherText)
|
||||
|
||||
plainText, err := PKCS7UnPadding(cipherText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return plainText, nil
|
||||
}
|
||||
31
pkg/crypto/generate.go
Normal file
31
pkg/crypto/generate.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
)
|
||||
|
||||
// 生成AES-128密钥的函数,符合市面规范
|
||||
func GenerateSecretKey() (string, error) {
|
||||
key := make([]byte, 16) // 16字节密钥
|
||||
_, err := io.ReadFull(rand.Reader, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(key), nil
|
||||
}
|
||||
|
||||
func GenerateSecretId() (string, error) {
|
||||
// 创建一个字节数组,用于存储随机数据
|
||||
bytes := make([]byte, 8) // 因为每个字节表示两个16进制字符
|
||||
|
||||
// 读取随机字节到数组中
|
||||
_, err := rand.Read(bytes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 将字节数组转换为16进制字符串
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
150
pkg/crypto/west_crypto.go
Normal file
150
pkg/crypto/west_crypto.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
const (
|
||||
KEY_SIZE = 16 // AES-128, 16 bytes
|
||||
)
|
||||
|
||||
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
|
||||
func WestDexEncrypt(data, secretKey string) (string, error) {
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
ciphertext, err := aesEncrypt([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
|
||||
func WestDexDecrypt(encodedData, secretKey string) ([]byte, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
plaintext, err := aesDecrypt(ciphertext, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
|
||||
func generateAESKey(length int, password []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(password)
|
||||
state := h.Sum(nil)
|
||||
|
||||
keyBytes := make([]byte, 0, length/8)
|
||||
for len(keyBytes) < length/8 {
|
||||
h := sha1.New()
|
||||
h.Write(state)
|
||||
state = h.Sum(nil)
|
||||
keyBytes = append(keyBytes, state...)
|
||||
}
|
||||
|
||||
return keyBytes[:length/8]
|
||||
}
|
||||
|
||||
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
|
||||
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
|
||||
ciphertext := make([]byte, len(paddedPlaintext))
|
||||
mode := newECBEncrypter(block)
|
||||
mode.CryptBlocks(ciphertext, paddedPlaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
|
||||
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode := newECBDecrypter(block)
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext), nil
|
||||
}
|
||||
|
||||
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
|
||||
func pkcs5Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs5Unpadding removes PKCS5 padding from the input
|
||||
func pkcs5Unpadding(src []byte) []byte {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
return src[:(length - unpadding)]
|
||||
}
|
||||
|
||||
// ECB mode encryption/decryption
|
||||
type ecb struct {
|
||||
b cipher.Block
|
||||
blockSize int
|
||||
}
|
||||
|
||||
func newECB(b cipher.Block) *ecb {
|
||||
return &ecb{
|
||||
b: b,
|
||||
blockSize: b.BlockSize(),
|
||||
}
|
||||
}
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Encrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Decrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
68
pkg/jwt/jwtx.go
Normal file
68
pkg/jwt/jwtx.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package jwtx
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/golang-jwt/jwt/v4"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Token 生成逻辑的函数,接收 userId、过期时间和密钥,返回生成的 token
|
||||
func GenerateJwtToken(userId int64, secret string, expireTime int64) (string, error) {
|
||||
// 获取当前时间戳
|
||||
now := time.Now().Unix()
|
||||
// 定义 JWT Claims
|
||||
claims := jwt.MapClaims{
|
||||
"exp": now + expireTime, // token 过期时间
|
||||
"iat": now, // 签发时间
|
||||
"userId": userId, // 用户ID
|
||||
}
|
||||
|
||||
// 创建新的 JWT token
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
|
||||
// 使用密钥对 token 签名
|
||||
signedToken, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return signedToken, nil
|
||||
}
|
||||
func ParseJwtToken(tokenStr string, secret string) (int64, error) {
|
||||
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil || !token.Valid {
|
||||
return 0, errors.New("invalid JWT")
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return 0, errors.New("invalid JWT claims")
|
||||
}
|
||||
|
||||
// 从 claims 中提取 userId
|
||||
userIdRaw, ok := claims["userId"]
|
||||
if !ok {
|
||||
return 0, errors.New("userId not found in JWT")
|
||||
}
|
||||
|
||||
// 处理不同类型的 userId,确保它被转换为 int64
|
||||
switch userId := userIdRaw.(type) {
|
||||
case float64:
|
||||
return int64(userId), nil
|
||||
case int64:
|
||||
return userId, nil
|
||||
case string:
|
||||
// 如果 userId 是字符串,可以尝试将其转换为 int64
|
||||
parsedId, err := strconv.ParseInt(userId, 10, 64)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid userId in JWT")
|
||||
}
|
||||
return parsedId, nil
|
||||
default:
|
||||
return 0, errors.New("unsupported userId type in JWT")
|
||||
}
|
||||
}
|
||||
69
pkg/response/response.go
Normal file
69
pkg/response/response.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package response
|
||||
|
||||
import (
|
||||
"github.com/zeromicro/go-zero/rest/httpx"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// 定义通用的响应结构
|
||||
type Response struct {
|
||||
Code int `json:"code"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// 定义分页响应结构
|
||||
type PageResult struct {
|
||||
List interface{} `json:"list"`
|
||||
Total int64 `json:"total"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"pageSize"`
|
||||
}
|
||||
|
||||
// 响应成功
|
||||
func Success(w http.ResponseWriter, data interface{}) {
|
||||
result := Response{
|
||||
Code: http.StatusOK,
|
||||
Data: data,
|
||||
Message: "操作成功",
|
||||
}
|
||||
httpx.OkJson(w, result)
|
||||
}
|
||||
|
||||
// 响应失败
|
||||
func Fail(w http.ResponseWriter, code int, message string) {
|
||||
result := Response{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
httpx.WriteJson(w, code, result)
|
||||
}
|
||||
|
||||
// 无权限
|
||||
func Unauthorized(w http.ResponseWriter, message string) {
|
||||
result := Response{
|
||||
Code: http.StatusUnauthorized,
|
||||
Message: message,
|
||||
}
|
||||
httpx.WriteJson(w, http.StatusUnauthorized, result)
|
||||
}
|
||||
|
||||
// 响应分页数据
|
||||
func Page(w http.ResponseWriter, list interface{}, total int64, page int, pageSize int) {
|
||||
result := Response{
|
||||
Code: http.StatusOK,
|
||||
Data: PageResult{List: list, Total: total, Page: page, PageSize: pageSize},
|
||||
Message: "查询成功",
|
||||
}
|
||||
httpx.OkJson(w, result)
|
||||
}
|
||||
|
||||
// 自定义错误响应
|
||||
func CustomError(w http.ResponseWriter, code int, message string, data interface{}) {
|
||||
result := Response{
|
||||
Code: code,
|
||||
Data: data,
|
||||
Message: message,
|
||||
}
|
||||
httpx.WriteJson(w, code, result)
|
||||
}
|
||||
86
pkg/schema/schemaVerify.go
Normal file
86
pkg/schema/schemaVerify.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package schema
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/xeipuuv/gojsonschema"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ValidationResult 结构用于保存校验结果
|
||||
type ValidationResult struct {
|
||||
Valid bool
|
||||
Data map[string]interface{}
|
||||
Errors string
|
||||
}
|
||||
|
||||
// 校验函数:接受 schema 文件路径和 JSON 数据
|
||||
func ValidateJSONWithSchema(schemaFileName string, data []byte) (ValidationResult, error) {
|
||||
// 获取项目根目录
|
||||
rootPath, err := os.Getwd()
|
||||
if err != nil {
|
||||
return ValidationResult{}, fmt.Errorf("无法获取项目根目录: %v", err)
|
||||
}
|
||||
|
||||
// 构建本地 Schema 文件路径
|
||||
schemaPath := filepath.Join(rootPath, "internal", "schema", schemaFileName)
|
||||
|
||||
// 将文件路径转换为 file:// URI 格式
|
||||
schemaURI := "file:///" + filepath.ToSlash(schemaPath)
|
||||
|
||||
// 读取 schema 文件,通过 URI 加载
|
||||
schemaLoader := gojsonschema.NewReferenceLoader(schemaURI)
|
||||
|
||||
// 将传入的 []byte 数据转为 JSON Loader
|
||||
jsonLoader := gojsonschema.NewBytesLoader(data)
|
||||
|
||||
// 执行校验
|
||||
result, err := gojsonschema.Validate(schemaLoader, jsonLoader)
|
||||
if err != nil {
|
||||
return ValidationResult{}, fmt.Errorf("校验过程中出错: %v", err)
|
||||
}
|
||||
|
||||
// 初始化返回结果
|
||||
validationResult := ValidationResult{
|
||||
Valid: result.Valid(),
|
||||
Data: make(map[string]interface{}),
|
||||
Errors: "",
|
||||
}
|
||||
|
||||
// 如果校验失败,收集并自定义错误信息
|
||||
if !result.Valid() {
|
||||
errorMessages := collectErrors(result.Errors())
|
||||
validationResult.Errors = formatErrors(errorMessages)
|
||||
return validationResult, nil
|
||||
}
|
||||
|
||||
// 校验成功,解析 JSON
|
||||
if err := json.Unmarshal(data, &validationResult.Data); err != nil {
|
||||
return validationResult, fmt.Errorf("JSON 解析出错: %v", err)
|
||||
}
|
||||
|
||||
return validationResult, nil
|
||||
}
|
||||
|
||||
// collectErrors 自定义处理错误信息
|
||||
func collectErrors(errors []gojsonschema.ResultError) []string {
|
||||
var errorMessages []string
|
||||
for _, err := range errors {
|
||||
// 从 Details() 中获取真正的字段名
|
||||
details := err.Details()
|
||||
fieldName, ok := details["property"].(string)
|
||||
if !ok {
|
||||
fieldName = err.Field() // 默认使用 err.Field(),如果 property 不存在
|
||||
}
|
||||
|
||||
errorMessages = append(errorMessages, fmt.Sprintf("%s: %s", fieldName, err.Description()))
|
||||
}
|
||||
return errorMessages
|
||||
}
|
||||
|
||||
// formatErrors 将错误列表格式化为美观的字符串
|
||||
func formatErrors(errors []string) string {
|
||||
return strings.Join(errors, ", ") // 用换行符连接每个错误信息
|
||||
}
|
||||
19
pkg/sqlutil/nullstring.go
Normal file
19
pkg/sqlutil/nullstring.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package sqlutil
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// StringToNullString 将 string 转换为 sql.NullString
|
||||
func StringToNullString(s string) sql.NullString {
|
||||
return sql.NullString{
|
||||
String: s,
|
||||
Valid: s != "",
|
||||
}
|
||||
}
|
||||
|
||||
// NullStringToString 将 sql.NullString 转换为 string
|
||||
func NullStringToString(ns sql.NullString) string {
|
||||
if ns.Valid {
|
||||
return ns.String
|
||||
}
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user