package middleware import ( "tyapi-server/internal/config" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" ) // CORSMiddleware CORS中间件 type CORSMiddleware struct { config *config.Config } // NewCORSMiddleware 创建CORS中间件 func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware { return &CORSMiddleware{ config: cfg, } } // GetName 返回中间件名称 func (m *CORSMiddleware) GetName() string { return "cors" } // GetPriority 返回中间件优先级 func (m *CORSMiddleware) GetPriority() int { return 100 // 高优先级,最先执行 } // Handle 返回中间件处理函数 func (m *CORSMiddleware) Handle() gin.HandlerFunc { if !m.config.Development.EnableCors { // 如果没有启用CORS,返回空处理函数 return func(c *gin.Context) { c.Next() } } config := cors.Config{ AllowAllOrigins: false, AllowOrigins: m.getAllowedOrigins(), AllowMethods: m.getAllowedMethods(), AllowHeaders: m.getAllowedHeaders(), ExposeHeaders: []string{ "Content-Length", "Content-Type", "X-Request-ID", "X-Response-Time", }, AllowCredentials: true, MaxAge: 86400, // 24小时 } return cors.New(config) } // IsGlobal 是否为全局中间件 func (m *CORSMiddleware) IsGlobal() bool { return true } // getAllowedOrigins 获取允许的来源 func (m *CORSMiddleware) getAllowedOrigins() []string { if m.config.Development.CorsOrigins == "" { return []string{"http://localhost:3000", "http://localhost:8080"} } // TODO: 解析配置中的origins字符串 return []string{m.config.Development.CorsOrigins} } // getAllowedMethods 获取允许的方法 func (m *CORSMiddleware) getAllowedMethods() []string { if m.config.Development.CorsMethods == "" { return []string{ "GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS", } } // TODO: 解析配置中的methods字符串 return []string{m.config.Development.CorsMethods} } // getAllowedHeaders 获取允许的头部 func (m *CORSMiddleware) getAllowedHeaders() []string { if m.config.Development.CorsHeaders == "" { return []string{ "Origin", "Content-Length", "Content-Type", "Authorization", "X-Requested-With", "Accept", "Accept-Encoding", "Accept-Language", "X-Request-ID", } } // TODO: 解析配置中的headers字符串 return []string{m.config.Development.CorsHeaders} }