first commit

This commit is contained in:
2024-10-02 00:57:17 +08:00
commit 6773f86bc5
312 changed files with 19169 additions and 0 deletions

88
pkg/crypto/crypto.go Normal file
View File

@@ -0,0 +1,88 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"io"
)
// PKCS7填充
func PKCS7Padding(ciphertext []byte, blockSize int) []byte {
padding := blockSize - len(ciphertext)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(ciphertext, padtext...)
}
// 去除PKCS7填充
func PKCS7UnPadding(origData []byte) ([]byte, error) {
length := len(origData)
if length == 0 {
return nil, errors.New("input data error")
}
unpadding := int(origData[length-1])
if unpadding > length {
return nil, errors.New("unpadding size is invalid")
}
return origData[:(length - unpadding)], nil
}
// AES CBC模式加密Base64传入传出
func AesEncrypt(plainText, key []byte) (string, error) {
block, err := aes.NewCipher(key)
if err != nil {
return "", err
}
blockSize := block.BlockSize()
plainText = PKCS7Padding(plainText, blockSize)
cipherText := make([]byte, blockSize+len(plainText))
iv := cipherText[:blockSize] // 使用前blockSize字节作为IV
_, err = io.ReadFull(rand.Reader, iv)
if err != nil {
return "", err
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(cipherText[blockSize:], plainText)
return base64.StdEncoding.EncodeToString(cipherText), nil
}
// AES CBC模式解密Base64传入传出
func AesDecrypt(cipherTextBase64 string, key []byte) ([]byte, error) {
cipherText, err := base64.StdEncoding.DecodeString(cipherTextBase64)
if err != nil {
return nil, err
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
if len(cipherText) < blockSize {
return nil, errors.New("ciphertext too short")
}
iv := cipherText[:blockSize]
cipherText = cipherText[blockSize:]
if len(cipherText)%blockSize != 0 {
return nil, errors.New("ciphertext is not a multiple of the block size")
}
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(cipherText, cipherText)
plainText, err := PKCS7UnPadding(cipherText)
if err != nil {
return nil, err
}
return plainText, nil
}

31
pkg/crypto/generate.go Normal file
View File

@@ -0,0 +1,31 @@
package crypto
import (
"crypto/rand"
"encoding/hex"
"io"
)
// 生成AES-128密钥的函数符合市面规范
func GenerateSecretKey() (string, error) {
key := make([]byte, 16) // 16字节密钥
_, err := io.ReadFull(rand.Reader, key)
if err != nil {
return "", err
}
return hex.EncodeToString(key), nil
}
func GenerateSecretId() (string, error) {
// 创建一个字节数组,用于存储随机数据
bytes := make([]byte, 8) // 因为每个字节表示两个16进制字符
// 读取随机字节到数组中
_, err := rand.Read(bytes)
if err != nil {
return "", err
}
// 将字节数组转换为16进制字符串
return hex.EncodeToString(bytes), nil
}

150
pkg/crypto/west_crypto.go Normal file
View File

@@ -0,0 +1,150 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/sha1"
"encoding/base64"
)
const (
KEY_SIZE = 16 // AES-128, 16 bytes
)
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
func WestDexEncrypt(data, secretKey string) (string, error) {
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
ciphertext, err := aesEncrypt([]byte(data), key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
func WestDexDecrypt(encodedData, secretKey string) ([]byte, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
if err != nil {
return nil, err
}
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
plaintext, err := aesDecrypt(ciphertext, key)
if err != nil {
return nil, err
}
return plaintext, nil
}
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
func generateAESKey(length int, password []byte) []byte {
h := sha1.New()
h.Write(password)
state := h.Sum(nil)
keyBytes := make([]byte, 0, length/8)
for len(keyBytes) < length/8 {
h := sha1.New()
h.Write(state)
state = h.Sum(nil)
keyBytes = append(keyBytes, state...)
}
return keyBytes[:length/8]
}
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
ciphertext := make([]byte, len(paddedPlaintext))
mode := newECBEncrypter(block)
mode.CryptBlocks(ciphertext, paddedPlaintext)
return ciphertext, nil
}
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
plaintext := make([]byte, len(ciphertext))
mode := newECBDecrypter(block)
mode.CryptBlocks(plaintext, ciphertext)
return pkcs5Unpadding(plaintext), nil
}
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
func pkcs5Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
// pkcs5Unpadding removes PKCS5 padding from the input
func pkcs5Unpadding(src []byte) []byte {
length := len(src)
unpadding := int(src[length-1])
return src[:(length - unpadding)]
}
// ECB mode encryption/decryption
type ecb struct {
b cipher.Block
blockSize int
}
func newECB(b cipher.Block) *ecb {
return &ecb{
b: b,
blockSize: b.BlockSize(),
}
}
type ecbEncrypter ecb
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Encrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
type ecbDecrypter ecb
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Decrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}

68
pkg/jwt/jwtx.go Normal file
View File

@@ -0,0 +1,68 @@
package jwtx
import (
"errors"
"github.com/golang-jwt/jwt/v4"
"strconv"
"time"
)
// Token 生成逻辑的函数,接收 userId、过期时间和密钥返回生成的 token
func GenerateJwtToken(userId int64, secret string, expireTime int64) (string, error) {
// 获取当前时间戳
now := time.Now().Unix()
// 定义 JWT Claims
claims := jwt.MapClaims{
"exp": now + expireTime, // token 过期时间
"iat": now, // 签发时间
"userId": userId, // 用户ID
}
// 创建新的 JWT token
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
// 使用密钥对 token 签名
signedToken, err := token.SignedString([]byte(secret))
if err != nil {
return "", err
}
return signedToken, nil
}
func ParseJwtToken(tokenStr string, secret string) (int64, error) {
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil || !token.Valid {
return 0, errors.New("invalid JWT")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return 0, errors.New("invalid JWT claims")
}
// 从 claims 中提取 userId
userIdRaw, ok := claims["userId"]
if !ok {
return 0, errors.New("userId not found in JWT")
}
// 处理不同类型的 userId确保它被转换为 int64
switch userId := userIdRaw.(type) {
case float64:
return int64(userId), nil
case int64:
return userId, nil
case string:
// 如果 userId 是字符串,可以尝试将其转换为 int64
parsedId, err := strconv.ParseInt(userId, 10, 64)
if err != nil {
return 0, errors.New("invalid userId in JWT")
}
return parsedId, nil
default:
return 0, errors.New("unsupported userId type in JWT")
}
}

69
pkg/response/response.go Normal file
View File

@@ -0,0 +1,69 @@
package response
import (
"github.com/zeromicro/go-zero/rest/httpx"
"net/http"
)
// 定义通用的响应结构
type Response struct {
Code int `json:"code"`
Data interface{} `json:"data,omitempty"`
Message string `json:"message"`
}
// 定义分页响应结构
type PageResult struct {
List interface{} `json:"list"`
Total int64 `json:"total"`
Page int `json:"page"`
PageSize int `json:"pageSize"`
}
// 响应成功
func Success(w http.ResponseWriter, data interface{}) {
result := Response{
Code: http.StatusOK,
Data: data,
Message: "操作成功",
}
httpx.OkJson(w, result)
}
// 响应失败
func Fail(w http.ResponseWriter, code int, message string) {
result := Response{
Code: code,
Message: message,
}
httpx.WriteJson(w, code, result)
}
// 无权限
func Unauthorized(w http.ResponseWriter, message string) {
result := Response{
Code: http.StatusUnauthorized,
Message: message,
}
httpx.WriteJson(w, http.StatusUnauthorized, result)
}
// 响应分页数据
func Page(w http.ResponseWriter, list interface{}, total int64, page int, pageSize int) {
result := Response{
Code: http.StatusOK,
Data: PageResult{List: list, Total: total, Page: page, PageSize: pageSize},
Message: "查询成功",
}
httpx.OkJson(w, result)
}
// 自定义错误响应
func CustomError(w http.ResponseWriter, code int, message string, data interface{}) {
result := Response{
Code: code,
Data: data,
Message: message,
}
httpx.WriteJson(w, code, result)
}

View File

@@ -0,0 +1,86 @@
package schema
import (
"encoding/json"
"fmt"
"github.com/xeipuuv/gojsonschema"
"os"
"path/filepath"
"strings"
)
// ValidationResult 结构用于保存校验结果
type ValidationResult struct {
Valid bool
Data map[string]interface{}
Errors string
}
// 校验函数:接受 schema 文件路径和 JSON 数据
func ValidateJSONWithSchema(schemaFileName string, data []byte) (ValidationResult, error) {
// 获取项目根目录
rootPath, err := os.Getwd()
if err != nil {
return ValidationResult{}, fmt.Errorf("无法获取项目根目录: %v", err)
}
// 构建本地 Schema 文件路径
schemaPath := filepath.Join(rootPath, "internal", "schema", schemaFileName)
// 将文件路径转换为 file:// URI 格式
schemaURI := "file:///" + filepath.ToSlash(schemaPath)
// 读取 schema 文件,通过 URI 加载
schemaLoader := gojsonschema.NewReferenceLoader(schemaURI)
// 将传入的 []byte 数据转为 JSON Loader
jsonLoader := gojsonschema.NewBytesLoader(data)
// 执行校验
result, err := gojsonschema.Validate(schemaLoader, jsonLoader)
if err != nil {
return ValidationResult{}, fmt.Errorf("校验过程中出错: %v", err)
}
// 初始化返回结果
validationResult := ValidationResult{
Valid: result.Valid(),
Data: make(map[string]interface{}),
Errors: "",
}
// 如果校验失败,收集并自定义错误信息
if !result.Valid() {
errorMessages := collectErrors(result.Errors())
validationResult.Errors = formatErrors(errorMessages)
return validationResult, nil
}
// 校验成功,解析 JSON
if err := json.Unmarshal(data, &validationResult.Data); err != nil {
return validationResult, fmt.Errorf("JSON 解析出错: %v", err)
}
return validationResult, nil
}
// collectErrors 自定义处理错误信息
func collectErrors(errors []gojsonschema.ResultError) []string {
var errorMessages []string
for _, err := range errors {
// 从 Details() 中获取真正的字段名
details := err.Details()
fieldName, ok := details["property"].(string)
if !ok {
fieldName = err.Field() // 默认使用 err.Field(),如果 property 不存在
}
errorMessages = append(errorMessages, fmt.Sprintf("%s: %s", fieldName, err.Description()))
}
return errorMessages
}
// formatErrors 将错误列表格式化为美观的字符串
func formatErrors(errors []string) string {
return strings.Join(errors, ", ") // 用换行符连接每个错误信息
}

19
pkg/sqlutil/nullstring.go Normal file
View File

@@ -0,0 +1,19 @@
package sqlutil
import "database/sql"
// StringToNullString 将 string 转换为 sql.NullString
func StringToNullString(s string) sql.NullString {
return sql.NullString{
String: s,
Valid: s != "",
}
}
// NullStringToString 将 sql.NullString 转换为 string
func NullStringToString(ns sql.NullString) string {
if ns.Valid {
return ns.String
}
return ""
}