Files
tyapi-server/internal/shared/middleware/cors.go

105 lines
2.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package middleware
import (
"tyapi-server/internal/config"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// CORSMiddleware CORS中间件
type CORSMiddleware struct {
config *config.Config
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
return &CORSMiddleware{
config: cfg,
}
}
// GetName 返回中间件名称
func (m *CORSMiddleware) GetName() string {
return "cors"
}
// GetPriority 返回中间件优先级
func (m *CORSMiddleware) GetPriority() int {
return 100 // 高优先级,最先执行
}
// Handle 返回中间件处理函数
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
if !m.config.Development.EnableCors {
// 如果没有启用CORS返回空处理函数
return func(c *gin.Context) {
c.Next()
}
}
config := cors.Config{
AllowAllOrigins: false,
AllowOrigins: m.getAllowedOrigins(),
AllowMethods: m.getAllowedMethods(),
AllowHeaders: m.getAllowedHeaders(),
ExposeHeaders: []string{
"Content-Length",
"Content-Type",
"X-Request-ID",
"X-Response-Time",
},
AllowCredentials: true,
MaxAge: 86400, // 24小时
}
return cors.New(config)
}
// IsGlobal 是否为全局中间件
func (m *CORSMiddleware) IsGlobal() bool {
return true
}
// getAllowedOrigins 获取允许的来源
func (m *CORSMiddleware) getAllowedOrigins() []string {
if m.config.Development.CorsOrigins == "" {
return []string{"http://localhost:3000", "http://localhost:8080"}
}
// TODO: 解析配置中的origins字符串
return []string{m.config.Development.CorsOrigins}
}
// getAllowedMethods 获取允许的方法
func (m *CORSMiddleware) getAllowedMethods() []string {
if m.config.Development.CorsMethods == "" {
return []string{
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
}
}
// TODO: 解析配置中的methods字符串
return []string{m.config.Development.CorsMethods}
}
// getAllowedHeaders 获取允许的头部
func (m *CORSMiddleware) getAllowedHeaders() []string {
if m.config.Development.CorsHeaders == "" {
return []string{
"Origin",
"Content-Length",
"Content-Type",
"Authorization",
"X-Requested-With",
"Accept",
"Accept-Encoding",
"Accept-Language",
"X-Request-ID",
}
}
// TODO: 解析配置中的headers字符串
return []string{m.config.Development.CorsHeaders}
}