312 lines
10 KiB
Go
312 lines
10 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/spf13/viper"
|
|
)
|
|
|
|
// LoadConfig 加载应用程序配置
|
|
func LoadConfig() (*Config, error) {
|
|
// 设置配置文件名和路径
|
|
viper.SetConfigName("config")
|
|
viper.SetConfigType("yaml")
|
|
viper.AddConfigPath(".")
|
|
viper.AddConfigPath("./configs")
|
|
viper.AddConfigPath("$HOME/.tyapi")
|
|
|
|
// 设置环境变量前缀
|
|
viper.SetEnvPrefix("")
|
|
viper.AutomaticEnv()
|
|
|
|
// 配置环境变量键名映射
|
|
setupEnvKeyMapping()
|
|
|
|
// 设置默认值
|
|
setDefaults()
|
|
|
|
// 尝试读取配置文件(可选)
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
|
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
|
}
|
|
// 配置文件不存在时使用环境变量和默认值
|
|
}
|
|
|
|
var config Config
|
|
if err := viper.Unmarshal(&config); err != nil {
|
|
return nil, fmt.Errorf("解析配置失败: %w", err)
|
|
}
|
|
|
|
// 验证配置
|
|
if err := validateConfig(&config); err != nil {
|
|
return nil, fmt.Errorf("配置验证失败: %w", err)
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
// setupEnvKeyMapping 设置环境变量到配置键的映射
|
|
func setupEnvKeyMapping() {
|
|
// 服务器配置
|
|
viper.BindEnv("server.port", "SERVER_PORT")
|
|
viper.BindEnv("server.mode", "SERVER_MODE")
|
|
viper.BindEnv("server.host", "SERVER_HOST")
|
|
viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT")
|
|
viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT")
|
|
viper.BindEnv("server.idle_timeout", "SERVER_IDLE_TIMEOUT")
|
|
|
|
// 数据库配置
|
|
viper.BindEnv("database.host", "DB_HOST")
|
|
viper.BindEnv("database.port", "DB_PORT")
|
|
viper.BindEnv("database.user", "DB_USER")
|
|
viper.BindEnv("database.password", "DB_PASSWORD")
|
|
viper.BindEnv("database.name", "DB_NAME")
|
|
viper.BindEnv("database.sslmode", "DB_SSLMODE")
|
|
viper.BindEnv("database.timezone", "DB_TIMEZONE")
|
|
viper.BindEnv("database.max_open_conns", "DB_MAX_OPEN_CONNS")
|
|
viper.BindEnv("database.max_idle_conns", "DB_MAX_IDLE_CONNS")
|
|
viper.BindEnv("database.conn_max_lifetime", "DB_CONN_MAX_LIFETIME")
|
|
|
|
// Redis配置
|
|
viper.BindEnv("redis.host", "REDIS_HOST")
|
|
viper.BindEnv("redis.port", "REDIS_PORT")
|
|
viper.BindEnv("redis.password", "REDIS_PASSWORD")
|
|
viper.BindEnv("redis.db", "REDIS_DB")
|
|
viper.BindEnv("redis.pool_size", "REDIS_POOL_SIZE")
|
|
viper.BindEnv("redis.min_idle_conns", "REDIS_MIN_IDLE_CONNS")
|
|
viper.BindEnv("redis.max_retries", "REDIS_MAX_RETRIES")
|
|
viper.BindEnv("redis.dial_timeout", "REDIS_DIAL_TIMEOUT")
|
|
viper.BindEnv("redis.read_timeout", "REDIS_READ_TIMEOUT")
|
|
viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT")
|
|
|
|
// 缓存配置
|
|
viper.BindEnv("cache.default_ttl", "CACHE_DEFAULT_TTL")
|
|
viper.BindEnv("cache.cleanup_interval", "CACHE_CLEANUP_INTERVAL")
|
|
viper.BindEnv("cache.max_size", "CACHE_MAX_SIZE")
|
|
|
|
// 日志配置
|
|
viper.BindEnv("logger.level", "LOG_LEVEL")
|
|
viper.BindEnv("logger.format", "LOG_FORMAT")
|
|
viper.BindEnv("logger.output", "LOG_OUTPUT")
|
|
viper.BindEnv("logger.file_path", "LOG_FILE_PATH")
|
|
viper.BindEnv("logger.max_size", "LOG_MAX_SIZE")
|
|
viper.BindEnv("logger.max_backups", "LOG_MAX_BACKUPS")
|
|
viper.BindEnv("logger.max_age", "LOG_MAX_AGE")
|
|
viper.BindEnv("logger.compress", "LOG_COMPRESS")
|
|
|
|
// JWT配置
|
|
viper.BindEnv("jwt.secret", "JWT_SECRET")
|
|
viper.BindEnv("jwt.expires_in", "JWT_EXPIRES_IN")
|
|
viper.BindEnv("jwt.refresh_expires_in", "JWT_REFRESH_EXPIRES_IN")
|
|
|
|
// 限流配置
|
|
viper.BindEnv("ratelimit.requests", "RATE_LIMIT_REQUESTS")
|
|
viper.BindEnv("ratelimit.window", "RATE_LIMIT_WINDOW")
|
|
viper.BindEnv("ratelimit.burst", "RATE_LIMIT_BURST")
|
|
|
|
// 监控配置
|
|
viper.BindEnv("monitoring.metrics_enabled", "METRICS_ENABLED")
|
|
viper.BindEnv("monitoring.metrics_port", "METRICS_PORT")
|
|
viper.BindEnv("monitoring.tracing_enabled", "TRACING_ENABLED")
|
|
viper.BindEnv("monitoring.tracing_endpoint", "TRACING_ENDPOINT")
|
|
viper.BindEnv("monitoring.sample_rate", "TRACING_SAMPLE_RATE")
|
|
|
|
// 健康检查配置
|
|
viper.BindEnv("health.enabled", "HEALTH_CHECK_ENABLED")
|
|
viper.BindEnv("health.interval", "HEALTH_CHECK_INTERVAL")
|
|
viper.BindEnv("health.timeout", "HEALTH_CHECK_TIMEOUT")
|
|
|
|
// 容错配置
|
|
viper.BindEnv("resilience.circuit_breaker_enabled", "CIRCUIT_BREAKER_ENABLED")
|
|
viper.BindEnv("resilience.circuit_breaker_threshold", "CIRCUIT_BREAKER_THRESHOLD")
|
|
viper.BindEnv("resilience.circuit_breaker_timeout", "CIRCUIT_BREAKER_TIMEOUT")
|
|
viper.BindEnv("resilience.retry_max_attempts", "RETRY_MAX_ATTEMPTS")
|
|
viper.BindEnv("resilience.retry_initial_delay", "RETRY_INITIAL_DELAY")
|
|
viper.BindEnv("resilience.retry_max_delay", "RETRY_MAX_DELAY")
|
|
|
|
// 开发配置
|
|
viper.BindEnv("development.debug", "DEBUG")
|
|
viper.BindEnv("development.enable_profiler", "ENABLE_PROFILER")
|
|
viper.BindEnv("development.enable_cors", "ENABLE_CORS")
|
|
viper.BindEnv("development.cors_allowed_origins", "CORS_ALLOWED_ORIGINS")
|
|
viper.BindEnv("development.cors_allowed_methods", "CORS_ALLOWED_METHODS")
|
|
viper.BindEnv("development.cors_allowed_headers", "CORS_ALLOWED_HEADERS")
|
|
|
|
// 应用程序配置
|
|
viper.BindEnv("app.name", "APP_NAME")
|
|
viper.BindEnv("app.version", "APP_VERSION")
|
|
viper.BindEnv("app.env", "ENV")
|
|
}
|
|
|
|
// setDefaults 设置默认配置值
|
|
func setDefaults() {
|
|
// 服务器默认值
|
|
viper.SetDefault("server.port", "8080")
|
|
viper.SetDefault("server.mode", "debug")
|
|
viper.SetDefault("server.host", "0.0.0.0")
|
|
viper.SetDefault("server.read_timeout", "30s")
|
|
viper.SetDefault("server.write_timeout", "30s")
|
|
viper.SetDefault("server.idle_timeout", "120s")
|
|
|
|
// 数据库默认值
|
|
viper.SetDefault("database.host", "localhost")
|
|
viper.SetDefault("database.port", "5432")
|
|
viper.SetDefault("database.user", "postgres")
|
|
viper.SetDefault("database.password", "password")
|
|
viper.SetDefault("database.name", "tyapi_db")
|
|
viper.SetDefault("database.sslmode", "disable")
|
|
viper.SetDefault("database.timezone", "Asia/Shanghai")
|
|
viper.SetDefault("database.max_open_conns", 100)
|
|
viper.SetDefault("database.max_idle_conns", 10)
|
|
viper.SetDefault("database.conn_max_lifetime", "300s")
|
|
|
|
// Redis默认值
|
|
viper.SetDefault("redis.host", "localhost")
|
|
viper.SetDefault("redis.port", "6379")
|
|
viper.SetDefault("redis.password", "")
|
|
viper.SetDefault("redis.db", 0)
|
|
viper.SetDefault("redis.pool_size", 10)
|
|
viper.SetDefault("redis.min_idle_conns", 5)
|
|
viper.SetDefault("redis.max_retries", 3)
|
|
viper.SetDefault("redis.dial_timeout", "5s")
|
|
viper.SetDefault("redis.read_timeout", "3s")
|
|
viper.SetDefault("redis.write_timeout", "3s")
|
|
|
|
// 缓存默认值
|
|
viper.SetDefault("cache.default_ttl", "300s")
|
|
viper.SetDefault("cache.cleanup_interval", "600s")
|
|
viper.SetDefault("cache.max_size", 1000)
|
|
|
|
// 日志默认值
|
|
viper.SetDefault("logger.level", "info")
|
|
viper.SetDefault("logger.format", "json")
|
|
viper.SetDefault("logger.output", "stdout")
|
|
viper.SetDefault("logger.file_path", "logs/app.log")
|
|
viper.SetDefault("logger.max_size", 100)
|
|
viper.SetDefault("logger.max_backups", 5)
|
|
viper.SetDefault("logger.max_age", 30)
|
|
viper.SetDefault("logger.compress", true)
|
|
|
|
// JWT默认值
|
|
viper.SetDefault("jwt.secret", "your-super-secret-jwt-key-change-this-in-production")
|
|
viper.SetDefault("jwt.expires_in", "24h")
|
|
viper.SetDefault("jwt.refresh_expires_in", "168h")
|
|
|
|
// 限流默认值
|
|
viper.SetDefault("ratelimit.requests", 100)
|
|
viper.SetDefault("ratelimit.window", "60s")
|
|
viper.SetDefault("ratelimit.burst", 10)
|
|
|
|
// 监控默认值
|
|
viper.SetDefault("monitoring.metrics_enabled", true)
|
|
viper.SetDefault("monitoring.metrics_port", "9090")
|
|
viper.SetDefault("monitoring.tracing_enabled", false)
|
|
viper.SetDefault("monitoring.tracing_endpoint", "http://localhost:14268/api/traces")
|
|
viper.SetDefault("monitoring.sample_rate", 0.1)
|
|
|
|
// 健康检查默认值
|
|
viper.SetDefault("health.enabled", true)
|
|
viper.SetDefault("health.interval", "30s")
|
|
viper.SetDefault("health.timeout", "5s")
|
|
|
|
// 容错默认值
|
|
viper.SetDefault("resilience.circuit_breaker_enabled", true)
|
|
viper.SetDefault("resilience.circuit_breaker_threshold", 5)
|
|
viper.SetDefault("resilience.circuit_breaker_timeout", "60s")
|
|
viper.SetDefault("resilience.retry_max_attempts", 3)
|
|
viper.SetDefault("resilience.retry_initial_delay", "100ms")
|
|
viper.SetDefault("resilience.retry_max_delay", "2s")
|
|
|
|
// 开发默认值
|
|
viper.SetDefault("development.debug", true)
|
|
viper.SetDefault("development.enable_profiler", false)
|
|
viper.SetDefault("development.enable_cors", true)
|
|
viper.SetDefault("development.cors_allowed_origins", "*")
|
|
viper.SetDefault("development.cors_allowed_methods", "GET,POST,PUT,DELETE,OPTIONS")
|
|
viper.SetDefault("development.cors_allowed_headers", "*")
|
|
|
|
// 应用程序默认值
|
|
viper.SetDefault("app.name", "tyapi-server")
|
|
viper.SetDefault("app.version", "1.0.0")
|
|
viper.SetDefault("app.env", "development")
|
|
}
|
|
|
|
// validateConfig 验证配置
|
|
func validateConfig(config *Config) error {
|
|
// 验证必要的配置项
|
|
if config.Database.Host == "" {
|
|
return fmt.Errorf("数据库主机地址不能为空")
|
|
}
|
|
|
|
if config.Database.User == "" {
|
|
return fmt.Errorf("数据库用户名不能为空")
|
|
}
|
|
|
|
if config.Database.Name == "" {
|
|
return fmt.Errorf("数据库名称不能为空")
|
|
}
|
|
|
|
if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" {
|
|
if config.App.IsProduction() {
|
|
return fmt.Errorf("生产环境必须设置安全的JWT密钥")
|
|
}
|
|
}
|
|
|
|
// 验证超时配置
|
|
if config.Server.ReadTimeout <= 0 {
|
|
return fmt.Errorf("服务器读取超时时间必须大于0")
|
|
}
|
|
|
|
if config.Server.WriteTimeout <= 0 {
|
|
return fmt.Errorf("服务器写入超时时间必须大于0")
|
|
}
|
|
|
|
// 验证数据库连接池配置
|
|
if config.Database.MaxOpenConns <= 0 {
|
|
return fmt.Errorf("数据库最大连接数必须大于0")
|
|
}
|
|
|
|
if config.Database.MaxIdleConns <= 0 {
|
|
return fmt.Errorf("数据库最大空闲连接数必须大于0")
|
|
}
|
|
|
|
if config.Database.MaxIdleConns > config.Database.MaxOpenConns {
|
|
return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetEnv 获取环境变量,如果不存在则返回默认值
|
|
func GetEnv(key, defaultValue string) string {
|
|
if value := os.Getenv(key); value != "" {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|
|
|
|
// ParseDuration 解析时间字符串
|
|
func ParseDuration(s string) time.Duration {
|
|
d, err := time.ParseDuration(s)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return d
|
|
}
|
|
|
|
// SplitAndTrim 分割字符串并去除空格
|
|
func SplitAndTrim(s, sep string) []string {
|
|
parts := strings.Split(s, sep)
|
|
result := make([]string, 0, len(parts))
|
|
for _, part := range parts {
|
|
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
|
result = append(result, trimmed)
|
|
}
|
|
}
|
|
return result
|
|
}
|