254 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			7.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package validator
 | ||
| 
 | ||
| import (
 | ||
| 	"fmt"
 | ||
| 	"strings"
 | ||
| 	"sync"
 | ||
| 
 | ||
| 	"tyapi-server/internal/shared/interfaces"
 | ||
| 
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| 	"github.com/gin-gonic/gin/binding"
 | ||
| 	"github.com/go-playground/locales/zh"
 | ||
| 	ut "github.com/go-playground/universal-translator"
 | ||
| 	"github.com/go-playground/validator/v10"
 | ||
| 	zh_translations "github.com/go-playground/validator/v10/translations/zh"
 | ||
| )
 | ||
| 
 | ||
| // 全局变量声明
 | ||
| var (
 | ||
| 	globalValidator  *validator.Validate
 | ||
| 	globalTranslator ut.Translator
 | ||
| 	once             sync.Once
 | ||
| )
 | ||
| 
 | ||
| // InitGlobalValidator 初始化全局校验器(线程安全)
 | ||
| func InitGlobalValidator() {
 | ||
| 	once.Do(func() {
 | ||
| 		// 1. 创建新的校验器实例
 | ||
| 		globalValidator = validator.New()
 | ||
| 		
 | ||
| 		// 2. 创建中文翻译器
 | ||
| 		zhLocale := zh.New()
 | ||
| 		uni := ut.New(zhLocale, zhLocale)
 | ||
| 		globalTranslator, _ = uni.GetTranslator("zh")
 | ||
| 		
 | ||
| 		// 3. 注册官方中文翻译
 | ||
| 		zh_translations.RegisterDefaultTranslations(globalValidator, globalTranslator)
 | ||
| 		
 | ||
| 		// 4. 注册自定义校验规则
 | ||
| 		RegisterCustomValidators(globalValidator)
 | ||
| 		
 | ||
| 		// 5. 注册自定义中文翻译
 | ||
| 		RegisterCustomTranslations(globalValidator, globalTranslator)
 | ||
| 		
 | ||
| 		// 6. 设置到Gin全局校验器(确保Gin使用我们的校验器)
 | ||
| 		if binding.Validator.Engine() != nil {
 | ||
| 			// 如果Gin已经初始化,则替换其校验器
 | ||
| 			ginValidator := binding.Validator.Engine().(*validator.Validate)
 | ||
| 			*ginValidator = *globalValidator
 | ||
| 		}
 | ||
| 	})
 | ||
| }
 | ||
| 
 | ||
| // GetGlobalValidator 获取全局校验器实例
 | ||
| func GetGlobalValidator() *validator.Validate {
 | ||
| 	if globalValidator == nil {
 | ||
| 		InitGlobalValidator()
 | ||
| 	}
 | ||
| 	return globalValidator
 | ||
| }
 | ||
| 
 | ||
| // GetGlobalTranslator 获取全局翻译器实例
 | ||
| func GetGlobalTranslator() ut.Translator {
 | ||
| 	if globalTranslator == nil {
 | ||
| 		InitGlobalValidator()
 | ||
| 	}
 | ||
| 	return globalTranslator
 | ||
| }
 | ||
| 
 | ||
| // RequestValidator HTTP请求验证器
 | ||
| type RequestValidator struct {
 | ||
| 	response   interfaces.ResponseBuilder
 | ||
| 	translator ut.Translator
 | ||
| 	validator  *validator.Validate
 | ||
| }
 | ||
| 
 | ||
| // NewRequestValidator 创建HTTP请求验证器
 | ||
| func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
 | ||
| 	// 确保全局校验器已初始化
 | ||
| 	InitGlobalValidator()
 | ||
| 	
 | ||
| 	return &RequestValidator{
 | ||
| 		response:   response,
 | ||
| 		translator: globalTranslator,  // 使用全局翻译器
 | ||
| 		validator:  globalValidator,   // 使用全局校验器
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // Validate 验证请求体
 | ||
| func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
 | ||
| 	return v.BindAndValidate(c, dto)
 | ||
| }
 | ||
| 
 | ||
| // ValidateQuery 验证查询参数
 | ||
| func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
 | ||
| 	if err := c.ShouldBindQuery(dto); err != nil {
 | ||
| 		if validationErrors, ok := err.(validator.ValidationErrors); ok {
 | ||
| 			validationErrorsMap := v.formatValidationErrors(validationErrors)
 | ||
| 			v.response.ValidationError(c, validationErrorsMap)
 | ||
| 		} else {
 | ||
| 			v.response.BadRequest(c, "查询参数格式错误")
 | ||
| 		}
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // ValidateParam 验证路径参数
 | ||
| func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
 | ||
| 	if err := c.ShouldBindUri(dto); err != nil {
 | ||
| 		if validationErrors, ok := err.(validator.ValidationErrors); ok {
 | ||
| 			validationErrorsMap := v.formatValidationErrors(validationErrors)
 | ||
| 			v.response.ValidationError(c, validationErrorsMap)
 | ||
| 		} else {
 | ||
| 			v.response.BadRequest(c, "路径参数格式错误")
 | ||
| 		}
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // BindAndValidate 绑定并验证请求
 | ||
| func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
 | ||
| 	if err := c.ShouldBindJSON(dto); err != nil {
 | ||
| 		if validationErrors, ok := err.(validator.ValidationErrors); ok {
 | ||
| 			validationErrorsMap := v.formatValidationErrors(validationErrors)
 | ||
| 			v.response.ValidationError(c, validationErrorsMap)
 | ||
| 		} else {
 | ||
| 			v.response.BadRequest(c, "请求体格式错误")
 | ||
| 		}
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // formatValidationErrors 格式化验证错误
 | ||
| func (v *RequestValidator) formatValidationErrors(validationErrors validator.ValidationErrors) map[string][]string {
 | ||
| 	errors := make(map[string][]string)
 | ||
| 
 | ||
| 	for _, fieldError := range validationErrors {
 | ||
| 		fieldName := v.getFieldName(fieldError)
 | ||
| 		errorMessage := v.getErrorMessage(fieldError)
 | ||
| 
 | ||
| 		if _, exists := errors[fieldName]; !exists {
 | ||
| 			errors[fieldName] = []string{}
 | ||
| 		}
 | ||
| 		errors[fieldName] = append(errors[fieldName], errorMessage)
 | ||
| 	}
 | ||
| 
 | ||
| 	return errors
 | ||
| }
 | ||
| 
 | ||
| // getErrorMessage 获取错误消息
 | ||
| func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
 | ||
| 	fieldDisplayName := getFieldDisplayName(fieldError.Field())
 | ||
| 
 | ||
| 	// 优先使用翻译器
 | ||
| 	errorMessage := fieldError.Translate(v.translator)
 | ||
| 	if errorMessage != fieldError.Error() {
 | ||
| 		// 替换字段名为中文
 | ||
| 		if fieldDisplayName != fieldError.Field() {
 | ||
| 			errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName)
 | ||
| 		}
 | ||
| 		return errorMessage
 | ||
| 	}
 | ||
| 
 | ||
| 	// 回退到手动翻译
 | ||
| 	return v.getFallbackErrorMessage(fieldError, fieldDisplayName)
 | ||
| }
 | ||
| 
 | ||
| // getFallbackErrorMessage 获取回退错误消息
 | ||
| func (v *RequestValidator) getFallbackErrorMessage(fieldError validator.FieldError, fieldDisplayName string) string {
 | ||
| 	tag := fieldError.Tag()
 | ||
| 	param := fieldError.Param()
 | ||
| 
 | ||
| 	switch tag {
 | ||
| 	case "required":
 | ||
| 		return fmt.Sprintf("%s不能为空", fieldDisplayName)
 | ||
| 	case "email":
 | ||
| 		return fmt.Sprintf("%s必须是有效的邮箱地址", fieldDisplayName)
 | ||
| 	case "min":
 | ||
| 		return fmt.Sprintf("%s长度不能少于%s位", fieldDisplayName, param)
 | ||
| 	case "max":
 | ||
| 		return fmt.Sprintf("%s长度不能超过%s位", fieldDisplayName, param)
 | ||
| 	case "len":
 | ||
| 		return fmt.Sprintf("%s长度必须为%s位", fieldDisplayName, param)
 | ||
| 	case "eqfield":
 | ||
| 		paramDisplayName := getFieldDisplayName(param)
 | ||
| 		return fmt.Sprintf("%s必须与%s一致", fieldDisplayName, paramDisplayName)
 | ||
| 	case "phone":
 | ||
| 		return fmt.Sprintf("%s必须是有效的手机号", fieldDisplayName)
 | ||
| 	case "username":
 | ||
| 		return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线,且必须以字母开头,长度3-20位", fieldDisplayName)
 | ||
| 	case "strong_password":
 | ||
| 		return fmt.Sprintf("%s强度不足,必须包含大小写字母和数字,且不少于8位", fieldDisplayName)
 | ||
| 	case "social_credit_code":
 | ||
| 		return fmt.Sprintf("%s格式不正确,必须是18位统一社会信用代码", fieldDisplayName)
 | ||
| 	case "id_card":
 | ||
| 		return fmt.Sprintf("%s格式不正确,必须是18位身份证号", fieldDisplayName)
 | ||
| 	case "price":
 | ||
| 		return fmt.Sprintf("%s必须是非负数", fieldDisplayName)
 | ||
| 	case "sort_order":
 | ||
| 		return fmt.Sprintf("%s必须是 asc 或 desc", fieldDisplayName)
 | ||
| 	case "product_code":
 | ||
| 		return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线、连字符、中英文括号,长度3-50位", fieldDisplayName)
 | ||
| 	case "uuid":
 | ||
| 		return fmt.Sprintf("%s必须是有效的UUID格式", fieldDisplayName)
 | ||
| 	case "url":
 | ||
| 		return fmt.Sprintf("%s必须是有效的URL地址", fieldDisplayName)
 | ||
| 	case "oneof":
 | ||
| 		return fmt.Sprintf("%s必须是以下值之一: %s", fieldDisplayName, param)
 | ||
| 	case "gt":
 | ||
| 		return fmt.Sprintf("%s必须大于%s", fieldDisplayName, param)
 | ||
| 	default:
 | ||
| 		return fmt.Sprintf("%s格式不正确", fieldDisplayName)
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // getFieldName 获取字段名
 | ||
| func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
 | ||
| 	fieldName := fieldError.Field()
 | ||
| 	return v.toSnakeCase(fieldName)
 | ||
| }
 | ||
| 
 | ||
| // toSnakeCase 转换为snake_case
 | ||
| func (v *RequestValidator) toSnakeCase(str string) string {
 | ||
| 	var result strings.Builder
 | ||
| 	for i, r := range str {
 | ||
| 		if i > 0 && (r >= 'A' && r <= 'Z') {
 | ||
| 			result.WriteRune('_')
 | ||
| 		}
 | ||
| 		result.WriteRune(r)
 | ||
| 	}
 | ||
| 	return strings.ToLower(result.String())
 | ||
| }
 | ||
| 
 | ||
| // GetValidator 获取validator实例(用于业务逻辑)
 | ||
| func (v *RequestValidator) GetValidator() *validator.Validate {
 | ||
| 	return v.validator
 | ||
| }
 | ||
| 
 | ||
| // ValidateValue 验证单个值(用于业务逻辑)
 | ||
| func (v *RequestValidator) ValidateValue(field interface{}, tag string) error {
 | ||
| 	return v.validator.Var(field, tag)
 | ||
| }
 | ||
| 
 | ||
| // ValidateStruct 验证结构体(用于业务逻辑)
 | ||
| func (v *RequestValidator) ValidateStruct(s interface{}) error {
 | ||
| 	return v.validator.Struct(s)
 | ||
| }
 |