package http import ( "fmt" "strings" "tyapi-server/internal/shared/interfaces" "github.com/gin-gonic/gin" "github.com/go-playground/validator/v10" ) // RequestValidator 请求验证器实现 type RequestValidator struct { validator *validator.Validate response interfaces.ResponseBuilder } // NewRequestValidator 创建请求验证器 func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator { v := validator.New() // 注册自定义验证器 registerCustomValidators(v) return &RequestValidator{ validator: v, response: response, } } // Validate 验证请求体 func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error { if err := v.validator.Struct(dto); err != nil { validationErrors := v.formatValidationErrors(err) v.response.BadRequest(c, "Validation failed", validationErrors) return err } return nil } // ValidateQuery 验证查询参数 func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error { if err := c.ShouldBindQuery(dto); err != nil { v.response.BadRequest(c, "查询参数格式错误", err.Error()) return err } if err := v.validator.Struct(dto); err != nil { validationErrors := v.formatValidationErrors(err) v.response.ValidationError(c, validationErrors) return err } return nil } // ValidateParam 验证路径参数 func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error { if err := c.ShouldBindUri(dto); err != nil { v.response.BadRequest(c, "路径参数格式错误", err.Error()) return err } if err := v.validator.Struct(dto); err != nil { validationErrors := v.formatValidationErrors(err) v.response.ValidationError(c, validationErrors) return err } return nil } // BindAndValidate 绑定并验证请求 func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error { // 绑定请求体 if err := c.ShouldBindJSON(dto); err != nil { v.response.BadRequest(c, "请求体格式错误", err.Error()) return err } // 验证数据 return v.Validate(c, dto) } // formatValidationErrors 格式化验证错误 func (v *RequestValidator) formatValidationErrors(err error) map[string][]string { errors := make(map[string][]string) if validationErrors, ok := err.(validator.ValidationErrors); ok { for _, fieldError := range validationErrors { fieldName := v.getFieldName(fieldError) errorMessage := v.getErrorMessage(fieldError) if _, exists := errors[fieldName]; !exists { errors[fieldName] = []string{} } errors[fieldName] = append(errors[fieldName], errorMessage) } } return errors } // getFieldName 获取字段名(JSON标签优先) func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string { // 可以通过反射获取JSON标签,这里简化处理 fieldName := fieldError.Field() // 转换为snake_case(可选) return v.toSnakeCase(fieldName) } // getErrorMessage 获取错误消息 func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string { field := fieldError.Field() tag := fieldError.Tag() param := fieldError.Param() fieldDisplayName := v.getFieldDisplayName(field) switch tag { case "required": return fmt.Sprintf("%s 不能为空", fieldDisplayName) case "email": return fmt.Sprintf("%s 必须是有效的邮箱地址", fieldDisplayName) case "min": return fmt.Sprintf("%s 长度不能少于 %s 位", fieldDisplayName, param) case "max": return fmt.Sprintf("%s 长度不能超过 %s 位", fieldDisplayName, param) case "len": return fmt.Sprintf("%s 长度必须为 %s 位", fieldDisplayName, param) case "gt": return fmt.Sprintf("%s 必须大于 %s", fieldDisplayName, param) case "gte": return fmt.Sprintf("%s 必须大于等于 %s", fieldDisplayName, param) case "lt": return fmt.Sprintf("%s 必须小于 %s", fieldDisplayName, param) case "lte": return fmt.Sprintf("%s 必须小于等于 %s", fieldDisplayName, param) case "oneof": return fmt.Sprintf("%s 必须是以下值之一:[%s]", fieldDisplayName, param) case "url": return fmt.Sprintf("%s 必须是有效的URL地址", fieldDisplayName) case "alpha": return fmt.Sprintf("%s 只能包含字母", fieldDisplayName) case "alphanum": return fmt.Sprintf("%s 只能包含字母和数字", fieldDisplayName) case "numeric": return fmt.Sprintf("%s 必须是数字", fieldDisplayName) case "phone": return fmt.Sprintf("%s 必须是有效的手机号", fieldDisplayName) case "username": return fmt.Sprintf("%s 格式不正确,只能包含字母、数字、下划线,且不能以数字开头", fieldDisplayName) case "strong_password": return fmt.Sprintf("%s 强度不足,必须包含大小写字母和数字,且不少于8位", fieldDisplayName) case "eqfield": return fmt.Sprintf("%s 必须与 %s 一致", fieldDisplayName, v.getFieldDisplayName(param)) default: return fmt.Sprintf("%s 格式不正确", fieldDisplayName) } } // getFieldDisplayName 获取字段显示名称(中文) func (v *RequestValidator) getFieldDisplayName(field string) string { fieldNames := map[string]string{ "phone": "手机号", "password": "密码", "confirm_password": "确认密码", "old_password": "原密码", "new_password": "新密码", "confirm_new_password": "确认新密码", "code": "验证码", "username": "用户名", "email": "邮箱", "display_name": "显示名称", "scene": "使用场景", "Password": "密码", "NewPassword": "新密码", } if displayName, exists := fieldNames[field]; exists { return displayName } return field } // toSnakeCase 转换为snake_case func (v *RequestValidator) toSnakeCase(str string) string { var result strings.Builder for i, r := range str { if i > 0 && (r >= 'A' && r <= 'Z') { result.WriteRune('_') } result.WriteRune(r) } return strings.ToLower(result.String()) } // registerCustomValidators 注册自定义验证器 func registerCustomValidators(v *validator.Validate) { // 注册手机号验证器 v.RegisterValidation("phone", validatePhone) // 注册用户名验证器 v.RegisterValidation("username", validateUsername) // 注册密码强度验证器 v.RegisterValidation("strong_password", validateStrongPassword) } // validatePhone 验证手机号 func validatePhone(fl validator.FieldLevel) bool { phone := fl.Field().String() if phone == "" { return true // 空值由required标签处理 } // 简单的手机号验证(可根据需要完善) if len(phone) < 10 || len(phone) > 15 { return false } // 检查是否以+开头或全是数字 if strings.HasPrefix(phone, "+") { phone = phone[1:] } for _, r := range phone { if r < '0' || r > '9' { if r != '-' && r != ' ' && r != '(' && r != ')' { return false } } } return true } // validateUsername 验证用户名 func validateUsername(fl validator.FieldLevel) bool { username := fl.Field().String() if username == "" { return true // 空值由required标签处理 } // 用户名规则:3-30个字符,只能包含字母、数字、下划线,不能以数字开头 if len(username) < 3 || len(username) > 30 { return false } // 不能以数字开头 if username[0] >= '0' && username[0] <= '9' { return false } // 只能包含字母、数字、下划线 for _, r := range username { if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') { return false } } return true } // validateStrongPassword 验证密码强度 func validateStrongPassword(fl validator.FieldLevel) bool { password := fl.Field().String() if password == "" { return true // 空值由required标签处理 } // 密码强度规则:至少8个字符,包含大小写字母、数字 if len(password) < 8 { return false } hasUpper := false hasLower := false hasDigit := false for _, r := range password { switch { case r >= 'A' && r <= 'Z': hasUpper = true case r >= 'a' && r <= 'z': hasLower = true case r >= '0' && r <= '9': hasDigit = true } } return hasUpper && hasLower && hasDigit } // ValidateStruct 直接验证结构体(不通过HTTP上下文) func (v *RequestValidator) ValidateStruct(dto interface{}) error { return v.validator.Struct(dto) } // GetValidator 获取原始验证器(用于特殊情况) func (v *RequestValidator) GetValidator() *validator.Validate { return v.validator }