Files
tyapi-server/internal/config/loader.go
2025-07-11 21:05:58 +08:00

268 lines
7.3 KiB
Go
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/viper"
)
// LoadConfig 加载应用程序配置
func LoadConfig() (*Config, error) {
// 1⃣ 获取环境变量决定配置文件
env := getEnvironment()
fmt.Printf("🔧 当前运行环境: %s\n", env)
// 2⃣ 加载基础配置文件
baseConfig := viper.New()
baseConfig.SetConfigName("config")
baseConfig.SetConfigType("yaml")
baseConfig.AddConfigPath(".")
baseConfig.AddConfigPath("./configs")
baseConfig.AddConfigPath("$HOME/.tyapi")
// 读取基础配置文件
if err := baseConfig.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("读取基础配置文件失败: %w", err)
}
return nil, fmt.Errorf("未找到 config.yaml 文件,请确保配置文件存在")
}
fmt.Printf("✅ 已加载配置文件: %s\n", baseConfig.ConfigFileUsed())
// 3⃣ 加载环境特定配置文件
envConfigFile := findEnvConfigFile(env)
if envConfigFile != "" {
// 创建一个新的viper实例来读取环境配置
envConfig := viper.New()
envConfig.SetConfigFile(envConfigFile)
if err := envConfig.ReadInConfig(); err != nil {
fmt.Printf("⚠️ 环境配置文件加载警告: %v\n", err)
} else {
fmt.Printf("✅ 已加载环境配置: %s\n", envConfigFile)
// 将环境配置合并到基础配置中
if err := mergeConfigs(baseConfig, envConfig.AllSettings()); err != nil {
return nil, fmt.Errorf("合并配置失败: %w", err)
}
}
} else {
fmt.Printf(" 未找到环境配置文件 configs/env.%s.yaml将使用基础配置\n", env)
}
// 4⃣ 手动处理环境变量覆盖,避免空值覆盖配置文件
// overrideWithEnvVars(baseConfig)
// 5⃣ 解析配置到结构体
var config Config
if err := baseConfig.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
// 6⃣ 验证配置
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
// 7⃣ 输出配置摘要
printConfigSummary(&config, env)
return &config, nil
}
// mergeConfigs 递归合并配置
func mergeConfigs(baseConfig *viper.Viper, overrideSettings map[string]interface{}) error {
for key, val := range overrideSettings {
// 如果值是一个嵌套的map则递归合并
if subMap, ok := val.(map[string]interface{}); ok {
// 创建子键路径
subKey := key
// 递归合并子配置
for subK, subV := range subMap {
fullKey := fmt.Sprintf("%s.%s", subKey, subK)
baseConfig.Set(fullKey, subV)
}
} else {
// 直接设置值
baseConfig.Set(key, val)
}
}
return nil
}
// findEnvConfigFile 查找环境特定的配置文件
func findEnvConfigFile(env string) string {
// 只查找 configs 目录下的环境配置文件
possiblePaths := []string{
fmt.Sprintf("configs/env.%s.yaml", env),
fmt.Sprintf("configs/env.%s.yml", env),
}
for _, path := range possiblePaths {
if _, err := os.Stat(path); err == nil {
absPath, _ := filepath.Abs(path)
return absPath
}
}
return ""
}
// getEnvironment 获取当前环境
func getEnvironment() string {
var env string
var source string
// 优先级CONFIG_ENV > ENV > APP_ENV > 默认值
if env = os.Getenv("CONFIG_ENV"); env != "" {
source = "CONFIG_ENV"
} else if env = os.Getenv("ENV"); env != "" {
source = "ENV"
} else if env = os.Getenv("APP_ENV"); env != "" {
source = "APP_ENV"
} else {
env = "development"
source = "默认值"
}
fmt.Printf("🌍 环境检测: %s (来源: %s)\n", env, source)
// 验证环境值
validEnvs := []string{"development", "production", "testing"}
isValid := false
for _, validEnv := range validEnvs {
if env == validEnv {
isValid = true
break
}
}
if !isValid {
fmt.Printf("⚠️ 警告: 未识别的环境 '%s',将使用默认环境 'development'\n", env)
return "development"
}
return env
}
// printConfigSummary 打印配置摘要
func printConfigSummary(config *Config, env string) {
fmt.Printf("\n🔧 配置摘要:\n")
fmt.Printf(" 🌍 环境: %s\n", env)
fmt.Printf(" 📄 基础配置: config.yaml\n")
fmt.Printf(" 📁 环境配置: configs/env.%s.yaml\n", env)
fmt.Printf(" 📱 应用名称: %s\n", config.App.Name)
fmt.Printf(" 🔖 版本: %s\n", config.App.Version)
fmt.Printf(" 🌐 服务端口: %s\n", config.Server.Port)
fmt.Printf(" 🗄️ 数据库: %s@%s:%s/%s\n",
config.Database.User,
config.Database.Host,
config.Database.Port,
config.Database.Name)
fmt.Printf(" 📊 追踪状态: %v (端点: %s)\n",
config.Monitoring.TracingEnabled,
config.Monitoring.TracingEndpoint)
fmt.Printf(" 📈 采样率: %.1f%%\n", config.Monitoring.SampleRate*100)
fmt.Printf("\n")
}
// validateConfig 验证配置
func validateConfig(config *Config) error {
// 验证必要的配置项
if config.Database.Host == "" {
return fmt.Errorf("数据库主机地址不能为空")
}
if config.Database.User == "" {
return fmt.Errorf("数据库用户名不能为空")
}
if config.Database.Name == "" {
return fmt.Errorf("数据库名称不能为空")
}
if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" {
if config.App.IsProduction() {
return fmt.Errorf("生产环境必须设置安全的JWT密钥")
}
}
// 验证超时配置
if config.Server.ReadTimeout <= 0 {
return fmt.Errorf("服务器读取超时时间必须大于0")
}
if config.Server.WriteTimeout <= 0 {
return fmt.Errorf("服务器写入超时时间必须大于0")
}
// 验证数据库连接池配置
if config.Database.MaxOpenConns <= 0 {
return fmt.Errorf("数据库最大连接数必须大于0")
}
if config.Database.MaxIdleConns <= 0 {
return fmt.Errorf("数据库最大空闲连接数必须大于0")
}
if config.Database.MaxIdleConns > config.Database.MaxOpenConns {
return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数")
}
return nil
}
// GetEnv 获取环境变量,如果不存在则返回默认值
func GetEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// ParseDuration 解析时间字符串
func ParseDuration(s string) time.Duration {
d, err := time.ParseDuration(s)
if err != nil {
return 0
}
return d
}
// overrideWithEnvVars 手动处理环境变量覆盖,避免空值覆盖配置文件
func overrideWithEnvVars(config *viper.Viper) {
// 定义需要环境变量覆盖的敏感配置项
sensitiveConfigs := map[string]string{
"database.password": "DATABASE_PASSWORD",
"jwt.secret": "JWT_SECRET",
"redis.password": "REDIS_PASSWORD",
"wechat_work.webhook_url": "WECHAT_WORK_WEBHOOK_URL",
"wechat_work.secret": "WECHAT_WORK_SECRET",
}
// 只覆盖明确设置的环境变量
for configKey, envKey := range sensitiveConfigs {
if envValue := os.Getenv(envKey); envValue != "" {
config.Set(configKey, envValue)
fmt.Printf("🔐 已从环境变量覆盖配置: %s\n", configKey)
}
}
}
// SplitAndTrim 分割字符串并去除空格
func SplitAndTrim(s, sep string) []string {
parts := strings.Split(s, sep)
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}