Files
tyapi-server/internal/config/loader.go

312 lines
10 KiB
Go

package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/spf13/viper"
)
// LoadConfig 加载应用程序配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("./configs")
viper.AddConfigPath("$HOME/.tyapi")
// 设置环境变量前缀
viper.SetEnvPrefix("")
viper.AutomaticEnv()
// 配置环境变量键名映射
setupEnvKeyMapping()
// 设置默认值
setDefaults()
// 尝试读取配置文件(可选)
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
// 配置文件不存在时使用环境变量和默认值
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
// 验证配置
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
return &config, nil
}
// setupEnvKeyMapping 设置环境变量到配置键的映射
func setupEnvKeyMapping() {
// 服务器配置
viper.BindEnv("server.port", "SERVER_PORT")
viper.BindEnv("server.mode", "SERVER_MODE")
viper.BindEnv("server.host", "SERVER_HOST")
viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT")
viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT")
viper.BindEnv("server.idle_timeout", "SERVER_IDLE_TIMEOUT")
// 数据库配置
viper.BindEnv("database.host", "DB_HOST")
viper.BindEnv("database.port", "DB_PORT")
viper.BindEnv("database.user", "DB_USER")
viper.BindEnv("database.password", "DB_PASSWORD")
viper.BindEnv("database.name", "DB_NAME")
viper.BindEnv("database.sslmode", "DB_SSLMODE")
viper.BindEnv("database.timezone", "DB_TIMEZONE")
viper.BindEnv("database.max_open_conns", "DB_MAX_OPEN_CONNS")
viper.BindEnv("database.max_idle_conns", "DB_MAX_IDLE_CONNS")
viper.BindEnv("database.conn_max_lifetime", "DB_CONN_MAX_LIFETIME")
// Redis配置
viper.BindEnv("redis.host", "REDIS_HOST")
viper.BindEnv("redis.port", "REDIS_PORT")
viper.BindEnv("redis.password", "REDIS_PASSWORD")
viper.BindEnv("redis.db", "REDIS_DB")
viper.BindEnv("redis.pool_size", "REDIS_POOL_SIZE")
viper.BindEnv("redis.min_idle_conns", "REDIS_MIN_IDLE_CONNS")
viper.BindEnv("redis.max_retries", "REDIS_MAX_RETRIES")
viper.BindEnv("redis.dial_timeout", "REDIS_DIAL_TIMEOUT")
viper.BindEnv("redis.read_timeout", "REDIS_READ_TIMEOUT")
viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT")
// 缓存配置
viper.BindEnv("cache.default_ttl", "CACHE_DEFAULT_TTL")
viper.BindEnv("cache.cleanup_interval", "CACHE_CLEANUP_INTERVAL")
viper.BindEnv("cache.max_size", "CACHE_MAX_SIZE")
// 日志配置
viper.BindEnv("logger.level", "LOG_LEVEL")
viper.BindEnv("logger.format", "LOG_FORMAT")
viper.BindEnv("logger.output", "LOG_OUTPUT")
viper.BindEnv("logger.file_path", "LOG_FILE_PATH")
viper.BindEnv("logger.max_size", "LOG_MAX_SIZE")
viper.BindEnv("logger.max_backups", "LOG_MAX_BACKUPS")
viper.BindEnv("logger.max_age", "LOG_MAX_AGE")
viper.BindEnv("logger.compress", "LOG_COMPRESS")
// JWT配置
viper.BindEnv("jwt.secret", "JWT_SECRET")
viper.BindEnv("jwt.expires_in", "JWT_EXPIRES_IN")
viper.BindEnv("jwt.refresh_expires_in", "JWT_REFRESH_EXPIRES_IN")
// 限流配置
viper.BindEnv("ratelimit.requests", "RATE_LIMIT_REQUESTS")
viper.BindEnv("ratelimit.window", "RATE_LIMIT_WINDOW")
viper.BindEnv("ratelimit.burst", "RATE_LIMIT_BURST")
// 监控配置
viper.BindEnv("monitoring.metrics_enabled", "METRICS_ENABLED")
viper.BindEnv("monitoring.metrics_port", "METRICS_PORT")
viper.BindEnv("monitoring.tracing_enabled", "TRACING_ENABLED")
viper.BindEnv("monitoring.tracing_endpoint", "TRACING_ENDPOINT")
viper.BindEnv("monitoring.sample_rate", "TRACING_SAMPLE_RATE")
// 健康检查配置
viper.BindEnv("health.enabled", "HEALTH_CHECK_ENABLED")
viper.BindEnv("health.interval", "HEALTH_CHECK_INTERVAL")
viper.BindEnv("health.timeout", "HEALTH_CHECK_TIMEOUT")
// 容错配置
viper.BindEnv("resilience.circuit_breaker_enabled", "CIRCUIT_BREAKER_ENABLED")
viper.BindEnv("resilience.circuit_breaker_threshold", "CIRCUIT_BREAKER_THRESHOLD")
viper.BindEnv("resilience.circuit_breaker_timeout", "CIRCUIT_BREAKER_TIMEOUT")
viper.BindEnv("resilience.retry_max_attempts", "RETRY_MAX_ATTEMPTS")
viper.BindEnv("resilience.retry_initial_delay", "RETRY_INITIAL_DELAY")
viper.BindEnv("resilience.retry_max_delay", "RETRY_MAX_DELAY")
// 开发配置
viper.BindEnv("development.debug", "DEBUG")
viper.BindEnv("development.enable_profiler", "ENABLE_PROFILER")
viper.BindEnv("development.enable_cors", "ENABLE_CORS")
viper.BindEnv("development.cors_allowed_origins", "CORS_ALLOWED_ORIGINS")
viper.BindEnv("development.cors_allowed_methods", "CORS_ALLOWED_METHODS")
viper.BindEnv("development.cors_allowed_headers", "CORS_ALLOWED_HEADERS")
// 应用程序配置
viper.BindEnv("app.name", "APP_NAME")
viper.BindEnv("app.version", "APP_VERSION")
viper.BindEnv("app.env", "ENV")
}
// setDefaults 设置默认配置值
func setDefaults() {
// 服务器默认值
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.read_timeout", "30s")
viper.SetDefault("server.write_timeout", "30s")
viper.SetDefault("server.idle_timeout", "120s")
// 数据库默认值
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", "5432")
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "password")
viper.SetDefault("database.name", "tyapi_db")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.timezone", "Asia/Shanghai")
viper.SetDefault("database.max_open_conns", 100)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime", "300s")
// Redis默认值
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", "6379")
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.pool_size", 10)
viper.SetDefault("redis.min_idle_conns", 5)
viper.SetDefault("redis.max_retries", 3)
viper.SetDefault("redis.dial_timeout", "5s")
viper.SetDefault("redis.read_timeout", "3s")
viper.SetDefault("redis.write_timeout", "3s")
// 缓存默认值
viper.SetDefault("cache.default_ttl", "300s")
viper.SetDefault("cache.cleanup_interval", "600s")
viper.SetDefault("cache.max_size", 1000)
// 日志默认值
viper.SetDefault("logger.level", "info")
viper.SetDefault("logger.format", "json")
viper.SetDefault("logger.output", "stdout")
viper.SetDefault("logger.file_path", "logs/app.log")
viper.SetDefault("logger.max_size", 100)
viper.SetDefault("logger.max_backups", 5)
viper.SetDefault("logger.max_age", 30)
viper.SetDefault("logger.compress", true)
// JWT默认值
viper.SetDefault("jwt.secret", "your-super-secret-jwt-key-change-this-in-production")
viper.SetDefault("jwt.expires_in", "24h")
viper.SetDefault("jwt.refresh_expires_in", "168h")
// 限流默认值
viper.SetDefault("ratelimit.requests", 100)
viper.SetDefault("ratelimit.window", "60s")
viper.SetDefault("ratelimit.burst", 10)
// 监控默认值
viper.SetDefault("monitoring.metrics_enabled", true)
viper.SetDefault("monitoring.metrics_port", "9090")
viper.SetDefault("monitoring.tracing_enabled", false)
viper.SetDefault("monitoring.tracing_endpoint", "http://localhost:14268/api/traces")
viper.SetDefault("monitoring.sample_rate", 0.1)
// 健康检查默认值
viper.SetDefault("health.enabled", true)
viper.SetDefault("health.interval", "30s")
viper.SetDefault("health.timeout", "5s")
// 容错默认值
viper.SetDefault("resilience.circuit_breaker_enabled", true)
viper.SetDefault("resilience.circuit_breaker_threshold", 5)
viper.SetDefault("resilience.circuit_breaker_timeout", "60s")
viper.SetDefault("resilience.retry_max_attempts", 3)
viper.SetDefault("resilience.retry_initial_delay", "100ms")
viper.SetDefault("resilience.retry_max_delay", "2s")
// 开发默认值
viper.SetDefault("development.debug", true)
viper.SetDefault("development.enable_profiler", false)
viper.SetDefault("development.enable_cors", true)
viper.SetDefault("development.cors_allowed_origins", "*")
viper.SetDefault("development.cors_allowed_methods", "GET,POST,PUT,DELETE,OPTIONS")
viper.SetDefault("development.cors_allowed_headers", "*")
// 应用程序默认值
viper.SetDefault("app.name", "tyapi-server")
viper.SetDefault("app.version", "1.0.0")
viper.SetDefault("app.env", "development")
}
// validateConfig 验证配置
func validateConfig(config *Config) error {
// 验证必要的配置项
if config.Database.Host == "" {
return fmt.Errorf("数据库主机地址不能为空")
}
if config.Database.User == "" {
return fmt.Errorf("数据库用户名不能为空")
}
if config.Database.Name == "" {
return fmt.Errorf("数据库名称不能为空")
}
if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" {
if config.App.IsProduction() {
return fmt.Errorf("生产环境必须设置安全的JWT密钥")
}
}
// 验证超时配置
if config.Server.ReadTimeout <= 0 {
return fmt.Errorf("服务器读取超时时间必须大于0")
}
if config.Server.WriteTimeout <= 0 {
return fmt.Errorf("服务器写入超时时间必须大于0")
}
// 验证数据库连接池配置
if config.Database.MaxOpenConns <= 0 {
return fmt.Errorf("数据库最大连接数必须大于0")
}
if config.Database.MaxIdleConns <= 0 {
return fmt.Errorf("数据库最大空闲连接数必须大于0")
}
if config.Database.MaxIdleConns > config.Database.MaxOpenConns {
return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数")
}
return nil
}
// GetEnv 获取环境变量,如果不存在则返回默认值
func GetEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// ParseDuration 解析时间字符串
func ParseDuration(s string) time.Duration {
d, err := time.ParseDuration(s)
if err != nil {
return 0
}
return d
}
// SplitAndTrim 分割字符串并去除空格
func SplitAndTrim(s, sep string) []string {
parts := strings.Split(s, sep)
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}