2025-06-30 19:21:56 +08:00
|
|
|
|
package middleware
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
2025-08-28 00:50:30 +08:00
|
|
|
|
"strings"
|
2025-06-30 19:21:56 +08:00
|
|
|
|
"tyapi-server/internal/config"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/gin-contrib/cors"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// CORSMiddleware CORS中间件
|
|
|
|
|
|
type CORSMiddleware struct {
|
|
|
|
|
|
config *config.Config
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// NewCORSMiddleware 创建CORS中间件
|
|
|
|
|
|
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
|
|
|
|
|
|
return &CORSMiddleware{
|
|
|
|
|
|
config: cfg,
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetName 返回中间件名称
|
|
|
|
|
|
func (m *CORSMiddleware) GetName() string {
|
|
|
|
|
|
return "cors"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// GetPriority 返回中间件优先级
|
|
|
|
|
|
func (m *CORSMiddleware) GetPriority() int {
|
2025-08-28 17:09:21 +08:00
|
|
|
|
return 95 // 在PanicRecovery(100)之后,SecurityHeaders(85)之前执行
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Handle 返回中间件处理函数
|
|
|
|
|
|
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
|
|
|
|
|
|
if !m.config.Development.EnableCors {
|
|
|
|
|
|
// 如果没有启用CORS,返回空处理函数
|
|
|
|
|
|
return func(c *gin.Context) {
|
|
|
|
|
|
c.Next()
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-28 17:09:21 +08:00
|
|
|
|
// 获取CORS配置
|
|
|
|
|
|
origins := m.getAllowedOrigins()
|
|
|
|
|
|
methods := m.getAllowedMethods()
|
|
|
|
|
|
headers := m.getAllowedHeaders()
|
|
|
|
|
|
|
2025-06-30 19:21:56 +08:00
|
|
|
|
config := cors.Config{
|
|
|
|
|
|
AllowAllOrigins: false,
|
2025-08-28 17:09:21 +08:00
|
|
|
|
AllowOrigins: origins,
|
|
|
|
|
|
AllowMethods: methods,
|
|
|
|
|
|
AllowHeaders: headers,
|
2025-06-30 19:21:56 +08:00
|
|
|
|
ExposeHeaders: []string{
|
|
|
|
|
|
"Content-Length",
|
|
|
|
|
|
"Content-Type",
|
|
|
|
|
|
"X-Request-ID",
|
|
|
|
|
|
"X-Response-Time",
|
2025-08-28 17:09:21 +08:00
|
|
|
|
"Access-Control-Allow-Origin",
|
|
|
|
|
|
"Access-Control-Allow-Methods",
|
|
|
|
|
|
"Access-Control-Allow-Headers",
|
2025-06-30 19:21:56 +08:00
|
|
|
|
},
|
|
|
|
|
|
AllowCredentials: true,
|
|
|
|
|
|
MaxAge: 86400, // 24小时
|
2025-08-28 17:09:21 +08:00
|
|
|
|
// 增加Chrome兼容性
|
|
|
|
|
|
AllowWildcard: false,
|
|
|
|
|
|
AllowBrowserExtensions: false,
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-28 17:09:21 +08:00
|
|
|
|
// 创建CORS中间件
|
|
|
|
|
|
corsMiddleware := cors.New(config)
|
|
|
|
|
|
|
|
|
|
|
|
// 返回包装后的中间件
|
|
|
|
|
|
return func(c *gin.Context) {
|
|
|
|
|
|
// 调用实际的CORS中间件
|
|
|
|
|
|
corsMiddleware(c)
|
|
|
|
|
|
|
|
|
|
|
|
// 继续处理下一个中间件或处理器
|
|
|
|
|
|
c.Next()
|
|
|
|
|
|
}
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// IsGlobal 是否为全局中间件
|
|
|
|
|
|
func (m *CORSMiddleware) IsGlobal() bool {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getAllowedOrigins 获取允许的来源
|
|
|
|
|
|
func (m *CORSMiddleware) getAllowedOrigins() []string {
|
|
|
|
|
|
if m.config.Development.CorsOrigins == "" {
|
|
|
|
|
|
return []string{"http://localhost:3000", "http://localhost:8080"}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-28 00:50:30 +08:00
|
|
|
|
// 解析配置中的origins字符串,按逗号分隔
|
|
|
|
|
|
origins := strings.Split(m.config.Development.CorsOrigins, ",")
|
|
|
|
|
|
// 去除空格
|
|
|
|
|
|
for i, origin := range origins {
|
|
|
|
|
|
origins[i] = strings.TrimSpace(origin)
|
|
|
|
|
|
}
|
|
|
|
|
|
return origins
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getAllowedMethods 获取允许的方法
|
|
|
|
|
|
func (m *CORSMiddleware) getAllowedMethods() []string {
|
|
|
|
|
|
if m.config.Development.CorsMethods == "" {
|
|
|
|
|
|
return []string{
|
|
|
|
|
|
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-28 00:50:30 +08:00
|
|
|
|
// 解析配置中的methods字符串,按逗号分隔
|
|
|
|
|
|
methods := strings.Split(m.config.Development.CorsMethods, ",")
|
|
|
|
|
|
// 去除空格
|
|
|
|
|
|
for i, method := range methods {
|
|
|
|
|
|
methods[i] = strings.TrimSpace(method)
|
|
|
|
|
|
}
|
|
|
|
|
|
return methods
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getAllowedHeaders 获取允许的头部
|
|
|
|
|
|
func (m *CORSMiddleware) getAllowedHeaders() []string {
|
|
|
|
|
|
if m.config.Development.CorsHeaders == "" {
|
|
|
|
|
|
return []string{
|
|
|
|
|
|
"Origin",
|
|
|
|
|
|
"Content-Type",
|
2025-08-28 00:50:30 +08:00
|
|
|
|
"Content-Length",
|
2025-06-30 19:21:56 +08:00
|
|
|
|
"Accept",
|
|
|
|
|
|
"Accept-Encoding",
|
|
|
|
|
|
"Accept-Language",
|
2025-08-28 00:50:30 +08:00
|
|
|
|
"Authorization",
|
|
|
|
|
|
"X-Requested-With",
|
2025-06-30 19:21:56 +08:00
|
|
|
|
"X-Request-ID",
|
2025-08-28 00:50:30 +08:00
|
|
|
|
"Access-Id",
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-08-28 00:50:30 +08:00
|
|
|
|
// 解析配置中的headers字符串,按逗号分隔
|
|
|
|
|
|
headers := strings.Split(m.config.Development.CorsHeaders, ",")
|
|
|
|
|
|
// 去除空格
|
|
|
|
|
|
for i, header := range headers {
|
|
|
|
|
|
headers[i] = strings.TrimSpace(header)
|
|
|
|
|
|
}
|
|
|
|
|
|
return headers
|
2025-06-30 19:21:56 +08:00
|
|
|
|
}
|