From 50a4fa86ce7906059843986009a13a1c6e60d34c Mon Sep 17 00:00:00 2001 From: liangzai <2440983361@qq.com> Date: Thu, 28 Aug 2025 00:50:30 +0800 Subject: [PATCH] add cors --- config.yaml | 4 ++-- internal/shared/middleware/cors.go | 35 ++++++++++++++++++++++-------- 2 files changed, 28 insertions(+), 11 deletions(-) diff --git a/config.yaml b/config.yaml index 3a8cc60..8b7c83b 100644 --- a/config.yaml +++ b/config.yaml @@ -228,9 +228,9 @@ development: debug: true enable_profiler: true enable_cors: true - cors_allowed_origins: "http://localhost:3000,http://localhost:3001" + cors_allowed_origins: "https://consoletest.tianyuanapi.com,https://console.tianyuanapi.com" cors_allowed_methods: "GET,POST,PUT,PATCH,DELETE,OPTIONS" - cors_allowed_headers: "Origin,Content-Type,Accept,Authorization,X-Requested-With" + cors_allowed_headers: "Origin,Content-Type,Accept,Authorization,X-Requested-With,Access-Id" # 企业微信配置 wechat_work: diff --git a/internal/shared/middleware/cors.go b/internal/shared/middleware/cors.go index a4a846d..046f617 100644 --- a/internal/shared/middleware/cors.go +++ b/internal/shared/middleware/cors.go @@ -1,6 +1,7 @@ package middleware import ( + "strings" "tyapi-server/internal/config" "github.com/gin-contrib/cors" @@ -67,8 +68,13 @@ func (m *CORSMiddleware) getAllowedOrigins() []string { return []string{"http://localhost:3000", "http://localhost:8080"} } - // TODO: 解析配置中的origins字符串 - return []string{m.config.Development.CorsOrigins} + // 解析配置中的origins字符串,按逗号分隔 + origins := strings.Split(m.config.Development.CorsOrigins, ",") + // 去除空格 + for i, origin := range origins { + origins[i] = strings.TrimSpace(origin) + } + return origins } // getAllowedMethods 获取允许的方法 @@ -79,8 +85,13 @@ func (m *CORSMiddleware) getAllowedMethods() []string { } } - // TODO: 解析配置中的methods字符串 - return []string{m.config.Development.CorsMethods} + // 解析配置中的methods字符串,按逗号分隔 + methods := strings.Split(m.config.Development.CorsMethods, ",") + // 去除空格 + for i, method := range methods { + methods[i] = strings.TrimSpace(method) + } + return methods } // getAllowedHeaders 获取允许的头部 @@ -88,17 +99,23 @@ func (m *CORSMiddleware) getAllowedHeaders() []string { if m.config.Development.CorsHeaders == "" { return []string{ "Origin", - "Content-Length", "Content-Type", - "Authorization", - "X-Requested-With", + "Content-Length", "Accept", "Accept-Encoding", "Accept-Language", + "Authorization", + "X-Requested-With", "X-Request-ID", + "Access-Id", } } - // TODO: 解析配置中的headers字符串 - return []string{m.config.Development.CorsHeaders} + // 解析配置中的headers字符串,按逗号分隔 + headers := strings.Split(m.config.Development.CorsHeaders, ",") + // 去除空格 + for i, header := range headers { + headers[i] = strings.TrimSpace(header) + } + return headers }