105 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			105 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package middleware | |||
|  | 
 | |||
|  | import ( | |||
|  | 	"tyapi-server/internal/config" | |||
|  | 
 | |||
|  | 	"github.com/gin-contrib/cors" | |||
|  | 	"github.com/gin-gonic/gin" | |||
|  | ) | |||
|  | 
 | |||
|  | // CORSMiddleware CORS中间件 | |||
|  | type CORSMiddleware struct { | |||
|  | 	config *config.Config | |||
|  | } | |||
|  | 
 | |||
|  | // NewCORSMiddleware 创建CORS中间件 | |||
|  | func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware { | |||
|  | 	return &CORSMiddleware{ | |||
|  | 		config: cfg, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // GetName 返回中间件名称 | |||
|  | func (m *CORSMiddleware) GetName() string { | |||
|  | 	return "cors" | |||
|  | } | |||
|  | 
 | |||
|  | // GetPriority 返回中间件优先级 | |||
|  | func (m *CORSMiddleware) GetPriority() int { | |||
|  | 	return 100 // 高优先级,最先执行 | |||
|  | } | |||
|  | 
 | |||
|  | // Handle 返回中间件处理函数 | |||
|  | func (m *CORSMiddleware) Handle() gin.HandlerFunc { | |||
|  | 	if !m.config.Development.EnableCors { | |||
|  | 		// 如果没有启用CORS,返回空处理函数 | |||
|  | 		return func(c *gin.Context) { | |||
|  | 			c.Next() | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	config := cors.Config{ | |||
|  | 		AllowAllOrigins: false, | |||
|  | 		AllowOrigins:    m.getAllowedOrigins(), | |||
|  | 		AllowMethods:    m.getAllowedMethods(), | |||
|  | 		AllowHeaders:    m.getAllowedHeaders(), | |||
|  | 		ExposeHeaders: []string{ | |||
|  | 			"Content-Length", | |||
|  | 			"Content-Type", | |||
|  | 			"X-Request-ID", | |||
|  | 			"X-Response-Time", | |||
|  | 		}, | |||
|  | 		AllowCredentials: true, | |||
|  | 		MaxAge:           86400, // 24小时 | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return cors.New(config) | |||
|  | } | |||
|  | 
 | |||
|  | // IsGlobal 是否为全局中间件 | |||
|  | func (m *CORSMiddleware) IsGlobal() bool { | |||
|  | 	return true | |||
|  | } | |||
|  | 
 | |||
|  | // getAllowedOrigins 获取允许的来源 | |||
|  | func (m *CORSMiddleware) getAllowedOrigins() []string { | |||
|  | 	if m.config.Development.CorsOrigins == "" { | |||
|  | 		return []string{"http://localhost:3000", "http://localhost:8080"} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// TODO: 解析配置中的origins字符串 | |||
|  | 	return []string{m.config.Development.CorsOrigins} | |||
|  | } | |||
|  | 
 | |||
|  | // getAllowedMethods 获取允许的方法 | |||
|  | func (m *CORSMiddleware) getAllowedMethods() []string { | |||
|  | 	if m.config.Development.CorsMethods == "" { | |||
|  | 		return []string{ | |||
|  | 			"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS", | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// TODO: 解析配置中的methods字符串 | |||
|  | 	return []string{m.config.Development.CorsMethods} | |||
|  | } | |||
|  | 
 | |||
|  | // getAllowedHeaders 获取允许的头部 | |||
|  | func (m *CORSMiddleware) getAllowedHeaders() []string { | |||
|  | 	if m.config.Development.CorsHeaders == "" { | |||
|  | 		return []string{ | |||
|  | 			"Origin", | |||
|  | 			"Content-Length", | |||
|  | 			"Content-Type", | |||
|  | 			"Authorization", | |||
|  | 			"X-Requested-With", | |||
|  | 			"Accept", | |||
|  | 			"Accept-Encoding", | |||
|  | 			"Accept-Language", | |||
|  | 			"X-Request-ID", | |||
|  | 		} | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// TODO: 解析配置中的headers字符串 | |||
|  | 	return []string{m.config.Development.CorsHeaders} | |||
|  | } |