Files
tyapi-server/internal/shared/validator/validator.go

254 lines
7.9 KiB
Go
Raw Normal View History

2025-07-20 20:53:26 +08:00
package validator
import (
"fmt"
"strings"
2025-08-27 22:19:19 +08:00
"sync"
2025-07-20 20:53:26 +08:00
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
zh_translations "github.com/go-playground/validator/v10/translations/zh"
)
2025-08-27 22:19:19 +08:00
// 全局变量声明
var (
globalValidator *validator.Validate
globalTranslator ut.Translator
once sync.Once
)
// InitGlobalValidator 初始化全局校验器(线程安全)
func InitGlobalValidator() {
once.Do(func() {
// 1. 创建新的校验器实例
globalValidator = validator.New()
// 2. 创建中文翻译器
zhLocale := zh.New()
uni := ut.New(zhLocale, zhLocale)
globalTranslator, _ = uni.GetTranslator("zh")
// 3. 注册官方中文翻译
zh_translations.RegisterDefaultTranslations(globalValidator, globalTranslator)
// 4. 注册自定义校验规则
RegisterCustomValidators(globalValidator)
// 5. 注册自定义中文翻译
RegisterCustomTranslations(globalValidator, globalTranslator)
// 6. 设置到Gin全局校验器确保Gin使用我们的校验器
if binding.Validator.Engine() != nil {
// 如果Gin已经初始化则替换其校验器
ginValidator := binding.Validator.Engine().(*validator.Validate)
*ginValidator = *globalValidator
}
})
}
// GetGlobalValidator 获取全局校验器实例
func GetGlobalValidator() *validator.Validate {
if globalValidator == nil {
InitGlobalValidator()
}
return globalValidator
}
// GetGlobalTranslator 获取全局翻译器实例
func GetGlobalTranslator() ut.Translator {
if globalTranslator == nil {
InitGlobalValidator()
}
return globalTranslator
}
2025-07-20 20:53:26 +08:00
// RequestValidator HTTP请求验证器
type RequestValidator struct {
response interfaces.ResponseBuilder
translator ut.Translator
validator *validator.Validate
}
// NewRequestValidator 创建HTTP请求验证器
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
2025-08-27 22:19:19 +08:00
// 确保全局校验器已初始化
InitGlobalValidator()
2025-07-20 20:53:26 +08:00
return &RequestValidator{
response: response,
2025-08-27 22:19:19 +08:00
translator: globalTranslator, // 使用全局翻译器
validator: globalValidator, // 使用全局校验器
2025-07-20 20:53:26 +08:00
}
}
// Validate 验证请求体
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
return v.BindAndValidate(c, dto)
}
// ValidateQuery 验证查询参数
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
validationErrorsMap := v.formatValidationErrors(validationErrors)
v.response.ValidationError(c, validationErrorsMap)
} else {
v.response.BadRequest(c, "查询参数格式错误")
}
return err
}
2025-07-28 01:46:39 +08:00
2025-07-20 20:53:26 +08:00
return nil
}
// ValidateParam 验证路径参数
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
validationErrorsMap := v.formatValidationErrors(validationErrors)
v.response.ValidationError(c, validationErrorsMap)
} else {
v.response.BadRequest(c, "路径参数格式错误")
}
return err
}
2025-07-28 01:46:39 +08:00
2025-07-20 20:53:26 +08:00
return nil
}
// BindAndValidate 绑定并验证请求
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindJSON(dto); err != nil {
if validationErrors, ok := err.(validator.ValidationErrors); ok {
validationErrorsMap := v.formatValidationErrors(validationErrors)
v.response.ValidationError(c, validationErrorsMap)
} else {
v.response.BadRequest(c, "请求体格式错误")
}
return err
}
return nil
}
// formatValidationErrors 格式化验证错误
func (v *RequestValidator) formatValidationErrors(validationErrors validator.ValidationErrors) map[string][]string {
errors := make(map[string][]string)
for _, fieldError := range validationErrors {
fieldName := v.getFieldName(fieldError)
errorMessage := v.getErrorMessage(fieldError)
if _, exists := errors[fieldName]; !exists {
errors[fieldName] = []string{}
}
errors[fieldName] = append(errors[fieldName], errorMessage)
}
return errors
}
// getErrorMessage 获取错误消息
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
fieldDisplayName := getFieldDisplayName(fieldError.Field())
// 优先使用翻译器
errorMessage := fieldError.Translate(v.translator)
if errorMessage != fieldError.Error() {
// 替换字段名为中文
if fieldDisplayName != fieldError.Field() {
errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName)
}
return errorMessage
}
// 回退到手动翻译
return v.getFallbackErrorMessage(fieldError, fieldDisplayName)
}
// getFallbackErrorMessage 获取回退错误消息
func (v *RequestValidator) getFallbackErrorMessage(fieldError validator.FieldError, fieldDisplayName string) string {
tag := fieldError.Tag()
param := fieldError.Param()
switch tag {
case "required":
return fmt.Sprintf("%s不能为空", fieldDisplayName)
case "email":
return fmt.Sprintf("%s必须是有效的邮箱地址", fieldDisplayName)
case "min":
return fmt.Sprintf("%s长度不能少于%s位", fieldDisplayName, param)
case "max":
return fmt.Sprintf("%s长度不能超过%s位", fieldDisplayName, param)
case "len":
return fmt.Sprintf("%s长度必须为%s位", fieldDisplayName, param)
case "eqfield":
paramDisplayName := getFieldDisplayName(param)
return fmt.Sprintf("%s必须与%s一致", fieldDisplayName, paramDisplayName)
case "phone":
return fmt.Sprintf("%s必须是有效的手机号", fieldDisplayName)
case "username":
return fmt.Sprintf("%s格式不正确只能包含字母、数字、下划线且必须以字母开头长度3-20位", fieldDisplayName)
case "strong_password":
return fmt.Sprintf("%s强度不足必须包含大小写字母和数字且不少于8位", fieldDisplayName)
case "social_credit_code":
return fmt.Sprintf("%s格式不正确必须是18位统一社会信用代码", fieldDisplayName)
case "id_card":
return fmt.Sprintf("%s格式不正确必须是18位身份证号", fieldDisplayName)
case "price":
return fmt.Sprintf("%s必须是非负数", fieldDisplayName)
case "sort_order":
return fmt.Sprintf("%s必须是 asc 或 desc", fieldDisplayName)
case "product_code":
2025-08-18 18:18:04 +08:00
return fmt.Sprintf("%s格式不正确只能包含字母、数字、下划线、连字符、中英文括号长度3-50位", fieldDisplayName)
2025-07-20 20:53:26 +08:00
case "uuid":
return fmt.Sprintf("%s必须是有效的UUID格式", fieldDisplayName)
case "url":
return fmt.Sprintf("%s必须是有效的URL地址", fieldDisplayName)
case "oneof":
return fmt.Sprintf("%s必须是以下值之一: %s", fieldDisplayName, param)
case "gt":
return fmt.Sprintf("%s必须大于%s", fieldDisplayName, param)
default:
return fmt.Sprintf("%s格式不正确", fieldDisplayName)
}
}
// getFieldName 获取字段名
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
fieldName := fieldError.Field()
return v.toSnakeCase(fieldName)
}
// toSnakeCase 转换为snake_case
func (v *RequestValidator) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if i > 0 && (r >= 'A' && r <= 'Z') {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// GetValidator 获取validator实例用于业务逻辑
func (v *RequestValidator) GetValidator() *validator.Validate {
return v.validator
}
// ValidateValue 验证单个值(用于业务逻辑)
func (v *RequestValidator) ValidateValue(field interface{}, tag string) error {
return v.validator.Var(field, tag)
}
// ValidateStruct 验证结构体(用于业务逻辑)
func (v *RequestValidator) ValidateStruct(s interface{}) error {
return v.validator.Struct(s)
2025-07-28 01:46:39 +08:00
}