| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | package validator | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | import ( | 
					
						
							|  |  |  |  | 	"fmt" | 
					
						
							|  |  |  |  | 	"strings" | 
					
						
							| 
									
										
										
										
											2025-08-27 22:19:19 +08:00
										 |  |  |  | 	"sync" | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"tyapi-server/internal/shared/interfaces" | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	"github.com/gin-gonic/gin" | 
					
						
							|  |  |  |  | 	"github.com/gin-gonic/gin/binding" | 
					
						
							|  |  |  |  | 	"github.com/go-playground/locales/zh" | 
					
						
							|  |  |  |  | 	ut "github.com/go-playground/universal-translator" | 
					
						
							|  |  |  |  | 	"github.com/go-playground/validator/v10" | 
					
						
							|  |  |  |  | 	zh_translations "github.com/go-playground/validator/v10/translations/zh" | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-08-27 22:19:19 +08:00
										 |  |  |  | // 全局变量声明 | 
					
						
							|  |  |  |  | var ( | 
					
						
							|  |  |  |  | 	globalValidator  *validator.Validate | 
					
						
							|  |  |  |  | 	globalTranslator ut.Translator | 
					
						
							|  |  |  |  | 	once             sync.Once | 
					
						
							|  |  |  |  | ) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // InitGlobalValidator 初始化全局校验器(线程安全) | 
					
						
							|  |  |  |  | func InitGlobalValidator() { | 
					
						
							|  |  |  |  | 	once.Do(func() { | 
					
						
							|  |  |  |  | 		// 1. 创建新的校验器实例 | 
					
						
							|  |  |  |  | 		globalValidator = validator.New() | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 2. 创建中文翻译器 | 
					
						
							|  |  |  |  | 		zhLocale := zh.New() | 
					
						
							|  |  |  |  | 		uni := ut.New(zhLocale, zhLocale) | 
					
						
							|  |  |  |  | 		globalTranslator, _ = uni.GetTranslator("zh") | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 3. 注册官方中文翻译 | 
					
						
							|  |  |  |  | 		zh_translations.RegisterDefaultTranslations(globalValidator, globalTranslator) | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 4. 注册自定义校验规则 | 
					
						
							|  |  |  |  | 		RegisterCustomValidators(globalValidator) | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 5. 注册自定义中文翻译 | 
					
						
							|  |  |  |  | 		RegisterCustomTranslations(globalValidator, globalTranslator) | 
					
						
							|  |  |  |  | 		 | 
					
						
							|  |  |  |  | 		// 6. 设置到Gin全局校验器(确保Gin使用我们的校验器) | 
					
						
							|  |  |  |  | 		if binding.Validator.Engine() != nil { | 
					
						
							|  |  |  |  | 			// 如果Gin已经初始化,则替换其校验器 | 
					
						
							|  |  |  |  | 			ginValidator := binding.Validator.Engine().(*validator.Validate) | 
					
						
							|  |  |  |  | 			*ginValidator = *globalValidator | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 	}) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetGlobalValidator 获取全局校验器实例 | 
					
						
							|  |  |  |  | func GetGlobalValidator() *validator.Validate { | 
					
						
							|  |  |  |  | 	if globalValidator == nil { | 
					
						
							|  |  |  |  | 		InitGlobalValidator() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	return globalValidator | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetGlobalTranslator 获取全局翻译器实例 | 
					
						
							|  |  |  |  | func GetGlobalTranslator() ut.Translator { | 
					
						
							|  |  |  |  | 	if globalTranslator == nil { | 
					
						
							|  |  |  |  | 		InitGlobalValidator() | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	return globalTranslator | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | // RequestValidator HTTP请求验证器 | 
					
						
							|  |  |  |  | type RequestValidator struct { | 
					
						
							|  |  |  |  | 	response   interfaces.ResponseBuilder | 
					
						
							|  |  |  |  | 	translator ut.Translator | 
					
						
							|  |  |  |  | 	validator  *validator.Validate | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // NewRequestValidator 创建HTTP请求验证器 | 
					
						
							|  |  |  |  | func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator { | 
					
						
							| 
									
										
										
										
											2025-08-27 22:19:19 +08:00
										 |  |  |  | 	// 确保全局校验器已初始化 | 
					
						
							|  |  |  |  | 	InitGlobalValidator() | 
					
						
							|  |  |  |  | 	 | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | 	return &RequestValidator{ | 
					
						
							|  |  |  |  | 		response:   response, | 
					
						
							| 
									
										
										
										
											2025-08-27 22:19:19 +08:00
										 |  |  |  | 		translator: globalTranslator,  // 使用全局翻译器 | 
					
						
							|  |  |  |  | 		validator:  globalValidator,   // 使用全局校验器 | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // Validate 验证请求体 | 
					
						
							|  |  |  |  | func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error { | 
					
						
							|  |  |  |  | 	return v.BindAndValidate(c, dto) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // ValidateQuery 验证查询参数 | 
					
						
							|  |  |  |  | func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error { | 
					
						
							|  |  |  |  | 	if err := c.ShouldBindQuery(dto); err != nil { | 
					
						
							|  |  |  |  | 		if validationErrors, ok := err.(validator.ValidationErrors); ok { | 
					
						
							|  |  |  |  | 			validationErrorsMap := v.formatValidationErrors(validationErrors) | 
					
						
							|  |  |  |  | 			v.response.ValidationError(c, validationErrorsMap) | 
					
						
							|  |  |  |  | 		} else { | 
					
						
							|  |  |  |  | 			v.response.BadRequest(c, "查询参数格式错误") | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return err | 
					
						
							|  |  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-07-28 01:46:39 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // ValidateParam 验证路径参数 | 
					
						
							|  |  |  |  | func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error { | 
					
						
							|  |  |  |  | 	if err := c.ShouldBindUri(dto); err != nil { | 
					
						
							|  |  |  |  | 		if validationErrors, ok := err.(validator.ValidationErrors); ok { | 
					
						
							|  |  |  |  | 			validationErrorsMap := v.formatValidationErrors(validationErrors) | 
					
						
							|  |  |  |  | 			v.response.ValidationError(c, validationErrorsMap) | 
					
						
							|  |  |  |  | 		} else { | 
					
						
							|  |  |  |  | 			v.response.BadRequest(c, "路径参数格式错误") | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return err | 
					
						
							|  |  |  |  | 	} | 
					
						
							| 
									
										
										
										
											2025-07-28 01:46:39 +08:00
										 |  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // BindAndValidate 绑定并验证请求 | 
					
						
							|  |  |  |  | func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error { | 
					
						
							|  |  |  |  | 	if err := c.ShouldBindJSON(dto); err != nil { | 
					
						
							|  |  |  |  | 		if validationErrors, ok := err.(validator.ValidationErrors); ok { | 
					
						
							|  |  |  |  | 			validationErrorsMap := v.formatValidationErrors(validationErrors) | 
					
						
							|  |  |  |  | 			v.response.ValidationError(c, validationErrorsMap) | 
					
						
							|  |  |  |  | 		} else { | 
					
						
							|  |  |  |  | 			v.response.BadRequest(c, "请求体格式错误") | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return err | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return nil | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // formatValidationErrors 格式化验证错误 | 
					
						
							|  |  |  |  | func (v *RequestValidator) formatValidationErrors(validationErrors validator.ValidationErrors) map[string][]string { | 
					
						
							|  |  |  |  | 	errors := make(map[string][]string) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	for _, fieldError := range validationErrors { | 
					
						
							|  |  |  |  | 		fieldName := v.getFieldName(fieldError) | 
					
						
							|  |  |  |  | 		errorMessage := v.getErrorMessage(fieldError) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 		if _, exists := errors[fieldName]; !exists { | 
					
						
							|  |  |  |  | 			errors[fieldName] = []string{} | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		errors[fieldName] = append(errors[fieldName], errorMessage) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	return errors | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getErrorMessage 获取错误消息 | 
					
						
							|  |  |  |  | func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string { | 
					
						
							|  |  |  |  | 	fieldDisplayName := getFieldDisplayName(fieldError.Field()) | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 优先使用翻译器 | 
					
						
							|  |  |  |  | 	errorMessage := fieldError.Translate(v.translator) | 
					
						
							|  |  |  |  | 	if errorMessage != fieldError.Error() { | 
					
						
							|  |  |  |  | 		// 替换字段名为中文 | 
					
						
							|  |  |  |  | 		if fieldDisplayName != fieldError.Field() { | 
					
						
							|  |  |  |  | 			errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName) | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		return errorMessage | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	// 回退到手动翻译 | 
					
						
							|  |  |  |  | 	return v.getFallbackErrorMessage(fieldError, fieldDisplayName) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getFallbackErrorMessage 获取回退错误消息 | 
					
						
							|  |  |  |  | func (v *RequestValidator) getFallbackErrorMessage(fieldError validator.FieldError, fieldDisplayName string) string { | 
					
						
							|  |  |  |  | 	tag := fieldError.Tag() | 
					
						
							|  |  |  |  | 	param := fieldError.Param() | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | 	switch tag { | 
					
						
							|  |  |  |  | 	case "required": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s不能为空", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "email": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是有效的邮箱地址", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "min": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s长度不能少于%s位", fieldDisplayName, param) | 
					
						
							|  |  |  |  | 	case "max": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s长度不能超过%s位", fieldDisplayName, param) | 
					
						
							|  |  |  |  | 	case "len": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s长度必须为%s位", fieldDisplayName, param) | 
					
						
							|  |  |  |  | 	case "eqfield": | 
					
						
							|  |  |  |  | 		paramDisplayName := getFieldDisplayName(param) | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须与%s一致", fieldDisplayName, paramDisplayName) | 
					
						
							|  |  |  |  | 	case "phone": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是有效的手机号", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "username": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线,且必须以字母开头,长度3-20位", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "strong_password": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s强度不足,必须包含大小写字母和数字,且不少于8位", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "social_credit_code": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s格式不正确,必须是18位统一社会信用代码", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "id_card": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s格式不正确,必须是18位身份证号", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "price": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是非负数", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "sort_order": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是 asc 或 desc", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "product_code": | 
					
						
							| 
									
										
										
										
											2025-08-18 18:18:04 +08:00
										 |  |  |  | 		return fmt.Sprintf("%s格式不正确,只能包含字母、数字、下划线、连字符、中英文括号,长度3-50位", fieldDisplayName) | 
					
						
							| 
									
										
										
										
											2025-07-20 20:53:26 +08:00
										 |  |  |  | 	case "uuid": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是有效的UUID格式", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "url": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是有效的URL地址", fieldDisplayName) | 
					
						
							|  |  |  |  | 	case "oneof": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须是以下值之一: %s", fieldDisplayName, param) | 
					
						
							|  |  |  |  | 	case "gt": | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s必须大于%s", fieldDisplayName, param) | 
					
						
							|  |  |  |  | 	default: | 
					
						
							|  |  |  |  | 		return fmt.Sprintf("%s格式不正确", fieldDisplayName) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // getFieldName 获取字段名 | 
					
						
							|  |  |  |  | func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string { | 
					
						
							|  |  |  |  | 	fieldName := fieldError.Field() | 
					
						
							|  |  |  |  | 	return v.toSnakeCase(fieldName) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // toSnakeCase 转换为snake_case | 
					
						
							|  |  |  |  | func (v *RequestValidator) toSnakeCase(str string) string { | 
					
						
							|  |  |  |  | 	var result strings.Builder | 
					
						
							|  |  |  |  | 	for i, r := range str { | 
					
						
							|  |  |  |  | 		if i > 0 && (r >= 'A' && r <= 'Z') { | 
					
						
							|  |  |  |  | 			result.WriteRune('_') | 
					
						
							|  |  |  |  | 		} | 
					
						
							|  |  |  |  | 		result.WriteRune(r) | 
					
						
							|  |  |  |  | 	} | 
					
						
							|  |  |  |  | 	return strings.ToLower(result.String()) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // GetValidator 获取validator实例(用于业务逻辑) | 
					
						
							|  |  |  |  | func (v *RequestValidator) GetValidator() *validator.Validate { | 
					
						
							|  |  |  |  | 	return v.validator | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // ValidateValue 验证单个值(用于业务逻辑) | 
					
						
							|  |  |  |  | func (v *RequestValidator) ValidateValue(field interface{}, tag string) error { | 
					
						
							|  |  |  |  | 	return v.validator.Var(field, tag) | 
					
						
							|  |  |  |  | } | 
					
						
							|  |  |  |  | 
 | 
					
						
							|  |  |  |  | // ValidateStruct 验证结构体(用于业务逻辑) | 
					
						
							|  |  |  |  | func (v *RequestValidator) ValidateStruct(s interface{}) error { | 
					
						
							|  |  |  |  | 	return v.validator.Struct(s) | 
					
						
							| 
									
										
										
										
											2025-07-28 01:46:39 +08:00
										 |  |  |  | } |