Files
tyapi-server/internal/shared/middleware/cors.go

143 lines
3.3 KiB
Go
Raw Normal View History

package middleware
import (
2025-08-28 00:50:30 +08:00
"strings"
"tyapi-server/internal/config"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// CORSMiddleware CORS中间件
type CORSMiddleware struct {
config *config.Config
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
return &CORSMiddleware{
config: cfg,
}
}
// GetName 返回中间件名称
func (m *CORSMiddleware) GetName() string {
return "cors"
}
// GetPriority 返回中间件优先级
func (m *CORSMiddleware) GetPriority() int {
2025-08-28 17:09:21 +08:00
return 95 // 在PanicRecovery(100)之后SecurityHeaders(85)之前执行
}
// Handle 返回中间件处理函数
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
if !m.config.Development.EnableCors {
// 如果没有启用CORS返回空处理函数
return func(c *gin.Context) {
c.Next()
}
}
2025-08-28 17:09:21 +08:00
// 获取CORS配置
origins := m.getAllowedOrigins()
methods := m.getAllowedMethods()
headers := m.getAllowedHeaders()
config := cors.Config{
AllowAllOrigins: false,
2025-08-28 17:09:21 +08:00
AllowOrigins: origins,
AllowMethods: methods,
AllowHeaders: headers,
ExposeHeaders: []string{
"Content-Length",
"Content-Type",
"X-Request-ID",
"X-Response-Time",
2025-08-28 17:09:21 +08:00
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Headers",
},
AllowCredentials: true,
MaxAge: 86400, // 24小时
2025-08-28 17:09:21 +08:00
// 增加Chrome兼容性
AllowWildcard: false,
AllowBrowserExtensions: false,
}
2025-08-28 17:09:21 +08:00
// 创建CORS中间件
corsMiddleware := cors.New(config)
// 返回包装后的中间件
return func(c *gin.Context) {
// 调用实际的CORS中间件
corsMiddleware(c)
// 继续处理下一个中间件或处理器
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *CORSMiddleware) IsGlobal() bool {
return true
}
// getAllowedOrigins 获取允许的来源
func (m *CORSMiddleware) getAllowedOrigins() []string {
if m.config.Development.CorsOrigins == "" {
return []string{"http://localhost:3000", "http://localhost:8080"}
}
2025-08-28 00:50:30 +08:00
// 解析配置中的origins字符串按逗号分隔
origins := strings.Split(m.config.Development.CorsOrigins, ",")
// 去除空格
for i, origin := range origins {
origins[i] = strings.TrimSpace(origin)
}
return origins
}
// getAllowedMethods 获取允许的方法
func (m *CORSMiddleware) getAllowedMethods() []string {
if m.config.Development.CorsMethods == "" {
return []string{
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
}
}
2025-08-28 00:50:30 +08:00
// 解析配置中的methods字符串按逗号分隔
methods := strings.Split(m.config.Development.CorsMethods, ",")
// 去除空格
for i, method := range methods {
methods[i] = strings.TrimSpace(method)
}
return methods
}
// getAllowedHeaders 获取允许的头部
func (m *CORSMiddleware) getAllowedHeaders() []string {
if m.config.Development.CorsHeaders == "" {
return []string{
"Origin",
"Content-Type",
2025-08-28 00:50:30 +08:00
"Content-Length",
"Accept",
"Accept-Encoding",
"Accept-Language",
2025-08-28 00:50:30 +08:00
"Authorization",
"X-Requested-With",
"X-Request-ID",
2025-08-28 00:50:30 +08:00
"Access-Id",
}
}
2025-08-28 00:50:30 +08:00
// 解析配置中的headers字符串按逗号分隔
headers := strings.Split(m.config.Development.CorsHeaders, ",")
// 去除空格
for i, header := range headers {
headers[i] = strings.TrimSpace(header)
}
return headers
}