Files
tyapi-server/internal/shared/middleware/cors.go
2025-08-28 17:09:21 +08:00

143 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"strings"
"tyapi-server/internal/config"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// CORSMiddleware CORS中间件
type CORSMiddleware struct {
config *config.Config
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
return &CORSMiddleware{
config: cfg,
}
}
// GetName 返回中间件名称
func (m *CORSMiddleware) GetName() string {
return "cors"
}
// GetPriority 返回中间件优先级
func (m *CORSMiddleware) GetPriority() int {
return 95 // 在PanicRecovery(100)之后SecurityHeaders(85)之前执行
}
// Handle 返回中间件处理函数
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
if !m.config.Development.EnableCors {
// 如果没有启用CORS返回空处理函数
return func(c *gin.Context) {
c.Next()
}
}
// 获取CORS配置
origins := m.getAllowedOrigins()
methods := m.getAllowedMethods()
headers := m.getAllowedHeaders()
config := cors.Config{
AllowAllOrigins: false,
AllowOrigins: origins,
AllowMethods: methods,
AllowHeaders: headers,
ExposeHeaders: []string{
"Content-Length",
"Content-Type",
"X-Request-ID",
"X-Response-Time",
"Access-Control-Allow-Origin",
"Access-Control-Allow-Methods",
"Access-Control-Allow-Headers",
},
AllowCredentials: true,
MaxAge: 86400, // 24小时
// 增加Chrome兼容性
AllowWildcard: false,
AllowBrowserExtensions: false,
}
// 创建CORS中间件
corsMiddleware := cors.New(config)
// 返回包装后的中间件
return func(c *gin.Context) {
// 调用实际的CORS中间件
corsMiddleware(c)
// 继续处理下一个中间件或处理器
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *CORSMiddleware) IsGlobal() bool {
return true
}
// getAllowedOrigins 获取允许的来源
func (m *CORSMiddleware) getAllowedOrigins() []string {
if m.config.Development.CorsOrigins == "" {
return []string{"http://localhost:3000", "http://localhost:8080"}
}
// 解析配置中的origins字符串按逗号分隔
origins := strings.Split(m.config.Development.CorsOrigins, ",")
// 去除空格
for i, origin := range origins {
origins[i] = strings.TrimSpace(origin)
}
return origins
}
// getAllowedMethods 获取允许的方法
func (m *CORSMiddleware) getAllowedMethods() []string {
if m.config.Development.CorsMethods == "" {
return []string{
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
}
}
// 解析配置中的methods字符串按逗号分隔
methods := strings.Split(m.config.Development.CorsMethods, ",")
// 去除空格
for i, method := range methods {
methods[i] = strings.TrimSpace(method)
}
return methods
}
// getAllowedHeaders 获取允许的头部
func (m *CORSMiddleware) getAllowedHeaders() []string {
if m.config.Development.CorsHeaders == "" {
return []string{
"Origin",
"Content-Type",
"Content-Length",
"Accept",
"Accept-Encoding",
"Accept-Language",
"Authorization",
"X-Requested-With",
"X-Request-ID",
"Access-Id",
}
}
// 解析配置中的headers字符串按逗号分隔
headers := strings.Split(m.config.Development.CorsHeaders, ",")
// 去除空格
for i, header := range headers {
headers[i] = strings.TrimSpace(header)
}
return headers
}