package middleware import ( "strings" "tyapi-server/internal/config" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) // CORSMiddleware CORS中间件 type CORSMiddleware struct { config *config.Config } // NewCORSMiddleware 创建CORS中间件 func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware { return &CORSMiddleware{ config: cfg, } } // GetName 返回中间件名称 func (m *CORSMiddleware) GetName() string { return "cors" } // GetPriority 返回中间件优先级 func (m *CORSMiddleware) GetPriority() int { return 100 // 高优先级,最先执行 } // Handle 返回中间件处理函数 func (m *CORSMiddleware) Handle() gin.HandlerFunc { if !m.config.Development.EnableCors { // 如果没有启用CORS,返回空处理函数 return func(c *gin.Context) { c.Next() } } config := cors.Config{ AllowAllOrigins: false, AllowOrigins: m.getAllowedOrigins(), AllowMethods: m.getAllowedMethods(), AllowHeaders: m.getAllowedHeaders(), ExposeHeaders: []string{ "Content-Length", "Content-Type", "X-Request-ID", "X-Response-Time", }, AllowCredentials: true, MaxAge: 86400, // 24小时 } return cors.New(config) } // IsGlobal 是否为全局中间件 func (m *CORSMiddleware) IsGlobal() bool { return true } // getAllowedOrigins 获取允许的来源 func (m *CORSMiddleware) getAllowedOrigins() []string { if m.config.Development.CorsOrigins == "" { return []string{"http://localhost:3000", "http://localhost:8080"} } // 解析配置中的origins字符串,按逗号分隔 origins := strings.Split(m.config.Development.CorsOrigins, ",") // 去除空格 for i, origin := range origins { origins[i] = strings.TrimSpace(origin) } return origins } // getAllowedMethods 获取允许的方法 func (m *CORSMiddleware) getAllowedMethods() []string { if m.config.Development.CorsMethods == "" { return []string{ "GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS", } } // 解析配置中的methods字符串,按逗号分隔 methods := strings.Split(m.config.Development.CorsMethods, ",") // 去除空格 for i, method := range methods { methods[i] = strings.TrimSpace(method) } return methods } // getAllowedHeaders 获取允许的头部 func (m *CORSMiddleware) getAllowedHeaders() []string { if m.config.Development.CorsHeaders == "" { return []string{ "Origin", "Content-Type", "Content-Length", "Accept", "Accept-Encoding", "Accept-Language", "Authorization", "X-Requested-With", "X-Request-ID", "Access-Id", } } // 解析配置中的headers字符串,按逗号分隔 headers := strings.Split(m.config.Development.CorsHeaders, ",") // 去除空格 for i, header := range headers { headers[i] = strings.TrimSpace(header) } return headers }