Initial commit: Basic project structure and dependencies
This commit is contained in:
104
internal/shared/middleware/cors.go
Normal file
104
internal/shared/middleware/cors.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware CORS中间件
|
||||
type CORSMiddleware struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewCORSMiddleware 创建CORS中间件
|
||||
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
|
||||
return &CORSMiddleware{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *CORSMiddleware) GetName() string {
|
||||
return "cors"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *CORSMiddleware) GetPriority() int {
|
||||
return 100 // 高优先级,最先执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.config.Development.EnableCors {
|
||||
// 如果没有启用CORS,返回空处理函数
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
config := cors.Config{
|
||||
AllowAllOrigins: false,
|
||||
AllowOrigins: m.getAllowedOrigins(),
|
||||
AllowMethods: m.getAllowedMethods(),
|
||||
AllowHeaders: m.getAllowedHeaders(),
|
||||
ExposeHeaders: []string{
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24小时
|
||||
}
|
||||
|
||||
return cors.New(config)
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *CORSMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getAllowedOrigins 获取允许的来源
|
||||
func (m *CORSMiddleware) getAllowedOrigins() []string {
|
||||
if m.config.Development.CorsOrigins == "" {
|
||||
return []string{"http://localhost:3000", "http://localhost:8080"}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的origins字符串
|
||||
return []string{m.config.Development.CorsOrigins}
|
||||
}
|
||||
|
||||
// getAllowedMethods 获取允许的方法
|
||||
func (m *CORSMiddleware) getAllowedMethods() []string {
|
||||
if m.config.Development.CorsMethods == "" {
|
||||
return []string{
|
||||
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的methods字符串
|
||||
return []string{m.config.Development.CorsMethods}
|
||||
}
|
||||
|
||||
// getAllowedHeaders 获取允许的头部
|
||||
func (m *CORSMiddleware) getAllowedHeaders() []string {
|
||||
if m.config.Development.CorsHeaders == "" {
|
||||
return []string{
|
||||
"Origin",
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"Accept",
|
||||
"Accept-Encoding",
|
||||
"Accept-Language",
|
||||
"X-Request-ID",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的headers字符串
|
||||
return []string{m.config.Development.CorsHeaders}
|
||||
}
|
||||
Reference in New Issue
Block a user