f
This commit is contained in:
267
internal/config/loader.go
Normal file
267
internal/config/loader.go
Normal file
@@ -0,0 +1,267 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// LoadConfig 加载应用程序配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
// 1️⃣ 获取环境变量决定配置文件
|
||||
env := getEnvironment()
|
||||
fmt.Printf("🔧 当前运行环境: %s\n", env)
|
||||
|
||||
// 2️⃣ 加载基础配置文件
|
||||
baseConfig := viper.New()
|
||||
baseConfig.SetConfigName("config")
|
||||
baseConfig.SetConfigType("yaml")
|
||||
baseConfig.AddConfigPath(".")
|
||||
baseConfig.AddConfigPath("./configs")
|
||||
baseConfig.AddConfigPath("$HOME/.hyapi")
|
||||
|
||||
// 读取基础配置文件
|
||||
if err := baseConfig.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("读取基础配置文件失败: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("未找到 config.yaml 文件,请确保配置文件存在")
|
||||
}
|
||||
fmt.Printf("✅ 已加载配置文件: %s\n", baseConfig.ConfigFileUsed())
|
||||
|
||||
// 3️⃣ 加载环境特定配置文件
|
||||
envConfigFile := findEnvConfigFile(env)
|
||||
if envConfigFile != "" {
|
||||
// 创建一个新的viper实例来读取环境配置
|
||||
envConfig := viper.New()
|
||||
envConfig.SetConfigFile(envConfigFile)
|
||||
|
||||
if err := envConfig.ReadInConfig(); err != nil {
|
||||
fmt.Printf("⚠️ 环境配置文件加载警告: %v\n", err)
|
||||
} else {
|
||||
fmt.Printf("✅ 已加载环境配置: %s\n", envConfigFile)
|
||||
|
||||
// 将环境配置合并到基础配置中
|
||||
if err := mergeConfigs(baseConfig, envConfig.AllSettings()); err != nil {
|
||||
return nil, fmt.Errorf("合并配置失败: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("ℹ️ 未找到环境配置文件 configs/env.%s.yaml,将使用基础配置\n", env)
|
||||
}
|
||||
|
||||
// 4️⃣ 手动处理环境变量覆盖,避免空值覆盖配置文件
|
||||
// overrideWithEnvVars(baseConfig)
|
||||
|
||||
// 5️⃣ 解析配置到结构体
|
||||
var config Config
|
||||
if err := baseConfig.Unmarshal(&config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 6️⃣ 验证配置
|
||||
if err := validateConfig(&config); err != nil {
|
||||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 7️⃣ 输出配置摘要
|
||||
printConfigSummary(&config, env)
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// mergeConfigs 递归合并配置
|
||||
func mergeConfigs(baseConfig *viper.Viper, overrideSettings map[string]interface{}) error {
|
||||
for key, val := range overrideSettings {
|
||||
// 如果值是一个嵌套的map,则递归合并
|
||||
if subMap, ok := val.(map[string]interface{}); ok {
|
||||
// 创建子键路径
|
||||
subKey := key
|
||||
|
||||
// 递归合并子配置
|
||||
for subK, subV := range subMap {
|
||||
fullKey := fmt.Sprintf("%s.%s", subKey, subK)
|
||||
baseConfig.Set(fullKey, subV)
|
||||
}
|
||||
} else {
|
||||
// 直接设置值
|
||||
baseConfig.Set(key, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// findEnvConfigFile 查找环境特定的配置文件
|
||||
func findEnvConfigFile(env string) string {
|
||||
// 只查找 configs 目录下的环境配置文件
|
||||
possiblePaths := []string{
|
||||
fmt.Sprintf("configs/env.%s.yaml", env),
|
||||
fmt.Sprintf("configs/env.%s.yml", env),
|
||||
}
|
||||
|
||||
for _, path := range possiblePaths {
|
||||
if _, err := os.Stat(path); err == nil {
|
||||
absPath, _ := filepath.Abs(path)
|
||||
return absPath
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// getEnvironment 获取当前环境
|
||||
func getEnvironment() string {
|
||||
var env string
|
||||
var source string
|
||||
|
||||
// 优先级:CONFIG_ENV > ENV > APP_ENV > 默认值
|
||||
if env = os.Getenv("CONFIG_ENV"); env != "" {
|
||||
source = "CONFIG_ENV"
|
||||
} else if env = os.Getenv("ENV"); env != "" {
|
||||
source = "ENV"
|
||||
} else if env = os.Getenv("APP_ENV"); env != "" {
|
||||
source = "APP_ENV"
|
||||
} else {
|
||||
env = "development"
|
||||
source = "默认值"
|
||||
}
|
||||
|
||||
fmt.Printf("🌍 环境检测: %s (来源: %s)\n", env, source)
|
||||
|
||||
// 验证环境值
|
||||
validEnvs := []string{"development", "production", "testing"}
|
||||
isValid := false
|
||||
for _, validEnv := range validEnvs {
|
||||
if env == validEnv {
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isValid {
|
||||
fmt.Printf("⚠️ 警告: 未识别的环境 '%s',将使用默认环境 'development'\n", env)
|
||||
return "development"
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
// printConfigSummary 打印配置摘要
|
||||
func printConfigSummary(config *Config, env string) {
|
||||
fmt.Printf("\n🔧 配置摘要:\n")
|
||||
fmt.Printf(" 🌍 环境: %s\n", env)
|
||||
fmt.Printf(" 📄 基础配置: config.yaml\n")
|
||||
fmt.Printf(" 📁 环境配置: configs/env.%s.yaml\n", env)
|
||||
fmt.Printf(" 📱 应用名称: %s\n", config.App.Name)
|
||||
fmt.Printf(" 🔖 版本: %s\n", config.App.Version)
|
||||
fmt.Printf(" 🌐 服务端口: %s\n", config.Server.Port)
|
||||
fmt.Printf(" 🗄️ 数据库: %s@%s:%s/%s\n",
|
||||
config.Database.User,
|
||||
config.Database.Host,
|
||||
config.Database.Port,
|
||||
config.Database.Name)
|
||||
fmt.Printf(" 📊 追踪状态: %v (端点: %s)\n",
|
||||
config.Monitoring.TracingEnabled,
|
||||
config.Monitoring.TracingEndpoint)
|
||||
fmt.Printf(" 📈 采样率: %.1f%%\n", config.Monitoring.SampleRate*100)
|
||||
fmt.Printf("\n")
|
||||
}
|
||||
|
||||
// validateConfig 验证配置
|
||||
func validateConfig(config *Config) error {
|
||||
// 验证必要的配置项
|
||||
if config.Database.Host == "" {
|
||||
return fmt.Errorf("数据库主机地址不能为空")
|
||||
}
|
||||
|
||||
if config.Database.User == "" {
|
||||
return fmt.Errorf("数据库用户名不能为空")
|
||||
}
|
||||
|
||||
if config.Database.Name == "" {
|
||||
return fmt.Errorf("数据库名称不能为空")
|
||||
}
|
||||
|
||||
if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" {
|
||||
if config.App.IsProduction() {
|
||||
return fmt.Errorf("生产环境必须设置安全的JWT密钥")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证超时配置
|
||||
if config.Server.ReadTimeout <= 0 {
|
||||
return fmt.Errorf("服务器读取超时时间必须大于0")
|
||||
}
|
||||
|
||||
if config.Server.WriteTimeout <= 0 {
|
||||
return fmt.Errorf("服务器写入超时时间必须大于0")
|
||||
}
|
||||
|
||||
// 验证数据库连接池配置
|
||||
if config.Database.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("数据库最大连接数必须大于0")
|
||||
}
|
||||
|
||||
if config.Database.MaxIdleConns <= 0 {
|
||||
return fmt.Errorf("数据库最大空闲连接数必须大于0")
|
||||
}
|
||||
|
||||
if config.Database.MaxIdleConns > config.Database.MaxOpenConns {
|
||||
return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEnv 获取环境变量,如果不存在则返回默认值
|
||||
func GetEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ParseDuration 解析时间字符串
|
||||
func ParseDuration(s string) time.Duration {
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// overrideWithEnvVars 手动处理环境变量覆盖,避免空值覆盖配置文件
|
||||
func overrideWithEnvVars(config *viper.Viper) {
|
||||
// 定义需要环境变量覆盖的敏感配置项
|
||||
sensitiveConfigs := map[string]string{
|
||||
"database.password": "DATABASE_PASSWORD",
|
||||
"jwt.secret": "JWT_SECRET",
|
||||
"redis.password": "REDIS_PASSWORD",
|
||||
"wechat_work.webhook_url": "WECHAT_WORK_WEBHOOK_URL",
|
||||
"wechat_work.secret": "WECHAT_WORK_SECRET",
|
||||
}
|
||||
|
||||
// 只覆盖明确设置的环境变量
|
||||
for configKey, envKey := range sensitiveConfigs {
|
||||
if envValue := os.Getenv(envKey); envValue != "" {
|
||||
config.Set(configKey, envValue)
|
||||
fmt.Printf("🔐 已从环境变量覆盖配置: %s\n", configKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SplitAndTrim 分割字符串并去除空格
|
||||
func SplitAndTrim(s, sep string) []string {
|
||||
parts := strings.Split(s, sep)
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user