216 lines
6.9 KiB
Go
216 lines
6.9 KiB
Go
|
|
package validator
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"fmt"
|
|||
|
|
"strings"
|
|||
|
|
|
|||
|
|
"tyapi-server/internal/shared/interfaces"
|
|||
|
|
|
|||
|
|
"github.com/gin-gonic/gin"
|
|||
|
|
"github.com/gin-gonic/gin/binding"
|
|||
|
|
"github.com/go-playground/locales/zh"
|
|||
|
|
ut "github.com/go-playground/universal-translator"
|
|||
|
|
"github.com/go-playground/validator/v10"
|
|||
|
|
zh_translations "github.com/go-playground/validator/v10/translations/zh"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// RequestValidator HTTP请求验证器
|
|||
|
|
type RequestValidator struct {
|
|||
|
|
response interfaces.ResponseBuilder
|
|||
|
|
translator ut.Translator
|
|||
|
|
validator *validator.Validate
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewRequestValidator 创建HTTP请求验证器
|
|||
|
|
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
|
|||
|
|
// 创建中文locale
|
|||
|
|
zhLocale := zh.New()
|
|||
|
|
uni := ut.New(zhLocale, zhLocale)
|
|||
|
|
|
|||
|
|
// 获取中文翻译器
|
|||
|
|
trans, _ := uni.GetTranslator("zh")
|
|||
|
|
|
|||
|
|
// 获取gin默认的validator实例
|
|||
|
|
ginValidator := binding.Validator.Engine().(*validator.Validate)
|
|||
|
|
|
|||
|
|
// 注册官方中文翻译
|
|||
|
|
zh_translations.RegisterDefaultTranslations(ginValidator, trans)
|
|||
|
|
|
|||
|
|
// 注册自定义验证器到gin的全局validator
|
|||
|
|
RegisterCustomValidators(ginValidator)
|
|||
|
|
|
|||
|
|
// 注册自定义翻译
|
|||
|
|
RegisterCustomTranslations(ginValidator, trans)
|
|||
|
|
|
|||
|
|
return &RequestValidator{
|
|||
|
|
response: response,
|
|||
|
|
translator: trans,
|
|||
|
|
validator: ginValidator,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Validate 验证请求体
|
|||
|
|
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
|
|||
|
|
return v.BindAndValidate(c, dto)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateQuery 验证查询参数
|
|||
|
|
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
|
|||
|
|
if err := c.ShouldBindQuery(dto); err != nil {
|
|||
|
|
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
|||
|
|
validationErrorsMap := v.formatValidationErrors(validationErrors)
|
|||
|
|
v.response.ValidationError(c, validationErrorsMap)
|
|||
|
|
} else {
|
|||
|
|
v.response.BadRequest(c, "查询参数格式错误")
|
|||
|
|
}
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateParam 验证路径参数
|
|||
|
|
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
|
|||
|
|
if err := c.ShouldBindUri(dto); err != nil {
|
|||
|
|
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
|||
|
|
validationErrorsMap := v.formatValidationErrors(validationErrors)
|
|||
|
|
v.response.ValidationError(c, validationErrorsMap)
|
|||
|
|
} else {
|
|||
|
|
v.response.BadRequest(c, "路径参数格式错误")
|
|||
|
|
}
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// BindAndValidate 绑定并验证请求
|
|||
|
|
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
|
|||
|
|
if err := c.ShouldBindJSON(dto); err != nil {
|
|||
|
|
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
|||
|
|
validationErrorsMap := v.formatValidationErrors(validationErrors)
|
|||
|
|
v.response.ValidationError(c, validationErrorsMap)
|
|||
|
|
} else {
|
|||
|
|
v.response.BadRequest(c, "请求体格式错误")
|
|||
|
|
}
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// formatValidationErrors 格式化验证错误
|
|||
|
|
func (v *RequestValidator) formatValidationErrors(validationErrors validator.ValidationErrors) map[string][]string {
|
|||
|
|
errors := make(map[string][]string)
|
|||
|
|
|
|||
|
|
for _, fieldError := range validationErrors {
|
|||
|
|
fieldName := v.getFieldName(fieldError)
|
|||
|
|
errorMessage := v.getErrorMessage(fieldError)
|
|||
|
|
|
|||
|
|
if _, exists := errors[fieldName]; !exists {
|
|||
|
|
errors[fieldName] = []string{}
|
|||
|
|
}
|
|||
|
|
errors[fieldName] = append(errors[fieldName], errorMessage)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return errors
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getErrorMessage 获取错误消息
|
|||
|
|
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
|
|||
|
|
fieldDisplayName := getFieldDisplayName(fieldError.Field())
|
|||
|
|
|
|||
|
|
// 优先使用翻译器
|
|||
|
|
errorMessage := fieldError.Translate(v.translator)
|
|||
|
|
if errorMessage != fieldError.Error() {
|
|||
|
|
// 替换字段名为中文
|
|||
|
|
if fieldDisplayName != fieldError.Field() {
|
|||
|
|
errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName)
|
|||
|
|
}
|
|||
|
|
return errorMessage
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 回退到手动翻译
|
|||
|
|
return v.getFallbackErrorMessage(fieldError, fieldDisplayName)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getFallbackErrorMessage 获取回退错误消息
|
|||
|
|
func (v *RequestValidator) getFallbackErrorMessage(fieldError validator.FieldError, fieldDisplayName string) string {
|
|||
|
|
tag := fieldError.Tag()
|
|||
|
|
param := fieldError.Param()
|
|||
|
|
|
|||
|
|
switch tag {
|
|||
|
|
case "required":
|
|||
|
|
return fmt.Sprintf("%s不能为空", fieldDisplayName)
|
|||
|
|
case "email":
|
|||
|
|
return fmt.Sprintf("%s必须是有效的邮箱地址", fieldDisplayName)
|
|||
|
|
case "min":
|
|||
|
|
return fmt.Sprintf("%s长度不能少于%s位", fieldDisplayName, param)
|
|||
|
|
case "max":
|
|||
|
|
return fmt.Sprintf("%s长度不能超过%s位", fieldDisplayName, param)
|
|||
|
|
case "len":
|
|||
|
|
return fmt.Sprintf("%s长度必须为%s位", fieldDisplayName, param)
|
|||
|
|
case "eqfield":
|
|||
|
|
paramDisplayName := getFieldDisplayName(param)
|
|||
|
|
return fmt.Sprintf("%s必须与%s一致", fieldDisplayName, paramDisplayName)
|
|||
|
|
case "phone":
|
|||
|
|
return fmt.Sprintf("%s必须是有效的手机号", fieldDisplayName)
|
|||
|
|
case "username":
|
|||
|
|
return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线,且必须以字母开头,长度3-20位", fieldDisplayName)
|
|||
|
|
case "strong_password":
|
|||
|
|
return fmt.Sprintf("%s强度不足,必须包含大小写字母和数字,且不少于8位", fieldDisplayName)
|
|||
|
|
case "social_credit_code":
|
|||
|
|
return fmt.Sprintf("%s格式不正确,必须是18位统一社会信用代码", fieldDisplayName)
|
|||
|
|
case "id_card":
|
|||
|
|
return fmt.Sprintf("%s格式不正确,必须是18位身份证号", fieldDisplayName)
|
|||
|
|
case "price":
|
|||
|
|
return fmt.Sprintf("%s必须是非负数", fieldDisplayName)
|
|||
|
|
case "sort_order":
|
|||
|
|
return fmt.Sprintf("%s必须是 asc 或 desc", fieldDisplayName)
|
|||
|
|
case "product_code":
|
|||
|
|
return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线、连字符,长度3-50位", fieldDisplayName)
|
|||
|
|
case "uuid":
|
|||
|
|
return fmt.Sprintf("%s必须是有效的UUID格式", fieldDisplayName)
|
|||
|
|
case "url":
|
|||
|
|
return fmt.Sprintf("%s必须是有效的URL地址", fieldDisplayName)
|
|||
|
|
case "oneof":
|
|||
|
|
return fmt.Sprintf("%s必须是以下值之一: %s", fieldDisplayName, param)
|
|||
|
|
case "gt":
|
|||
|
|
return fmt.Sprintf("%s必须大于%s", fieldDisplayName, param)
|
|||
|
|
default:
|
|||
|
|
return fmt.Sprintf("%s格式不正确", fieldDisplayName)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// getFieldName 获取字段名
|
|||
|
|
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
|
|||
|
|
fieldName := fieldError.Field()
|
|||
|
|
return v.toSnakeCase(fieldName)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// toSnakeCase 转换为snake_case
|
|||
|
|
func (v *RequestValidator) toSnakeCase(str string) string {
|
|||
|
|
var result strings.Builder
|
|||
|
|
for i, r := range str {
|
|||
|
|
if i > 0 && (r >= 'A' && r <= 'Z') {
|
|||
|
|
result.WriteRune('_')
|
|||
|
|
}
|
|||
|
|
result.WriteRune(r)
|
|||
|
|
}
|
|||
|
|
return strings.ToLower(result.String())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetValidator 获取validator实例(用于业务逻辑)
|
|||
|
|
func (v *RequestValidator) GetValidator() *validator.Validate {
|
|||
|
|
return v.validator
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateValue 验证单个值(用于业务逻辑)
|
|||
|
|
func (v *RequestValidator) ValidateValue(field interface{}, tag string) error {
|
|||
|
|
return v.validator.Var(field, tag)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ValidateStruct 验证结构体(用于业务逻辑)
|
|||
|
|
func (v *RequestValidator) ValidateStruct(s interface{}) error {
|
|||
|
|
return v.validator.Struct(s)
|
|||
|
|
}
|