105 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			105 lines
		
	
	
		
			2.3 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package middleware
 | ||
| 
 | ||
| import (
 | ||
| 	"tyapi-server/internal/config"
 | ||
| 
 | ||
| 	"github.com/gin-contrib/cors"
 | ||
| 	"github.com/gin-gonic/gin"
 | ||
| )
 | ||
| 
 | ||
| // CORSMiddleware CORS中间件
 | ||
| type CORSMiddleware struct {
 | ||
| 	config *config.Config
 | ||
| }
 | ||
| 
 | ||
| // NewCORSMiddleware 创建CORS中间件
 | ||
| func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
 | ||
| 	return &CORSMiddleware{
 | ||
| 		config: cfg,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // GetName 返回中间件名称
 | ||
| func (m *CORSMiddleware) GetName() string {
 | ||
| 	return "cors"
 | ||
| }
 | ||
| 
 | ||
| // GetPriority 返回中间件优先级
 | ||
| func (m *CORSMiddleware) GetPriority() int {
 | ||
| 	return 100 // 高优先级,最先执行
 | ||
| }
 | ||
| 
 | ||
| // Handle 返回中间件处理函数
 | ||
| func (m *CORSMiddleware) Handle() gin.HandlerFunc {
 | ||
| 	if !m.config.Development.EnableCors {
 | ||
| 		// 如果没有启用CORS,返回空处理函数
 | ||
| 		return func(c *gin.Context) {
 | ||
| 			c.Next()
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	config := cors.Config{
 | ||
| 		AllowAllOrigins: false,
 | ||
| 		AllowOrigins:    m.getAllowedOrigins(),
 | ||
| 		AllowMethods:    m.getAllowedMethods(),
 | ||
| 		AllowHeaders:    m.getAllowedHeaders(),
 | ||
| 		ExposeHeaders: []string{
 | ||
| 			"Content-Length",
 | ||
| 			"Content-Type",
 | ||
| 			"X-Request-ID",
 | ||
| 			"X-Response-Time",
 | ||
| 		},
 | ||
| 		AllowCredentials: true,
 | ||
| 		MaxAge:           86400, // 24小时
 | ||
| 	}
 | ||
| 
 | ||
| 	return cors.New(config)
 | ||
| }
 | ||
| 
 | ||
| // IsGlobal 是否为全局中间件
 | ||
| func (m *CORSMiddleware) IsGlobal() bool {
 | ||
| 	return true
 | ||
| }
 | ||
| 
 | ||
| // getAllowedOrigins 获取允许的来源
 | ||
| func (m *CORSMiddleware) getAllowedOrigins() []string {
 | ||
| 	if m.config.Development.CorsOrigins == "" {
 | ||
| 		return []string{"http://localhost:3000", "http://localhost:8080"}
 | ||
| 	}
 | ||
| 
 | ||
| 	// TODO: 解析配置中的origins字符串
 | ||
| 	return []string{m.config.Development.CorsOrigins}
 | ||
| }
 | ||
| 
 | ||
| // getAllowedMethods 获取允许的方法
 | ||
| func (m *CORSMiddleware) getAllowedMethods() []string {
 | ||
| 	if m.config.Development.CorsMethods == "" {
 | ||
| 		return []string{
 | ||
| 			"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	// TODO: 解析配置中的methods字符串
 | ||
| 	return []string{m.config.Development.CorsMethods}
 | ||
| }
 | ||
| 
 | ||
| // getAllowedHeaders 获取允许的头部
 | ||
| func (m *CORSMiddleware) getAllowedHeaders() []string {
 | ||
| 	if m.config.Development.CorsHeaders == "" {
 | ||
| 		return []string{
 | ||
| 			"Origin",
 | ||
| 			"Content-Length",
 | ||
| 			"Content-Type",
 | ||
| 			"Authorization",
 | ||
| 			"X-Requested-With",
 | ||
| 			"Accept",
 | ||
| 			"Accept-Encoding",
 | ||
| 			"Accept-Language",
 | ||
| 			"X-Request-ID",
 | ||
| 		}
 | ||
| 	}
 | ||
| 
 | ||
| 	// TODO: 解析配置中的headers字符串
 | ||
| 	return []string{m.config.Development.CorsHeaders}
 | ||
| }
 |