Initial commit: Basic project structure and dependencies
This commit is contained in:
260
internal/shared/http/response.go
Normal file
260
internal/shared/http/response.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ResponseBuilder 响应构建器实现
|
||||
type ResponseBuilder struct{}
|
||||
|
||||
// NewResponseBuilder 创建响应构建器
|
||||
func NewResponseBuilder() interfaces.ResponseBuilder {
|
||||
return &ResponseBuilder{}
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Success"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Created 创建成功响应
|
||||
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Created successfully"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
|
||||
// 根据错误类型确定状态码
|
||||
statusCode := http.StatusInternalServerError
|
||||
message := "Internal server error"
|
||||
errorDetail := err.Error()
|
||||
|
||||
// 这里可以根据不同的错误类型设置不同的状态码
|
||||
// 例如:ValidationError -> 400, NotFoundError -> 404, etc.
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
Errors: errorDetail,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// BadRequest 400错误响应
|
||||
func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
response.Errors = errors[0]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Unauthorized 401错误响应
|
||||
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
|
||||
msg := "Unauthorized"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
}
|
||||
|
||||
// Forbidden 403错误响应
|
||||
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
|
||||
msg := "Forbidden"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusForbidden, response)
|
||||
}
|
||||
|
||||
// NotFound 404错误响应
|
||||
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
|
||||
msg := "Resource not found"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Conflict 409错误响应
|
||||
func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusConflict, response)
|
||||
}
|
||||
|
||||
// InternalError 500错误响应
|
||||
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
|
||||
msg := "Internal server error"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Paginated 分页响应
|
||||
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: "Success",
|
||||
Data: data,
|
||||
Pagination: &pagination,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// getRequestID 从上下文获取请求ID
|
||||
func (r *ResponseBuilder) getRequestID(c *gin.Context) string {
|
||||
if requestID, exists := c.Get("request_id"); exists {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildPagination 构建分页元数据
|
||||
func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
|
||||
totalPages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
return interfaces.PaginationMeta{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
HasNext: page < totalPages,
|
||||
HasPrev: page > 1,
|
||||
}
|
||||
}
|
||||
|
||||
// CustomResponse 自定义响应
|
||||
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: statusCode >= 200 && statusCode < 300,
|
||||
Message: http.StatusText(statusCode),
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// ValidationError 验证错误响应
|
||||
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: "Validation failed",
|
||||
Errors: errors,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnprocessableEntity, response)
|
||||
}
|
||||
|
||||
// TooManyRequests 限流错误响应
|
||||
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
|
||||
msg := "Too many requests"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Meta: map[string]interface{}{
|
||||
"retry_after": "60s",
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, response)
|
||||
}
|
||||
258
internal/shared/http/router.go
Normal file
258
internal/shared/http/router.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GinRouter Gin路由器实现
|
||||
type GinRouter struct {
|
||||
engine *gin.Engine
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
middlewares []interfaces.Middleware
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewGinRouter 创建Gin路由器
|
||||
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *GinRouter {
|
||||
// 设置Gin模式
|
||||
if cfg.App.IsProduction() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
} else {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
|
||||
// 创建Gin引擎
|
||||
engine := gin.New()
|
||||
|
||||
return &GinRouter{
|
||||
engine: engine,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
middlewares: make([]interfaces.Middleware, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册处理器
|
||||
func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
|
||||
// 应用处理器中间件
|
||||
middlewares := handler.GetMiddlewares()
|
||||
|
||||
// 注册路由
|
||||
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
|
||||
|
||||
r.logger.Info("Registered HTTP handler",
|
||||
zap.String("method", handler.GetMethod()),
|
||||
zap.String("path", handler.GetPath()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterMiddleware 注册中间件
|
||||
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
|
||||
r.middlewares = append(r.middlewares, middleware)
|
||||
|
||||
r.logger.Info("Registered middleware",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterGroup 注册路由组
|
||||
func (r *GinRouter) RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes {
|
||||
return r.engine.Group(prefix, middlewares...)
|
||||
}
|
||||
|
||||
// GetRoutes 获取路由信息
|
||||
func (r *GinRouter) GetRoutes() gin.RoutesInfo {
|
||||
return r.engine.Routes()
|
||||
}
|
||||
|
||||
// Start 启动路由器
|
||||
func (r *GinRouter) Start(addr string) error {
|
||||
// 应用中间件(按优先级排序)
|
||||
r.applyMiddlewares()
|
||||
|
||||
// 创建HTTP服务器
|
||||
r.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r.engine,
|
||||
ReadTimeout: r.config.Server.ReadTimeout,
|
||||
WriteTimeout: r.config.Server.WriteTimeout,
|
||||
IdleTimeout: r.config.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
|
||||
|
||||
// 启动服务器
|
||||
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止路由器
|
||||
func (r *GinRouter) Stop(ctx context.Context) error {
|
||||
if r.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Info("Stopping HTTP server...")
|
||||
|
||||
// 优雅关闭服务器
|
||||
if err := r.server.Shutdown(ctx); err != nil {
|
||||
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("HTTP server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEngine 获取Gin引擎
|
||||
func (r *GinRouter) GetEngine() *gin.Engine {
|
||||
return r.engine
|
||||
}
|
||||
|
||||
// applyMiddlewares 应用中间件
|
||||
func (r *GinRouter) applyMiddlewares() {
|
||||
// 按优先级排序中间件
|
||||
sort.Slice(r.middlewares, func(i, j int) bool {
|
||||
return r.middlewares[i].GetPriority() > r.middlewares[j].GetPriority()
|
||||
})
|
||||
|
||||
// 应用全局中间件
|
||||
for _, middleware := range r.middlewares {
|
||||
if middleware.IsGlobal() {
|
||||
r.engine.Use(middleware.Handle())
|
||||
r.logger.Debug("Applied global middleware",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupDefaultRoutes 设置默认路由
|
||||
func (r *GinRouter) SetupDefaultRoutes() {
|
||||
// 健康检查
|
||||
r.engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"service": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
})
|
||||
})
|
||||
|
||||
// API信息
|
||||
r.engine.GET("/info", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
"environment": r.config.App.Env,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// 404处理
|
||||
r.engine.NoRoute(func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Route not found",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// 405处理
|
||||
r.engine.NoMethod(func(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"success": false,
|
||||
"message": "Method not allowed",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// PrintRoutes 打印路由信息
|
||||
func (r *GinRouter) PrintRoutes() {
|
||||
routes := r.GetRoutes()
|
||||
|
||||
r.logger.Info("Registered routes:")
|
||||
for _, route := range routes {
|
||||
r.logger.Info("Route",
|
||||
zap.String("method", route.Method),
|
||||
zap.String("path", route.Path),
|
||||
zap.String("handler", route.Handler))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取路由器统计信息
|
||||
func (r *GinRouter) GetStats() map[string]interface{} {
|
||||
routes := r.GetRoutes()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_routes": len(routes),
|
||||
"total_middlewares": len(r.middlewares),
|
||||
"server_config": map[string]interface{}{
|
||||
"read_timeout": r.config.Server.ReadTimeout,
|
||||
"write_timeout": r.config.Server.WriteTimeout,
|
||||
"idle_timeout": r.config.Server.IdleTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
// 按方法统计路由数量
|
||||
methodStats := make(map[string]int)
|
||||
for _, route := range routes {
|
||||
methodStats[route.Method]++
|
||||
}
|
||||
stats["routes_by_method"] = methodStats
|
||||
|
||||
// 中间件统计
|
||||
middlewareStats := make([]map[string]interface{}, 0, len(r.middlewares))
|
||||
for _, middleware := range r.middlewares {
|
||||
middlewareStats = append(middlewareStats, map[string]interface{}{
|
||||
"name": middleware.GetName(),
|
||||
"priority": middleware.GetPriority(),
|
||||
"global": middleware.IsGlobal(),
|
||||
})
|
||||
}
|
||||
stats["middlewares"] = middlewareStats
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// EnableMetrics 启用指标收集
|
||||
func (r *GinRouter) EnableMetrics(collector interfaces.MetricsCollector) {
|
||||
r.engine.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
duration := time.Since(start).Seconds()
|
||||
collector.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableProfiling 启用性能分析
|
||||
func (r *GinRouter) EnableProfiling() {
|
||||
if r.config.Development.EnableProfiler {
|
||||
// 这里可以集成pprof
|
||||
r.logger.Info("Profiling enabled")
|
||||
}
|
||||
}
|
||||
273
internal/shared/http/validator.go
Normal file
273
internal/shared/http/validator.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// RequestValidator 请求验证器实现
|
||||
type RequestValidator struct {
|
||||
validator *validator.Validate
|
||||
response interfaces.ResponseBuilder
|
||||
}
|
||||
|
||||
// NewRequestValidator 创建请求验证器
|
||||
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
|
||||
v := validator.New()
|
||||
|
||||
// 注册自定义验证器
|
||||
registerCustomValidators(v)
|
||||
|
||||
return &RequestValidator{
|
||||
validator: v,
|
||||
response: response,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证请求体
|
||||
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateQuery 验证查询参数
|
||||
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindQuery(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid query parameters", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateParam 验证路径参数
|
||||
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindUri(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid path parameters", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindAndValidate 绑定并验证请求
|
||||
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
|
||||
// 绑定请求体
|
||||
if err := c.ShouldBindJSON(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid request body", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证数据
|
||||
return v.Validate(c, dto)
|
||||
}
|
||||
|
||||
// formatValidationErrors 格式化验证错误
|
||||
func (v *RequestValidator) formatValidationErrors(err error) map[string][]string {
|
||||
errors := make(map[string][]string)
|
||||
|
||||
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, fieldError := range validationErrors {
|
||||
fieldName := v.getFieldName(fieldError)
|
||||
errorMessage := v.getErrorMessage(fieldError)
|
||||
|
||||
if _, exists := errors[fieldName]; !exists {
|
||||
errors[fieldName] = []string{}
|
||||
}
|
||||
errors[fieldName] = append(errors[fieldName], errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// getFieldName 获取字段名(JSON标签优先)
|
||||
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
|
||||
// 可以通过反射获取JSON标签,这里简化处理
|
||||
fieldName := fieldError.Field()
|
||||
|
||||
// 转换为snake_case(可选)
|
||||
return v.toSnakeCase(fieldName)
|
||||
}
|
||||
|
||||
// getErrorMessage 获取错误消息
|
||||
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
|
||||
field := fieldError.Field()
|
||||
tag := fieldError.Tag()
|
||||
param := fieldError.Param()
|
||||
|
||||
switch tag {
|
||||
case "required":
|
||||
return fmt.Sprintf("%s is required", field)
|
||||
case "email":
|
||||
return fmt.Sprintf("%s must be a valid email address", field)
|
||||
case "min":
|
||||
return fmt.Sprintf("%s must be at least %s characters", field, param)
|
||||
case "max":
|
||||
return fmt.Sprintf("%s must be at most %s characters", field, param)
|
||||
case "len":
|
||||
return fmt.Sprintf("%s must be exactly %s characters", field, param)
|
||||
case "gt":
|
||||
return fmt.Sprintf("%s must be greater than %s", field, param)
|
||||
case "gte":
|
||||
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
|
||||
case "lt":
|
||||
return fmt.Sprintf("%s must be less than %s", field, param)
|
||||
case "lte":
|
||||
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
|
||||
case "oneof":
|
||||
return fmt.Sprintf("%s must be one of [%s]", field, param)
|
||||
case "url":
|
||||
return fmt.Sprintf("%s must be a valid URL", field)
|
||||
case "alpha":
|
||||
return fmt.Sprintf("%s must contain only alphabetic characters", field)
|
||||
case "alphanum":
|
||||
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
|
||||
case "numeric":
|
||||
return fmt.Sprintf("%s must be numeric", field)
|
||||
case "phone":
|
||||
return fmt.Sprintf("%s must be a valid phone number", field)
|
||||
case "username":
|
||||
return fmt.Sprintf("%s must be a valid username", field)
|
||||
default:
|
||||
return fmt.Sprintf("%s is invalid", field)
|
||||
}
|
||||
}
|
||||
|
||||
// toSnakeCase 转换为snake_case
|
||||
func (v *RequestValidator) toSnakeCase(str string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range str {
|
||||
if i > 0 && (r >= 'A' && r <= 'Z') {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// registerCustomValidators 注册自定义验证器
|
||||
func registerCustomValidators(v *validator.Validate) {
|
||||
// 注册手机号验证器
|
||||
v.RegisterValidation("phone", validatePhone)
|
||||
|
||||
// 注册用户名验证器
|
||||
v.RegisterValidation("username", validateUsername)
|
||||
|
||||
// 注册密码强度验证器
|
||||
v.RegisterValidation("strong_password", validateStrongPassword)
|
||||
}
|
||||
|
||||
// validatePhone 验证手机号
|
||||
func validatePhone(fl validator.FieldLevel) bool {
|
||||
phone := fl.Field().String()
|
||||
if phone == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 简单的手机号验证(可根据需要完善)
|
||||
if len(phone) < 10 || len(phone) > 15 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否以+开头或全是数字
|
||||
if strings.HasPrefix(phone, "+") {
|
||||
phone = phone[1:]
|
||||
}
|
||||
|
||||
for _, r := range phone {
|
||||
if r < '0' || r > '9' {
|
||||
if r != '-' && r != ' ' && r != '(' && r != ')' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateUsername 验证用户名
|
||||
func validateUsername(fl validator.FieldLevel) bool {
|
||||
username := fl.Field().String()
|
||||
if username == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 用户名规则:3-30个字符,只能包含字母、数字、下划线,不能以数字开头
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不能以数字开头
|
||||
if username[0] >= '0' && username[0] <= '9' {
|
||||
return false
|
||||
}
|
||||
|
||||
// 只能包含字母、数字、下划线
|
||||
for _, r := range username {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateStrongPassword 验证密码强度
|
||||
func validateStrongPassword(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
if password == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 密码强度规则:至少8个字符,包含大小写字母、数字
|
||||
if len(password) < 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
|
||||
for _, r := range password {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
hasUpper = true
|
||||
case r >= 'a' && r <= 'z':
|
||||
hasLower = true
|
||||
case r >= '0' && r <= '9':
|
||||
hasDigit = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasUpper && hasLower && hasDigit
|
||||
}
|
||||
|
||||
// ValidateStruct 直接验证结构体(不通过HTTP上下文)
|
||||
func (v *RequestValidator) ValidateStruct(dto interface{}) error {
|
||||
return v.validator.Struct(dto)
|
||||
}
|
||||
|
||||
// GetValidator 获取原始验证器(用于特殊情况)
|
||||
func (v *RequestValidator) GetValidator() *validator.Validate {
|
||||
return v.validator
|
||||
}
|
||||
Reference in New Issue
Block a user