package config import ( "fmt" "os" "strings" "time" "github.com/spf13/viper" ) // LoadConfig 加载应用程序配置 func LoadConfig() (*Config, error) { // 设置配置文件名和路径 viper.SetConfigName("config") viper.SetConfigType("yaml") viper.AddConfigPath(".") viper.AddConfigPath("./configs") viper.AddConfigPath("$HOME/.tyapi") // 设置环境变量前缀 viper.SetEnvPrefix("") viper.AutomaticEnv() // 配置环境变量键名映射 setupEnvKeyMapping() // 设置默认值 setDefaults() // 尝试读取配置文件(可选) if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, fmt.Errorf("读取配置文件失败: %w", err) } // 配置文件不存在时使用环境变量和默认值 } var config Config if err := viper.Unmarshal(&config); err != nil { return nil, fmt.Errorf("解析配置失败: %w", err) } // 验证配置 if err := validateConfig(&config); err != nil { return nil, fmt.Errorf("配置验证失败: %w", err) } return &config, nil } // setupEnvKeyMapping 设置环境变量到配置键的映射 func setupEnvKeyMapping() { // 服务器配置 viper.BindEnv("server.port", "SERVER_PORT") viper.BindEnv("server.mode", "SERVER_MODE") viper.BindEnv("server.host", "SERVER_HOST") viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT") viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT") viper.BindEnv("server.idle_timeout", "SERVER_IDLE_TIMEOUT") // 数据库配置 viper.BindEnv("database.host", "DB_HOST") viper.BindEnv("database.port", "DB_PORT") viper.BindEnv("database.user", "DB_USER") viper.BindEnv("database.password", "DB_PASSWORD") viper.BindEnv("database.name", "DB_NAME") viper.BindEnv("database.sslmode", "DB_SSLMODE") viper.BindEnv("database.timezone", "DB_TIMEZONE") viper.BindEnv("database.max_open_conns", "DB_MAX_OPEN_CONNS") viper.BindEnv("database.max_idle_conns", "DB_MAX_IDLE_CONNS") viper.BindEnv("database.conn_max_lifetime", "DB_CONN_MAX_LIFETIME") // Redis配置 viper.BindEnv("redis.host", "REDIS_HOST") viper.BindEnv("redis.port", "REDIS_PORT") viper.BindEnv("redis.password", "REDIS_PASSWORD") viper.BindEnv("redis.db", "REDIS_DB") viper.BindEnv("redis.pool_size", "REDIS_POOL_SIZE") viper.BindEnv("redis.min_idle_conns", "REDIS_MIN_IDLE_CONNS") viper.BindEnv("redis.max_retries", "REDIS_MAX_RETRIES") viper.BindEnv("redis.dial_timeout", "REDIS_DIAL_TIMEOUT") viper.BindEnv("redis.read_timeout", "REDIS_READ_TIMEOUT") viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT") // 缓存配置 viper.BindEnv("cache.default_ttl", "CACHE_DEFAULT_TTL") viper.BindEnv("cache.cleanup_interval", "CACHE_CLEANUP_INTERVAL") viper.BindEnv("cache.max_size", "CACHE_MAX_SIZE") // 日志配置 viper.BindEnv("logger.level", "LOG_LEVEL") viper.BindEnv("logger.format", "LOG_FORMAT") viper.BindEnv("logger.output", "LOG_OUTPUT") viper.BindEnv("logger.file_path", "LOG_FILE_PATH") viper.BindEnv("logger.max_size", "LOG_MAX_SIZE") viper.BindEnv("logger.max_backups", "LOG_MAX_BACKUPS") viper.BindEnv("logger.max_age", "LOG_MAX_AGE") viper.BindEnv("logger.compress", "LOG_COMPRESS") // JWT配置 viper.BindEnv("jwt.secret", "JWT_SECRET") viper.BindEnv("jwt.expires_in", "JWT_EXPIRES_IN") viper.BindEnv("jwt.refresh_expires_in", "JWT_REFRESH_EXPIRES_IN") // 限流配置 viper.BindEnv("ratelimit.requests", "RATE_LIMIT_REQUESTS") viper.BindEnv("ratelimit.window", "RATE_LIMIT_WINDOW") viper.BindEnv("ratelimit.burst", "RATE_LIMIT_BURST") // 监控配置 viper.BindEnv("monitoring.metrics_enabled", "METRICS_ENABLED") viper.BindEnv("monitoring.metrics_port", "METRICS_PORT") viper.BindEnv("monitoring.tracing_enabled", "TRACING_ENABLED") viper.BindEnv("monitoring.tracing_endpoint", "TRACING_ENDPOINT") viper.BindEnv("monitoring.sample_rate", "TRACING_SAMPLE_RATE") // 健康检查配置 viper.BindEnv("health.enabled", "HEALTH_CHECK_ENABLED") viper.BindEnv("health.interval", "HEALTH_CHECK_INTERVAL") viper.BindEnv("health.timeout", "HEALTH_CHECK_TIMEOUT") // 容错配置 viper.BindEnv("resilience.circuit_breaker_enabled", "CIRCUIT_BREAKER_ENABLED") viper.BindEnv("resilience.circuit_breaker_threshold", "CIRCUIT_BREAKER_THRESHOLD") viper.BindEnv("resilience.circuit_breaker_timeout", "CIRCUIT_BREAKER_TIMEOUT") viper.BindEnv("resilience.retry_max_attempts", "RETRY_MAX_ATTEMPTS") viper.BindEnv("resilience.retry_initial_delay", "RETRY_INITIAL_DELAY") viper.BindEnv("resilience.retry_max_delay", "RETRY_MAX_DELAY") // 开发配置 viper.BindEnv("development.debug", "DEBUG") viper.BindEnv("development.enable_profiler", "ENABLE_PROFILER") viper.BindEnv("development.enable_cors", "ENABLE_CORS") viper.BindEnv("development.cors_allowed_origins", "CORS_ALLOWED_ORIGINS") viper.BindEnv("development.cors_allowed_methods", "CORS_ALLOWED_METHODS") viper.BindEnv("development.cors_allowed_headers", "CORS_ALLOWED_HEADERS") // 应用程序配置 viper.BindEnv("app.name", "APP_NAME") viper.BindEnv("app.version", "APP_VERSION") viper.BindEnv("app.env", "ENV") } // setDefaults 设置默认配置值 func setDefaults() { // 服务器默认值 viper.SetDefault("server.port", "8080") viper.SetDefault("server.mode", "debug") viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.read_timeout", "30s") viper.SetDefault("server.write_timeout", "30s") viper.SetDefault("server.idle_timeout", "120s") // 数据库默认值 viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", "5432") viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "password") viper.SetDefault("database.name", "tyapi_db") viper.SetDefault("database.sslmode", "disable") viper.SetDefault("database.timezone", "Asia/Shanghai") viper.SetDefault("database.max_open_conns", 100) viper.SetDefault("database.max_idle_conns", 10) viper.SetDefault("database.conn_max_lifetime", "300s") // Redis默认值 viper.SetDefault("redis.host", "localhost") viper.SetDefault("redis.port", "6379") viper.SetDefault("redis.password", "") viper.SetDefault("redis.db", 0) viper.SetDefault("redis.pool_size", 10) viper.SetDefault("redis.min_idle_conns", 5) viper.SetDefault("redis.max_retries", 3) viper.SetDefault("redis.dial_timeout", "5s") viper.SetDefault("redis.read_timeout", "3s") viper.SetDefault("redis.write_timeout", "3s") // 缓存默认值 viper.SetDefault("cache.default_ttl", "300s") viper.SetDefault("cache.cleanup_interval", "600s") viper.SetDefault("cache.max_size", 1000) // 日志默认值 viper.SetDefault("logger.level", "info") viper.SetDefault("logger.format", "json") viper.SetDefault("logger.output", "stdout") viper.SetDefault("logger.file_path", "logs/app.log") viper.SetDefault("logger.max_size", 100) viper.SetDefault("logger.max_backups", 5) viper.SetDefault("logger.max_age", 30) viper.SetDefault("logger.compress", true) // JWT默认值 viper.SetDefault("jwt.secret", "your-super-secret-jwt-key-change-this-in-production") viper.SetDefault("jwt.expires_in", "24h") viper.SetDefault("jwt.refresh_expires_in", "168h") // 限流默认值 viper.SetDefault("ratelimit.requests", 100) viper.SetDefault("ratelimit.window", "60s") viper.SetDefault("ratelimit.burst", 10) // 监控默认值 viper.SetDefault("monitoring.metrics_enabled", true) viper.SetDefault("monitoring.metrics_port", "9090") viper.SetDefault("monitoring.tracing_enabled", false) viper.SetDefault("monitoring.tracing_endpoint", "http://localhost:14268/api/traces") viper.SetDefault("monitoring.sample_rate", 0.1) // 健康检查默认值 viper.SetDefault("health.enabled", true) viper.SetDefault("health.interval", "30s") viper.SetDefault("health.timeout", "5s") // 容错默认值 viper.SetDefault("resilience.circuit_breaker_enabled", true) viper.SetDefault("resilience.circuit_breaker_threshold", 5) viper.SetDefault("resilience.circuit_breaker_timeout", "60s") viper.SetDefault("resilience.retry_max_attempts", 3) viper.SetDefault("resilience.retry_initial_delay", "100ms") viper.SetDefault("resilience.retry_max_delay", "2s") // 开发默认值 viper.SetDefault("development.debug", true) viper.SetDefault("development.enable_profiler", false) viper.SetDefault("development.enable_cors", true) viper.SetDefault("development.cors_allowed_origins", "*") viper.SetDefault("development.cors_allowed_methods", "GET,POST,PUT,DELETE,OPTIONS") viper.SetDefault("development.cors_allowed_headers", "*") // 应用程序默认值 viper.SetDefault("app.name", "tyapi-server") viper.SetDefault("app.version", "1.0.0") viper.SetDefault("app.env", "development") } // validateConfig 验证配置 func validateConfig(config *Config) error { // 验证必要的配置项 if config.Database.Host == "" { return fmt.Errorf("数据库主机地址不能为空") } if config.Database.User == "" { return fmt.Errorf("数据库用户名不能为空") } if config.Database.Name == "" { return fmt.Errorf("数据库名称不能为空") } if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" { if config.App.IsProduction() { return fmt.Errorf("生产环境必须设置安全的JWT密钥") } } // 验证超时配置 if config.Server.ReadTimeout <= 0 { return fmt.Errorf("服务器读取超时时间必须大于0") } if config.Server.WriteTimeout <= 0 { return fmt.Errorf("服务器写入超时时间必须大于0") } // 验证数据库连接池配置 if config.Database.MaxOpenConns <= 0 { return fmt.Errorf("数据库最大连接数必须大于0") } if config.Database.MaxIdleConns <= 0 { return fmt.Errorf("数据库最大空闲连接数必须大于0") } if config.Database.MaxIdleConns > config.Database.MaxOpenConns { return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数") } return nil } // GetEnv 获取环境变量,如果不存在则返回默认值 func GetEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue } // ParseDuration 解析时间字符串 func ParseDuration(s string) time.Duration { d, err := time.ParseDuration(s) if err != nil { return 0 } return d } // SplitAndTrim 分割字符串并去除空格 func SplitAndTrim(s, sep string) []string { parts := strings.Split(s, sep) result := make([]string, 0, len(parts)) for _, part := range parts { if trimmed := strings.TrimSpace(part); trimmed != "" { result = append(result, trimmed) } } return result }