package validator import ( "fmt" "strings" "tyapi-server/internal/shared/interfaces" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/locales/zh" ut "github.com/go-playground/universal-translator" "github.com/go-playground/validator/v10" zh_translations "github.com/go-playground/validator/v10/translations/zh" ) // RequestValidator HTTP请求验证器 type RequestValidator struct { response interfaces.ResponseBuilder translator ut.Translator validator *validator.Validate } // NewRequestValidator 创建HTTP请求验证器 func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator { // 创建中文locale zhLocale := zh.New() uni := ut.New(zhLocale, zhLocale) // 获取中文翻译器 trans, _ := uni.GetTranslator("zh") // 获取gin默认的validator实例 ginValidator := binding.Validator.Engine().(*validator.Validate) // 注册官方中文翻译 zh_translations.RegisterDefaultTranslations(ginValidator, trans) // 注册自定义验证器到gin的全局validator RegisterCustomValidators(ginValidator) // 注册自定义翻译 RegisterCustomTranslations(ginValidator, trans) return &RequestValidator{ response: response, translator: trans, validator: ginValidator, } } // Validate 验证请求体 func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error { return v.BindAndValidate(c, dto) } // ValidateQuery 验证查询参数 func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error { if err := c.ShouldBindQuery(dto); err != nil { if validationErrors, ok := err.(validator.ValidationErrors); ok { validationErrorsMap := v.formatValidationErrors(validationErrors) v.response.ValidationError(c, validationErrorsMap) } else { v.response.BadRequest(c, "查询参数格式错误") } return err } return nil } // ValidateParam 验证路径参数 func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error { if err := c.ShouldBindUri(dto); err != nil { if validationErrors, ok := err.(validator.ValidationErrors); ok { validationErrorsMap := v.formatValidationErrors(validationErrors) v.response.ValidationError(c, validationErrorsMap) } else { v.response.BadRequest(c, "路径参数格式错误") } return err } return nil } // BindAndValidate 绑定并验证请求 func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error { if err := c.ShouldBindJSON(dto); err != nil { if validationErrors, ok := err.(validator.ValidationErrors); ok { validationErrorsMap := v.formatValidationErrors(validationErrors) v.response.ValidationError(c, validationErrorsMap) } else { v.response.BadRequest(c, "请求体格式错误") } return err } return nil } // formatValidationErrors 格式化验证错误 func (v *RequestValidator) formatValidationErrors(validationErrors validator.ValidationErrors) map[string][]string { errors := make(map[string][]string) for _, fieldError := range validationErrors { fieldName := v.getFieldName(fieldError) errorMessage := v.getErrorMessage(fieldError) if _, exists := errors[fieldName]; !exists { errors[fieldName] = []string{} } errors[fieldName] = append(errors[fieldName], errorMessage) } return errors } // getErrorMessage 获取错误消息 func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string { fieldDisplayName := getFieldDisplayName(fieldError.Field()) // 优先使用翻译器 errorMessage := fieldError.Translate(v.translator) if errorMessage != fieldError.Error() { // 替换字段名为中文 if fieldDisplayName != fieldError.Field() { errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName) } return errorMessage } // 回退到手动翻译 return v.getFallbackErrorMessage(fieldError, fieldDisplayName) } // getFallbackErrorMessage 获取回退错误消息 func (v *RequestValidator) getFallbackErrorMessage(fieldError validator.FieldError, fieldDisplayName string) string { tag := fieldError.Tag() param := fieldError.Param() switch tag { case "required": return fmt.Sprintf("%s不能为空", fieldDisplayName) case "email": return fmt.Sprintf("%s必须是有效的邮箱地址", fieldDisplayName) case "min": return fmt.Sprintf("%s长度不能少于%s位", fieldDisplayName, param) case "max": return fmt.Sprintf("%s长度不能超过%s位", fieldDisplayName, param) case "len": return fmt.Sprintf("%s长度必须为%s位", fieldDisplayName, param) case "eqfield": paramDisplayName := getFieldDisplayName(param) return fmt.Sprintf("%s必须与%s一致", fieldDisplayName, paramDisplayName) case "phone": return fmt.Sprintf("%s必须是有效的手机号", fieldDisplayName) case "username": return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线,且必须以字母开头,长度3-20位", fieldDisplayName) case "strong_password": return fmt.Sprintf("%s强度不足,必须包含大小写字母和数字,且不少于8位", fieldDisplayName) case "social_credit_code": return fmt.Sprintf("%s格式不正确,必须是18位统一社会信用代码", fieldDisplayName) case "id_card": return fmt.Sprintf("%s格式不正确,必须是18位身份证号", fieldDisplayName) case "price": return fmt.Sprintf("%s必须是非负数", fieldDisplayName) case "sort_order": return fmt.Sprintf("%s必须是 asc 或 desc", fieldDisplayName) case "product_code": return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线、连字符、中英文括号,长度3-50位", fieldDisplayName) case "uuid": return fmt.Sprintf("%s必须是有效的UUID格式", fieldDisplayName) case "url": return fmt.Sprintf("%s必须是有效的URL地址", fieldDisplayName) case "oneof": return fmt.Sprintf("%s必须是以下值之一: %s", fieldDisplayName, param) case "gt": return fmt.Sprintf("%s必须大于%s", fieldDisplayName, param) default: return fmt.Sprintf("%s格式不正确", fieldDisplayName) } } // getFieldName 获取字段名 func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string { fieldName := fieldError.Field() return v.toSnakeCase(fieldName) } // toSnakeCase 转换为snake_case func (v *RequestValidator) toSnakeCase(str string) string { var result strings.Builder for i, r := range str { if i > 0 && (r >= 'A' && r <= 'Z') { result.WriteRune('_') } result.WriteRune(r) } return strings.ToLower(result.String()) } // GetValidator 获取validator实例(用于业务逻辑) func (v *RequestValidator) GetValidator() *validator.Validate { return v.validator } // ValidateValue 验证单个值(用于业务逻辑) func (v *RequestValidator) ValidateValue(field interface{}, tag string) error { return v.validator.Var(field, tag) } // ValidateStruct 验证结构体(用于业务逻辑) func (v *RequestValidator) ValidateStruct(s interface{}) error { return v.validator.Struct(s) }