Initial commit: Basic project structure and dependencies
This commit is contained in:
235
internal/app/app.go
Normal file
235
internal/app/app.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/container"
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
)
|
||||
|
||||
// Application 应用程序结构
|
||||
type Application struct {
|
||||
container *container.Container
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewApplication 创建新的应用程序实例
|
||||
func NewApplication() (*Application, error) {
|
||||
// 加载配置
|
||||
cfg, err := config.LoadConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to load config: %w", err)
|
||||
}
|
||||
|
||||
// 创建日志器
|
||||
logger, err := createLogger(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create logger: %w", err)
|
||||
}
|
||||
|
||||
// 创建容器
|
||||
cont := container.NewContainer()
|
||||
|
||||
return &Application{
|
||||
container: cont,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run 运行应用程序
|
||||
func (a *Application) Run() error {
|
||||
// 打印启动信息
|
||||
a.printBanner()
|
||||
|
||||
// 启动容器
|
||||
a.logger.Info("Starting application container...")
|
||||
if err := a.container.Start(); err != nil {
|
||||
a.logger.Error("Failed to start container", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置优雅关闭
|
||||
a.setupGracefulShutdown()
|
||||
|
||||
a.logger.Info("Application started successfully",
|
||||
zap.String("version", a.config.App.Version),
|
||||
zap.String("environment", a.config.App.Env),
|
||||
zap.String("port", a.config.Server.Port))
|
||||
|
||||
// 等待信号
|
||||
return a.waitForShutdown()
|
||||
}
|
||||
|
||||
// RunMigrations 运行数据库迁移
|
||||
func (a *Application) RunMigrations() error {
|
||||
a.logger.Info("Running database migrations...")
|
||||
|
||||
// 创建数据库连接
|
||||
db, err := a.createDatabaseConnection()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database connection: %w", err)
|
||||
}
|
||||
|
||||
// 自动迁移
|
||||
if err := a.autoMigrate(db); err != nil {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
a.logger.Info("Database migrations completed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// printBanner 打印启动横幅
|
||||
func (a *Application) printBanner() {
|
||||
banner := fmt.Sprintf(`
|
||||
╔══════════════════════════════════════════════════════════════╗
|
||||
║ %s ║
|
||||
║ Version: %s ║
|
||||
║ Environment: %s ║
|
||||
║ Port: %s ║
|
||||
╚══════════════════════════════════════════════════════════════╝
|
||||
`,
|
||||
a.config.App.Name,
|
||||
a.config.App.Version,
|
||||
a.config.App.Env,
|
||||
a.config.Server.Port,
|
||||
)
|
||||
|
||||
fmt.Println(banner)
|
||||
}
|
||||
|
||||
// setupGracefulShutdown 设置优雅关闭
|
||||
func (a *Application) setupGracefulShutdown() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
go func() {
|
||||
<-c
|
||||
a.logger.Info("Received shutdown signal, starting graceful shutdown...")
|
||||
|
||||
// 停止容器
|
||||
if err := a.container.Stop(); err != nil {
|
||||
a.logger.Error("Error during container shutdown", zap.Error(err))
|
||||
}
|
||||
|
||||
a.logger.Info("Application shutdown completed")
|
||||
os.Exit(0)
|
||||
}()
|
||||
}
|
||||
|
||||
// waitForShutdown 等待关闭信号
|
||||
func (a *Application) waitForShutdown() error {
|
||||
// 创建一个通道来等待关闭
|
||||
done := make(chan bool, 1)
|
||||
|
||||
// 启动一个协程来监听信号
|
||||
go func() {
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
|
||||
<-c
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// 等待关闭信号
|
||||
<-done
|
||||
return nil
|
||||
}
|
||||
|
||||
// createDatabaseConnection 创建数据库连接
|
||||
func (a *Application) createDatabaseConnection() (*gorm.DB, error) {
|
||||
return container.NewDatabase(a.config, a.logger)
|
||||
}
|
||||
|
||||
// autoMigrate 自动迁移
|
||||
func (a *Application) autoMigrate(db *gorm.DB) error {
|
||||
// 迁移用户相关表
|
||||
return db.AutoMigrate(
|
||||
&entities.User{},
|
||||
// 后续可以添加其他实体
|
||||
)
|
||||
}
|
||||
|
||||
// createLogger 创建日志器
|
||||
func createLogger(cfg *config.Config) (*zap.Logger, error) {
|
||||
level, err := zap.ParseAtomicLevel(cfg.Logger.Level)
|
||||
if err != nil {
|
||||
level = zap.NewAtomicLevelAt(zap.InfoLevel)
|
||||
}
|
||||
|
||||
config := zap.Config{
|
||||
Level: level,
|
||||
Development: cfg.App.IsDevelopment(),
|
||||
Encoding: cfg.Logger.Format,
|
||||
EncoderConfig: zap.NewProductionEncoderConfig(),
|
||||
OutputPaths: []string{cfg.Logger.Output},
|
||||
ErrorOutputPaths: []string{"stderr"},
|
||||
}
|
||||
|
||||
if cfg.Logger.Format == "" {
|
||||
config.Encoding = "json"
|
||||
}
|
||||
if cfg.Logger.Output == "" {
|
||||
config.OutputPaths = []string{"stdout"}
|
||||
}
|
||||
|
||||
return config.Build()
|
||||
}
|
||||
|
||||
// GetConfig 获取配置
|
||||
func (a *Application) GetConfig() *config.Config {
|
||||
return a.config
|
||||
}
|
||||
|
||||
// GetLogger 获取日志器
|
||||
func (a *Application) GetLogger() *zap.Logger {
|
||||
return a.logger
|
||||
}
|
||||
|
||||
// HealthCheck 应用程序健康检查
|
||||
func (a *Application) HealthCheck() error {
|
||||
// 这里可以添加应用程序级别的健康检查逻辑
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetVersion 获取版本信息
|
||||
func (a *Application) GetVersion() map[string]string {
|
||||
return map[string]string{
|
||||
"name": a.config.App.Name,
|
||||
"version": a.config.App.Version,
|
||||
"environment": a.config.App.Env,
|
||||
"go_version": "1.23.4+",
|
||||
}
|
||||
}
|
||||
|
||||
// RunCommand 运行特定命令
|
||||
func (a *Application) RunCommand(command string, args ...string) error {
|
||||
switch command {
|
||||
case "migrate":
|
||||
return a.RunMigrations()
|
||||
case "version":
|
||||
version := a.GetVersion()
|
||||
fmt.Printf("Name: %s\n", version["name"])
|
||||
fmt.Printf("Version: %s\n", version["version"])
|
||||
fmt.Printf("Environment: %s\n", version["environment"])
|
||||
fmt.Printf("Go Version: %s\n", version["go_version"])
|
||||
return nil
|
||||
case "health":
|
||||
if err := a.HealthCheck(); err != nil {
|
||||
fmt.Printf("Health check failed: %v\n", err)
|
||||
return err
|
||||
}
|
||||
fmt.Println("Application is healthy")
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown command: %s", command)
|
||||
}
|
||||
}
|
||||
166
internal/config/config.go
Normal file
166
internal/config/config.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config 应用程序总配置
|
||||
type Config struct {
|
||||
Server ServerConfig `mapstructure:"server"`
|
||||
Database DatabaseConfig `mapstructure:"database"`
|
||||
Redis RedisConfig `mapstructure:"redis"`
|
||||
Cache CacheConfig `mapstructure:"cache"`
|
||||
Logger LoggerConfig `mapstructure:"logger"`
|
||||
JWT JWTConfig `mapstructure:"jwt"`
|
||||
RateLimit RateLimitConfig `mapstructure:"ratelimit"`
|
||||
Monitoring MonitoringConfig `mapstructure:"monitoring"`
|
||||
Health HealthConfig `mapstructure:"health"`
|
||||
Resilience ResilienceConfig `mapstructure:"resilience"`
|
||||
Development DevelopmentConfig `mapstructure:"development"`
|
||||
App AppConfig `mapstructure:"app"`
|
||||
}
|
||||
|
||||
// ServerConfig HTTP服务器配置
|
||||
type ServerConfig struct {
|
||||
Port string `mapstructure:"port"`
|
||||
Mode string `mapstructure:"mode"`
|
||||
Host string `mapstructure:"host"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout"`
|
||||
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port string `mapstructure:"port"`
|
||||
User string `mapstructure:"user"`
|
||||
Password string `mapstructure:"password"`
|
||||
Name string `mapstructure:"name"`
|
||||
SSLMode string `mapstructure:"sslmode"`
|
||||
Timezone string `mapstructure:"timezone"`
|
||||
MaxOpenConns int `mapstructure:"max_open_conns"`
|
||||
MaxIdleConns int `mapstructure:"max_idle_conns"`
|
||||
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
// RedisConfig Redis配置
|
||||
type RedisConfig struct {
|
||||
Host string `mapstructure:"host"`
|
||||
Port string `mapstructure:"port"`
|
||||
Password string `mapstructure:"password"`
|
||||
DB int `mapstructure:"db"`
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
MinIdleConns int `mapstructure:"min_idle_conns"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
DialTimeout time.Duration `mapstructure:"dial_timeout"`
|
||||
ReadTimeout time.Duration `mapstructure:"read_timeout"`
|
||||
WriteTimeout time.Duration `mapstructure:"write_timeout"`
|
||||
}
|
||||
|
||||
// CacheConfig 缓存配置
|
||||
type CacheConfig struct {
|
||||
DefaultTTL time.Duration `mapstructure:"default_ttl"`
|
||||
CleanupInterval time.Duration `mapstructure:"cleanup_interval"`
|
||||
MaxSize int `mapstructure:"max_size"`
|
||||
}
|
||||
|
||||
// LoggerConfig 日志配置
|
||||
type LoggerConfig struct {
|
||||
Level string `mapstructure:"level"`
|
||||
Format string `mapstructure:"format"`
|
||||
Output string `mapstructure:"output"`
|
||||
FilePath string `mapstructure:"file_path"`
|
||||
MaxSize int `mapstructure:"max_size"`
|
||||
MaxBackups int `mapstructure:"max_backups"`
|
||||
MaxAge int `mapstructure:"max_age"`
|
||||
Compress bool `mapstructure:"compress"`
|
||||
}
|
||||
|
||||
// JWTConfig JWT配置
|
||||
type JWTConfig struct {
|
||||
Secret string `mapstructure:"secret"`
|
||||
ExpiresIn time.Duration `mapstructure:"expires_in"`
|
||||
RefreshExpiresIn time.Duration `mapstructure:"refresh_expires_in"`
|
||||
}
|
||||
|
||||
// RateLimitConfig 限流配置
|
||||
type RateLimitConfig struct {
|
||||
Requests int `mapstructure:"requests"`
|
||||
Window time.Duration `mapstructure:"window"`
|
||||
Burst int `mapstructure:"burst"`
|
||||
}
|
||||
|
||||
// MonitoringConfig 监控配置
|
||||
type MonitoringConfig struct {
|
||||
MetricsEnabled bool `mapstructure:"metrics_enabled"`
|
||||
MetricsPort string `mapstructure:"metrics_port"`
|
||||
TracingEnabled bool `mapstructure:"tracing_enabled"`
|
||||
TracingEndpoint string `mapstructure:"tracing_endpoint"`
|
||||
SampleRate float64 `mapstructure:"sample_rate"`
|
||||
}
|
||||
|
||||
// HealthConfig 健康检查配置
|
||||
type HealthConfig struct {
|
||||
Enabled bool `mapstructure:"enabled"`
|
||||
Interval time.Duration `mapstructure:"interval"`
|
||||
Timeout time.Duration `mapstructure:"timeout"`
|
||||
}
|
||||
|
||||
// ResilienceConfig 容错配置
|
||||
type ResilienceConfig struct {
|
||||
CircuitBreakerEnabled bool `mapstructure:"circuit_breaker_enabled"`
|
||||
CircuitBreakerThreshold int `mapstructure:"circuit_breaker_threshold"`
|
||||
CircuitBreakerTimeout time.Duration `mapstructure:"circuit_breaker_timeout"`
|
||||
RetryMaxAttempts int `mapstructure:"retry_max_attempts"`
|
||||
RetryInitialDelay time.Duration `mapstructure:"retry_initial_delay"`
|
||||
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
|
||||
}
|
||||
|
||||
// DevelopmentConfig 开发配置
|
||||
type DevelopmentConfig struct {
|
||||
Debug bool `mapstructure:"debug"`
|
||||
EnableProfiler bool `mapstructure:"enable_profiler"`
|
||||
EnableCors bool `mapstructure:"enable_cors"`
|
||||
CorsOrigins string `mapstructure:"cors_allowed_origins"`
|
||||
CorsMethods string `mapstructure:"cors_allowed_methods"`
|
||||
CorsHeaders string `mapstructure:"cors_allowed_headers"`
|
||||
}
|
||||
|
||||
// AppConfig 应用程序配置
|
||||
type AppConfig struct {
|
||||
Name string `mapstructure:"name"`
|
||||
Version string `mapstructure:"version"`
|
||||
Env string `mapstructure:"env"`
|
||||
}
|
||||
|
||||
// GetDSN 获取数据库DSN连接字符串
|
||||
func (d DatabaseConfig) GetDSN() string {
|
||||
return "host=" + d.Host +
|
||||
" user=" + d.User +
|
||||
" password=" + d.Password +
|
||||
" dbname=" + d.Name +
|
||||
" port=" + d.Port +
|
||||
" sslmode=" + d.SSLMode +
|
||||
" TimeZone=" + d.Timezone
|
||||
}
|
||||
|
||||
// GetRedisAddr 获取Redis地址
|
||||
func (r RedisConfig) GetRedisAddr() string {
|
||||
return r.Host + ":" + r.Port
|
||||
}
|
||||
|
||||
// IsProduction 检查是否为生产环境
|
||||
func (a AppConfig) IsProduction() bool {
|
||||
return a.Env == "production"
|
||||
}
|
||||
|
||||
// IsDevelopment 检查是否为开发环境
|
||||
func (a AppConfig) IsDevelopment() bool {
|
||||
return a.Env == "development"
|
||||
}
|
||||
|
||||
// IsStaging 检查是否为测试环境
|
||||
func (a AppConfig) IsStaging() bool {
|
||||
return a.Env == "staging"
|
||||
}
|
||||
311
internal/config/loader.go
Normal file
311
internal/config/loader.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// LoadConfig 加载应用程序配置
|
||||
func LoadConfig() (*Config, error) {
|
||||
// 设置配置文件名和路径
|
||||
viper.SetConfigName("config")
|
||||
viper.SetConfigType("yaml")
|
||||
viper.AddConfigPath(".")
|
||||
viper.AddConfigPath("./configs")
|
||||
viper.AddConfigPath("$HOME/.tyapi")
|
||||
|
||||
// 设置环境变量前缀
|
||||
viper.SetEnvPrefix("")
|
||||
viper.AutomaticEnv()
|
||||
|
||||
// 配置环境变量键名映射
|
||||
setupEnvKeyMapping()
|
||||
|
||||
// 设置默认值
|
||||
setDefaults()
|
||||
|
||||
// 尝试读取配置文件(可选)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
// 配置文件不存在时使用环境变量和默认值
|
||||
}
|
||||
|
||||
var config Config
|
||||
if err := viper.Unmarshal(&config); err != nil {
|
||||
return nil, fmt.Errorf("解析配置失败: %w", err)
|
||||
}
|
||||
|
||||
// 验证配置
|
||||
if err := validateConfig(&config); err != nil {
|
||||
return nil, fmt.Errorf("配置验证失败: %w", err)
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
// setupEnvKeyMapping 设置环境变量到配置键的映射
|
||||
func setupEnvKeyMapping() {
|
||||
// 服务器配置
|
||||
viper.BindEnv("server.port", "SERVER_PORT")
|
||||
viper.BindEnv("server.mode", "SERVER_MODE")
|
||||
viper.BindEnv("server.host", "SERVER_HOST")
|
||||
viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT")
|
||||
viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT")
|
||||
viper.BindEnv("server.idle_timeout", "SERVER_IDLE_TIMEOUT")
|
||||
|
||||
// 数据库配置
|
||||
viper.BindEnv("database.host", "DB_HOST")
|
||||
viper.BindEnv("database.port", "DB_PORT")
|
||||
viper.BindEnv("database.user", "DB_USER")
|
||||
viper.BindEnv("database.password", "DB_PASSWORD")
|
||||
viper.BindEnv("database.name", "DB_NAME")
|
||||
viper.BindEnv("database.sslmode", "DB_SSLMODE")
|
||||
viper.BindEnv("database.timezone", "DB_TIMEZONE")
|
||||
viper.BindEnv("database.max_open_conns", "DB_MAX_OPEN_CONNS")
|
||||
viper.BindEnv("database.max_idle_conns", "DB_MAX_IDLE_CONNS")
|
||||
viper.BindEnv("database.conn_max_lifetime", "DB_CONN_MAX_LIFETIME")
|
||||
|
||||
// Redis配置
|
||||
viper.BindEnv("redis.host", "REDIS_HOST")
|
||||
viper.BindEnv("redis.port", "REDIS_PORT")
|
||||
viper.BindEnv("redis.password", "REDIS_PASSWORD")
|
||||
viper.BindEnv("redis.db", "REDIS_DB")
|
||||
viper.BindEnv("redis.pool_size", "REDIS_POOL_SIZE")
|
||||
viper.BindEnv("redis.min_idle_conns", "REDIS_MIN_IDLE_CONNS")
|
||||
viper.BindEnv("redis.max_retries", "REDIS_MAX_RETRIES")
|
||||
viper.BindEnv("redis.dial_timeout", "REDIS_DIAL_TIMEOUT")
|
||||
viper.BindEnv("redis.read_timeout", "REDIS_READ_TIMEOUT")
|
||||
viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT")
|
||||
|
||||
// 缓存配置
|
||||
viper.BindEnv("cache.default_ttl", "CACHE_DEFAULT_TTL")
|
||||
viper.BindEnv("cache.cleanup_interval", "CACHE_CLEANUP_INTERVAL")
|
||||
viper.BindEnv("cache.max_size", "CACHE_MAX_SIZE")
|
||||
|
||||
// 日志配置
|
||||
viper.BindEnv("logger.level", "LOG_LEVEL")
|
||||
viper.BindEnv("logger.format", "LOG_FORMAT")
|
||||
viper.BindEnv("logger.output", "LOG_OUTPUT")
|
||||
viper.BindEnv("logger.file_path", "LOG_FILE_PATH")
|
||||
viper.BindEnv("logger.max_size", "LOG_MAX_SIZE")
|
||||
viper.BindEnv("logger.max_backups", "LOG_MAX_BACKUPS")
|
||||
viper.BindEnv("logger.max_age", "LOG_MAX_AGE")
|
||||
viper.BindEnv("logger.compress", "LOG_COMPRESS")
|
||||
|
||||
// JWT配置
|
||||
viper.BindEnv("jwt.secret", "JWT_SECRET")
|
||||
viper.BindEnv("jwt.expires_in", "JWT_EXPIRES_IN")
|
||||
viper.BindEnv("jwt.refresh_expires_in", "JWT_REFRESH_EXPIRES_IN")
|
||||
|
||||
// 限流配置
|
||||
viper.BindEnv("ratelimit.requests", "RATE_LIMIT_REQUESTS")
|
||||
viper.BindEnv("ratelimit.window", "RATE_LIMIT_WINDOW")
|
||||
viper.BindEnv("ratelimit.burst", "RATE_LIMIT_BURST")
|
||||
|
||||
// 监控配置
|
||||
viper.BindEnv("monitoring.metrics_enabled", "METRICS_ENABLED")
|
||||
viper.BindEnv("monitoring.metrics_port", "METRICS_PORT")
|
||||
viper.BindEnv("monitoring.tracing_enabled", "TRACING_ENABLED")
|
||||
viper.BindEnv("monitoring.tracing_endpoint", "TRACING_ENDPOINT")
|
||||
viper.BindEnv("monitoring.sample_rate", "TRACING_SAMPLE_RATE")
|
||||
|
||||
// 健康检查配置
|
||||
viper.BindEnv("health.enabled", "HEALTH_CHECK_ENABLED")
|
||||
viper.BindEnv("health.interval", "HEALTH_CHECK_INTERVAL")
|
||||
viper.BindEnv("health.timeout", "HEALTH_CHECK_TIMEOUT")
|
||||
|
||||
// 容错配置
|
||||
viper.BindEnv("resilience.circuit_breaker_enabled", "CIRCUIT_BREAKER_ENABLED")
|
||||
viper.BindEnv("resilience.circuit_breaker_threshold", "CIRCUIT_BREAKER_THRESHOLD")
|
||||
viper.BindEnv("resilience.circuit_breaker_timeout", "CIRCUIT_BREAKER_TIMEOUT")
|
||||
viper.BindEnv("resilience.retry_max_attempts", "RETRY_MAX_ATTEMPTS")
|
||||
viper.BindEnv("resilience.retry_initial_delay", "RETRY_INITIAL_DELAY")
|
||||
viper.BindEnv("resilience.retry_max_delay", "RETRY_MAX_DELAY")
|
||||
|
||||
// 开发配置
|
||||
viper.BindEnv("development.debug", "DEBUG")
|
||||
viper.BindEnv("development.enable_profiler", "ENABLE_PROFILER")
|
||||
viper.BindEnv("development.enable_cors", "ENABLE_CORS")
|
||||
viper.BindEnv("development.cors_allowed_origins", "CORS_ALLOWED_ORIGINS")
|
||||
viper.BindEnv("development.cors_allowed_methods", "CORS_ALLOWED_METHODS")
|
||||
viper.BindEnv("development.cors_allowed_headers", "CORS_ALLOWED_HEADERS")
|
||||
|
||||
// 应用程序配置
|
||||
viper.BindEnv("app.name", "APP_NAME")
|
||||
viper.BindEnv("app.version", "APP_VERSION")
|
||||
viper.BindEnv("app.env", "ENV")
|
||||
}
|
||||
|
||||
// setDefaults 设置默认配置值
|
||||
func setDefaults() {
|
||||
// 服务器默认值
|
||||
viper.SetDefault("server.port", "8080")
|
||||
viper.SetDefault("server.mode", "debug")
|
||||
viper.SetDefault("server.host", "0.0.0.0")
|
||||
viper.SetDefault("server.read_timeout", "30s")
|
||||
viper.SetDefault("server.write_timeout", "30s")
|
||||
viper.SetDefault("server.idle_timeout", "120s")
|
||||
|
||||
// 数据库默认值
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", "5432")
|
||||
viper.SetDefault("database.user", "postgres")
|
||||
viper.SetDefault("database.password", "password")
|
||||
viper.SetDefault("database.name", "tyapi_db")
|
||||
viper.SetDefault("database.sslmode", "disable")
|
||||
viper.SetDefault("database.timezone", "Asia/Shanghai")
|
||||
viper.SetDefault("database.max_open_conns", 100)
|
||||
viper.SetDefault("database.max_idle_conns", 10)
|
||||
viper.SetDefault("database.conn_max_lifetime", "300s")
|
||||
|
||||
// Redis默认值
|
||||
viper.SetDefault("redis.host", "localhost")
|
||||
viper.SetDefault("redis.port", "6379")
|
||||
viper.SetDefault("redis.password", "")
|
||||
viper.SetDefault("redis.db", 0)
|
||||
viper.SetDefault("redis.pool_size", 10)
|
||||
viper.SetDefault("redis.min_idle_conns", 5)
|
||||
viper.SetDefault("redis.max_retries", 3)
|
||||
viper.SetDefault("redis.dial_timeout", "5s")
|
||||
viper.SetDefault("redis.read_timeout", "3s")
|
||||
viper.SetDefault("redis.write_timeout", "3s")
|
||||
|
||||
// 缓存默认值
|
||||
viper.SetDefault("cache.default_ttl", "300s")
|
||||
viper.SetDefault("cache.cleanup_interval", "600s")
|
||||
viper.SetDefault("cache.max_size", 1000)
|
||||
|
||||
// 日志默认值
|
||||
viper.SetDefault("logger.level", "info")
|
||||
viper.SetDefault("logger.format", "json")
|
||||
viper.SetDefault("logger.output", "stdout")
|
||||
viper.SetDefault("logger.file_path", "logs/app.log")
|
||||
viper.SetDefault("logger.max_size", 100)
|
||||
viper.SetDefault("logger.max_backups", 5)
|
||||
viper.SetDefault("logger.max_age", 30)
|
||||
viper.SetDefault("logger.compress", true)
|
||||
|
||||
// JWT默认值
|
||||
viper.SetDefault("jwt.secret", "your-super-secret-jwt-key-change-this-in-production")
|
||||
viper.SetDefault("jwt.expires_in", "24h")
|
||||
viper.SetDefault("jwt.refresh_expires_in", "168h")
|
||||
|
||||
// 限流默认值
|
||||
viper.SetDefault("ratelimit.requests", 100)
|
||||
viper.SetDefault("ratelimit.window", "60s")
|
||||
viper.SetDefault("ratelimit.burst", 10)
|
||||
|
||||
// 监控默认值
|
||||
viper.SetDefault("monitoring.metrics_enabled", true)
|
||||
viper.SetDefault("monitoring.metrics_port", "9090")
|
||||
viper.SetDefault("monitoring.tracing_enabled", false)
|
||||
viper.SetDefault("monitoring.tracing_endpoint", "http://localhost:14268/api/traces")
|
||||
viper.SetDefault("monitoring.sample_rate", 0.1)
|
||||
|
||||
// 健康检查默认值
|
||||
viper.SetDefault("health.enabled", true)
|
||||
viper.SetDefault("health.interval", "30s")
|
||||
viper.SetDefault("health.timeout", "5s")
|
||||
|
||||
// 容错默认值
|
||||
viper.SetDefault("resilience.circuit_breaker_enabled", true)
|
||||
viper.SetDefault("resilience.circuit_breaker_threshold", 5)
|
||||
viper.SetDefault("resilience.circuit_breaker_timeout", "60s")
|
||||
viper.SetDefault("resilience.retry_max_attempts", 3)
|
||||
viper.SetDefault("resilience.retry_initial_delay", "100ms")
|
||||
viper.SetDefault("resilience.retry_max_delay", "2s")
|
||||
|
||||
// 开发默认值
|
||||
viper.SetDefault("development.debug", true)
|
||||
viper.SetDefault("development.enable_profiler", false)
|
||||
viper.SetDefault("development.enable_cors", true)
|
||||
viper.SetDefault("development.cors_allowed_origins", "*")
|
||||
viper.SetDefault("development.cors_allowed_methods", "GET,POST,PUT,DELETE,OPTIONS")
|
||||
viper.SetDefault("development.cors_allowed_headers", "*")
|
||||
|
||||
// 应用程序默认值
|
||||
viper.SetDefault("app.name", "tyapi-server")
|
||||
viper.SetDefault("app.version", "1.0.0")
|
||||
viper.SetDefault("app.env", "development")
|
||||
}
|
||||
|
||||
// validateConfig 验证配置
|
||||
func validateConfig(config *Config) error {
|
||||
// 验证必要的配置项
|
||||
if config.Database.Host == "" {
|
||||
return fmt.Errorf("数据库主机地址不能为空")
|
||||
}
|
||||
|
||||
if config.Database.User == "" {
|
||||
return fmt.Errorf("数据库用户名不能为空")
|
||||
}
|
||||
|
||||
if config.Database.Name == "" {
|
||||
return fmt.Errorf("数据库名称不能为空")
|
||||
}
|
||||
|
||||
if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" {
|
||||
if config.App.IsProduction() {
|
||||
return fmt.Errorf("生产环境必须设置安全的JWT密钥")
|
||||
}
|
||||
}
|
||||
|
||||
// 验证超时配置
|
||||
if config.Server.ReadTimeout <= 0 {
|
||||
return fmt.Errorf("服务器读取超时时间必须大于0")
|
||||
}
|
||||
|
||||
if config.Server.WriteTimeout <= 0 {
|
||||
return fmt.Errorf("服务器写入超时时间必须大于0")
|
||||
}
|
||||
|
||||
// 验证数据库连接池配置
|
||||
if config.Database.MaxOpenConns <= 0 {
|
||||
return fmt.Errorf("数据库最大连接数必须大于0")
|
||||
}
|
||||
|
||||
if config.Database.MaxIdleConns <= 0 {
|
||||
return fmt.Errorf("数据库最大空闲连接数必须大于0")
|
||||
}
|
||||
|
||||
if config.Database.MaxIdleConns > config.Database.MaxOpenConns {
|
||||
return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEnv 获取环境变量,如果不存在则返回默认值
|
||||
func GetEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// ParseDuration 解析时间字符串
|
||||
func ParseDuration(s string) time.Duration {
|
||||
d, err := time.ParseDuration(s)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// SplitAndTrim 分割字符串并去除空格
|
||||
func SplitAndTrim(s, sep string) []string {
|
||||
parts := strings.Split(s, sep)
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if trimmed := strings.TrimSpace(part); trimmed != "" {
|
||||
result = append(result, trimmed)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
441
internal/container/container.go
Normal file
441
internal/container/container.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package container
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/fx"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/domains/user/handlers"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/domains/user/routes"
|
||||
"tyapi-server/internal/domains/user/services"
|
||||
"tyapi-server/internal/shared/cache"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/events"
|
||||
"tyapi-server/internal/shared/health"
|
||||
"tyapi-server/internal/shared/http"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
)
|
||||
|
||||
// Container 应用容器
|
||||
type Container struct {
|
||||
App *fx.App
|
||||
}
|
||||
|
||||
// NewContainer 创建新的应用容器
|
||||
func NewContainer() *Container {
|
||||
app := fx.New(
|
||||
// 配置模块
|
||||
fx.Provide(
|
||||
config.LoadConfig,
|
||||
),
|
||||
|
||||
// 基础设施模块
|
||||
fx.Provide(
|
||||
NewLogger,
|
||||
NewDatabase,
|
||||
NewRedisClient,
|
||||
NewRedisCache,
|
||||
NewEventBus,
|
||||
NewHealthChecker,
|
||||
),
|
||||
|
||||
// HTTP基础组件
|
||||
fx.Provide(
|
||||
NewResponseBuilder,
|
||||
NewRequestValidator,
|
||||
NewGinRouter,
|
||||
),
|
||||
|
||||
// 中间件组件
|
||||
fx.Provide(
|
||||
NewRequestIDMiddleware,
|
||||
NewSecurityHeadersMiddleware,
|
||||
NewResponseTimeMiddleware,
|
||||
NewCORSMiddleware,
|
||||
NewRateLimitMiddleware,
|
||||
NewRequestLoggerMiddleware,
|
||||
NewJWTAuthMiddleware,
|
||||
NewOptionalAuthMiddleware,
|
||||
),
|
||||
|
||||
// 用户域组件
|
||||
fx.Provide(
|
||||
NewUserRepository,
|
||||
NewUserService,
|
||||
NewUserHandler,
|
||||
NewUserRoutes,
|
||||
),
|
||||
|
||||
// 应用生命周期
|
||||
fx.Invoke(
|
||||
RegisterLifecycleHooks,
|
||||
RegisterMiddlewares,
|
||||
RegisterRoutes,
|
||||
),
|
||||
)
|
||||
|
||||
return &Container{App: app}
|
||||
}
|
||||
|
||||
// Start 启动容器
|
||||
func (c *Container) Start() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return c.App.Start(ctx)
|
||||
}
|
||||
|
||||
// Stop 停止容器
|
||||
func (c *Container) Stop() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
return c.App.Stop(ctx)
|
||||
}
|
||||
|
||||
// 基础设施构造函数
|
||||
|
||||
// NewLogger 创建日志器
|
||||
func NewLogger(cfg *config.Config) (*zap.Logger, error) {
|
||||
level, err := zap.ParseAtomicLevel(cfg.Logger.Level)
|
||||
if err != nil {
|
||||
level = zap.NewAtomicLevelAt(zap.InfoLevel)
|
||||
}
|
||||
|
||||
config := zap.Config{
|
||||
Level: level,
|
||||
Development: cfg.App.IsDevelopment(),
|
||||
Encoding: cfg.Logger.Format,
|
||||
EncoderConfig: zap.NewProductionEncoderConfig(),
|
||||
OutputPaths: []string{cfg.Logger.Output},
|
||||
ErrorOutputPaths: []string{"stderr"},
|
||||
}
|
||||
|
||||
if cfg.Logger.Format == "" {
|
||||
config.Encoding = "json"
|
||||
}
|
||||
if cfg.Logger.Output == "" {
|
||||
config.OutputPaths = []string{"stdout"}
|
||||
}
|
||||
|
||||
return config.Build()
|
||||
}
|
||||
|
||||
// NewDatabase 创建数据库连接
|
||||
func NewDatabase(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
|
||||
dbConfig := database.Config{
|
||||
Host: cfg.Database.Host,
|
||||
Port: cfg.Database.Port,
|
||||
User: cfg.Database.User,
|
||||
Password: cfg.Database.Password,
|
||||
Name: cfg.Database.Name,
|
||||
SSLMode: cfg.Database.SSLMode,
|
||||
Timezone: cfg.Database.Timezone,
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: cfg.Database.ConnMaxLifetime,
|
||||
}
|
||||
|
||||
db, err := database.NewConnection(dbConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db.DB, nil
|
||||
}
|
||||
|
||||
// NewRedisClient 创建Redis客户端
|
||||
func NewRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, error) {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Redis.GetRedisAddr(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
PoolSize: cfg.Redis.PoolSize,
|
||||
MinIdleConns: cfg.Redis.MinIdleConns,
|
||||
DialTimeout: cfg.Redis.DialTimeout,
|
||||
ReadTimeout: cfg.Redis.ReadTimeout,
|
||||
WriteTimeout: cfg.Redis.WriteTimeout,
|
||||
})
|
||||
|
||||
// 测试连接
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
_, err := client.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
logger.Error("Failed to connect to Redis", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Redis connection established")
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存服务
|
||||
func NewRedisCache(client *redis.Client, logger *zap.Logger, cfg *config.Config) interfaces.CacheService {
|
||||
return cache.NewRedisCache(client, logger, "app")
|
||||
}
|
||||
|
||||
// NewEventBus 创建事件总线
|
||||
func NewEventBus(logger *zap.Logger, cfg *config.Config) interfaces.EventBus {
|
||||
return events.NewMemoryEventBus(logger, 5) // 默认5个工作协程
|
||||
}
|
||||
|
||||
// NewHealthChecker 创建健康检查器
|
||||
func NewHealthChecker(logger *zap.Logger) *health.HealthChecker {
|
||||
return health.NewHealthChecker(logger)
|
||||
}
|
||||
|
||||
// HTTP组件构造函数
|
||||
|
||||
// NewResponseBuilder 创建响应构建器
|
||||
func NewResponseBuilder() interfaces.ResponseBuilder {
|
||||
return http.NewResponseBuilder()
|
||||
}
|
||||
|
||||
// NewRequestValidator 创建请求验证器
|
||||
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
|
||||
return http.NewRequestValidator(response)
|
||||
}
|
||||
|
||||
// NewGinRouter 创建Gin路由器
|
||||
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *http.GinRouter {
|
||||
return http.NewGinRouter(cfg, logger)
|
||||
}
|
||||
|
||||
// 中间件构造函数
|
||||
|
||||
// NewRequestIDMiddleware 创建请求ID中间件
|
||||
func NewRequestIDMiddleware() *middleware.RequestIDMiddleware {
|
||||
return middleware.NewRequestIDMiddleware()
|
||||
}
|
||||
|
||||
// NewSecurityHeadersMiddleware 创建安全头部中间件
|
||||
func NewSecurityHeadersMiddleware() *middleware.SecurityHeadersMiddleware {
|
||||
return middleware.NewSecurityHeadersMiddleware()
|
||||
}
|
||||
|
||||
// NewResponseTimeMiddleware 创建响应时间中间件
|
||||
func NewResponseTimeMiddleware() *middleware.ResponseTimeMiddleware {
|
||||
return middleware.NewResponseTimeMiddleware()
|
||||
}
|
||||
|
||||
// NewCORSMiddleware 创建CORS中间件
|
||||
func NewCORSMiddleware(cfg *config.Config) *middleware.CORSMiddleware {
|
||||
return middleware.NewCORSMiddleware(cfg)
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config) *middleware.RateLimitMiddleware {
|
||||
return middleware.NewRateLimitMiddleware(cfg)
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger) *middleware.RequestLoggerMiddleware {
|
||||
return middleware.NewRequestLoggerMiddleware(logger)
|
||||
}
|
||||
|
||||
// NewJWTAuthMiddleware 创建JWT认证中间件
|
||||
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *middleware.JWTAuthMiddleware {
|
||||
return middleware.NewJWTAuthMiddleware(cfg, logger)
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware 创建可选认证中间件
|
||||
func NewOptionalAuthMiddleware(jwtAuth *middleware.JWTAuthMiddleware) *middleware.OptionalAuthMiddleware {
|
||||
return middleware.NewOptionalAuthMiddleware(jwtAuth)
|
||||
}
|
||||
|
||||
// 用户域构造函数
|
||||
|
||||
// NewUserRepository 创建用户仓储
|
||||
func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *repositories.UserRepository {
|
||||
return repositories.NewUserRepository(db, cache, logger)
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务
|
||||
func NewUserService(
|
||||
repo *repositories.UserRepository,
|
||||
eventBus interfaces.EventBus,
|
||||
logger *zap.Logger,
|
||||
) *services.UserService {
|
||||
return services.NewUserService(repo, eventBus, logger)
|
||||
}
|
||||
|
||||
// NewUserHandler 创建用户处理器
|
||||
func NewUserHandler(
|
||||
userService *services.UserService,
|
||||
response interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
jwtAuth *middleware.JWTAuthMiddleware,
|
||||
) *handlers.UserHandler {
|
||||
return handlers.NewUserHandler(userService, response, validator, logger, jwtAuth)
|
||||
}
|
||||
|
||||
// NewUserRoutes 创建用户路由
|
||||
func NewUserRoutes(
|
||||
handler *handlers.UserHandler,
|
||||
jwtAuth *middleware.JWTAuthMiddleware,
|
||||
optionalAuth *middleware.OptionalAuthMiddleware,
|
||||
) *routes.UserRoutes {
|
||||
return routes.NewUserRoutes(handler, jwtAuth, optionalAuth)
|
||||
}
|
||||
|
||||
// 注册函数
|
||||
|
||||
// RegisterMiddlewares 注册中间件
|
||||
func RegisterMiddlewares(
|
||||
router *http.GinRouter,
|
||||
requestID *middleware.RequestIDMiddleware,
|
||||
security *middleware.SecurityHeadersMiddleware,
|
||||
responseTime *middleware.ResponseTimeMiddleware,
|
||||
cors *middleware.CORSMiddleware,
|
||||
rateLimit *middleware.RateLimitMiddleware,
|
||||
requestLogger *middleware.RequestLoggerMiddleware,
|
||||
) {
|
||||
// 注册全局中间件
|
||||
router.RegisterMiddleware(requestID)
|
||||
router.RegisterMiddleware(security)
|
||||
router.RegisterMiddleware(responseTime)
|
||||
router.RegisterMiddleware(cors)
|
||||
router.RegisterMiddleware(rateLimit)
|
||||
router.RegisterMiddleware(requestLogger)
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册路由
|
||||
func RegisterRoutes(
|
||||
router *http.GinRouter,
|
||||
userRoutes *routes.UserRoutes,
|
||||
) {
|
||||
// 设置默认路由
|
||||
router.SetupDefaultRoutes()
|
||||
|
||||
// 注册用户路由
|
||||
userRoutes.RegisterRoutes(router.GetEngine())
|
||||
userRoutes.RegisterPublicRoutes(router.GetEngine())
|
||||
userRoutes.RegisterAdminRoutes(router.GetEngine())
|
||||
userRoutes.RegisterHealthRoutes(router.GetEngine())
|
||||
|
||||
// 打印路由信息
|
||||
router.PrintRoutes()
|
||||
}
|
||||
|
||||
// 生命周期钩子
|
||||
|
||||
// RegisterLifecycleHooks 注册生命周期钩子
|
||||
func RegisterLifecycleHooks(
|
||||
lc fx.Lifecycle,
|
||||
logger *zap.Logger,
|
||||
cfg *config.Config,
|
||||
db *gorm.DB,
|
||||
cache interfaces.CacheService,
|
||||
eventBus interfaces.EventBus,
|
||||
healthChecker *health.HealthChecker,
|
||||
router *http.GinRouter,
|
||||
userService *services.UserService,
|
||||
) {
|
||||
lc.Append(fx.Hook{
|
||||
OnStart: func(ctx context.Context) error {
|
||||
logger.Info("Starting application services...")
|
||||
|
||||
// 注册服务到健康检查器
|
||||
healthChecker.RegisterService(userService)
|
||||
|
||||
// 初始化缓存服务
|
||||
if err := cache.Initialize(ctx); err != nil {
|
||||
logger.Error("Failed to initialize cache", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 启动事件总线
|
||||
if err := eventBus.Start(ctx); err != nil {
|
||||
logger.Error("Failed to start event bus", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 启动健康检查(如果启用)
|
||||
if cfg.Health.Enabled {
|
||||
go healthChecker.StartPeriodicCheck(ctx, cfg.Health.Interval)
|
||||
}
|
||||
|
||||
// 启动HTTP服务器
|
||||
go func() {
|
||||
addr := fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
|
||||
if err := router.Start(addr); err != nil {
|
||||
logger.Error("Failed to start HTTP server", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Info("All services started successfully")
|
||||
return nil
|
||||
},
|
||||
OnStop: func(ctx context.Context) error {
|
||||
logger.Info("Stopping application services...")
|
||||
|
||||
// 停止HTTP服务器
|
||||
if err := router.Stop(ctx); err != nil {
|
||||
logger.Error("Failed to stop HTTP server", zap.Error(err))
|
||||
}
|
||||
|
||||
// 停止事件总线
|
||||
if err := eventBus.Stop(ctx); err != nil {
|
||||
logger.Error("Failed to stop event bus", zap.Error(err))
|
||||
}
|
||||
|
||||
// 关闭缓存服务
|
||||
if err := cache.Shutdown(ctx); err != nil {
|
||||
logger.Error("Failed to shutdown cache service", zap.Error(err))
|
||||
}
|
||||
|
||||
// 关闭数据库连接
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
if err := sqlDB.Close(); err != nil {
|
||||
logger.Error("Failed to close database", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("All services stopped")
|
||||
return nil
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ServiceRegistrar 服务注册器接口
|
||||
type ServiceRegistrar interface {
|
||||
RegisterServices() fx.Option
|
||||
}
|
||||
|
||||
// DomainModule 领域模块接口
|
||||
type DomainModule interface {
|
||||
ServiceRegistrar
|
||||
GetName() string
|
||||
GetDependencies() []string
|
||||
}
|
||||
|
||||
// RegisterDomainModule 注册领域模块
|
||||
func RegisterDomainModule(module DomainModule) fx.Option {
|
||||
return fx.Options(
|
||||
fx.Provide(
|
||||
fx.Annotated{
|
||||
Name: module.GetName(),
|
||||
Target: func() DomainModule {
|
||||
return module
|
||||
},
|
||||
},
|
||||
),
|
||||
module.RegisterServices(),
|
||||
)
|
||||
}
|
||||
|
||||
// GetContainer 获取容器实例(用于测试或特殊情况)
|
||||
func GetContainer(cfg *config.Config) *Container {
|
||||
return NewContainer()
|
||||
}
|
||||
173
internal/domains/user/dto/user_dto.go
Normal file
173
internal/domains/user/dto/user_dto.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
)
|
||||
|
||||
// CreateUserRequest 创建用户请求
|
||||
type CreateUserRequest struct {
|
||||
Username string `json:"username" binding:"required,min=3,max=50" example:"john_doe"`
|
||||
Email string `json:"email" binding:"required,email" example:"john@example.com"`
|
||||
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
|
||||
FirstName string `json:"first_name" binding:"max=50" example:"John"`
|
||||
LastName string `json:"last_name" binding:"max=50" example:"Doe"`
|
||||
Phone string `json:"phone" binding:"omitempty,max=20" example:"+86-13800138000"`
|
||||
}
|
||||
|
||||
// UpdateUserRequest 更新用户请求
|
||||
type UpdateUserRequest struct {
|
||||
FirstName *string `json:"first_name,omitempty" binding:"omitempty,max=50" example:"John"`
|
||||
LastName *string `json:"last_name,omitempty" binding:"omitempty,max=50" example:"Doe"`
|
||||
Phone *string `json:"phone,omitempty" binding:"omitempty,max=20" example:"+86-13800138000"`
|
||||
Avatar *string `json:"avatar,omitempty" binding:"omitempty,url" example:"https://example.com/avatar.jpg"`
|
||||
}
|
||||
|
||||
// ChangePasswordRequest 修改密码请求
|
||||
type ChangePasswordRequest struct {
|
||||
OldPassword string `json:"old_password" binding:"required" example:"oldpassword123"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6,max=128" example:"newpassword123"`
|
||||
}
|
||||
|
||||
// UserResponse 用户响应
|
||||
type UserResponse struct {
|
||||
ID string `json:"id" example:"123e4567-e89b-12d3-a456-426614174000"`
|
||||
Username string `json:"username" example:"john_doe"`
|
||||
Email string `json:"email" example:"john@example.com"`
|
||||
FirstName string `json:"first_name" example:"John"`
|
||||
LastName string `json:"last_name" example:"Doe"`
|
||||
Phone string `json:"phone" example:"+86-13800138000"`
|
||||
Avatar string `json:"avatar" example:"https://example.com/avatar.jpg"`
|
||||
Status entities.UserStatus `json:"status" example:"active"`
|
||||
LastLoginAt *time.Time `json:"last_login_at,omitempty" example:"2024-01-01T00:00:00Z"`
|
||||
CreatedAt time.Time `json:"created_at" example:"2024-01-01T00:00:00Z"`
|
||||
UpdatedAt time.Time `json:"updated_at" example:"2024-01-01T00:00:00Z"`
|
||||
Profile *UserProfileResponse `json:"profile,omitempty"`
|
||||
}
|
||||
|
||||
// UserProfileResponse 用户档案响应
|
||||
type UserProfileResponse struct {
|
||||
Bio string `json:"bio,omitempty" example:"Software Developer"`
|
||||
Location string `json:"location,omitempty" example:"Beijing, China"`
|
||||
Website string `json:"website,omitempty" example:"https://johndoe.com"`
|
||||
Birthday *time.Time `json:"birthday,omitempty" example:"1990-01-01T00:00:00Z"`
|
||||
Gender string `json:"gender,omitempty" example:"male"`
|
||||
Timezone string `json:"timezone,omitempty" example:"Asia/Shanghai"`
|
||||
Language string `json:"language,omitempty" example:"zh-CN"`
|
||||
}
|
||||
|
||||
// UserListRequest 用户列表请求
|
||||
type UserListRequest struct {
|
||||
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
|
||||
PageSize int `form:"page_size" binding:"omitempty,min=1,max=100" example:"20"`
|
||||
Sort string `form:"sort" binding:"omitempty,oneof=created_at updated_at username email" example:"created_at"`
|
||||
Order string `form:"order" binding:"omitempty,oneof=asc desc" example:"desc"`
|
||||
Status entities.UserStatus `form:"status" binding:"omitempty,oneof=active inactive suspended pending" example:"active"`
|
||||
Search string `form:"search" binding:"omitempty,max=100" example:"john"`
|
||||
Filters map[string]interface{} `form:"-"`
|
||||
}
|
||||
|
||||
// UserListResponse 用户列表响应
|
||||
type UserListResponse struct {
|
||||
Users []*UserResponse `json:"users"`
|
||||
Pagination PaginationMeta `json:"pagination"`
|
||||
}
|
||||
|
||||
// PaginationMeta 分页元数据
|
||||
type PaginationMeta struct {
|
||||
Page int `json:"page" example:"1"`
|
||||
PageSize int `json:"page_size" example:"20"`
|
||||
Total int64 `json:"total" example:"100"`
|
||||
TotalPages int `json:"total_pages" example:"5"`
|
||||
HasNext bool `json:"has_next" example:"true"`
|
||||
HasPrev bool `json:"has_prev" example:"false"`
|
||||
}
|
||||
|
||||
// LoginRequest 登录请求
|
||||
type LoginRequest struct {
|
||||
Login string `json:"login" binding:"required" example:"john_doe"`
|
||||
Password string `json:"password" binding:"required" example:"password123"`
|
||||
}
|
||||
|
||||
// LoginResponse 登录响应
|
||||
type LoginResponse struct {
|
||||
User *UserResponse `json:"user"`
|
||||
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
|
||||
TokenType string `json:"token_type" example:"Bearer"`
|
||||
ExpiresIn int64 `json:"expires_in" example:"86400"`
|
||||
}
|
||||
|
||||
// UpdateProfileRequest 更新用户档案请求
|
||||
type UpdateProfileRequest struct {
|
||||
Bio *string `json:"bio,omitempty" binding:"omitempty,max=500" example:"Software Developer"`
|
||||
Location *string `json:"location,omitempty" binding:"omitempty,max=100" example:"Beijing, China"`
|
||||
Website *string `json:"website,omitempty" binding:"omitempty,url" example:"https://johndoe.com"`
|
||||
Birthday *time.Time `json:"birthday,omitempty" example:"1990-01-01T00:00:00Z"`
|
||||
Gender *string `json:"gender,omitempty" binding:"omitempty,oneof=male female other" example:"male"`
|
||||
Timezone *string `json:"timezone,omitempty" binding:"omitempty,max=50" example:"Asia/Shanghai"`
|
||||
Language *string `json:"language,omitempty" binding:"omitempty,max=10" example:"zh-CN"`
|
||||
}
|
||||
|
||||
// UserStatsResponse 用户统计响应
|
||||
type UserStatsResponse struct {
|
||||
TotalUsers int64 `json:"total_users" example:"1000"`
|
||||
ActiveUsers int64 `json:"active_users" example:"950"`
|
||||
InactiveUsers int64 `json:"inactive_users" example:"30"`
|
||||
SuspendedUsers int64 `json:"suspended_users" example:"20"`
|
||||
NewUsersToday int64 `json:"new_users_today" example:"5"`
|
||||
NewUsersWeek int64 `json:"new_users_week" example:"25"`
|
||||
NewUsersMonth int64 `json:"new_users_month" example:"120"`
|
||||
}
|
||||
|
||||
// UserSearchRequest 用户搜索请求
|
||||
type UserSearchRequest struct {
|
||||
Query string `form:"q" binding:"required,min=1,max=100" example:"john"`
|
||||
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
|
||||
PageSize int `form:"page_size" binding:"omitempty,min=1,max=50" example:"10"`
|
||||
}
|
||||
|
||||
// 转换方法
|
||||
func (r *CreateUserRequest) ToEntity() *entities.User {
|
||||
return &entities.User{
|
||||
Username: r.Username,
|
||||
Email: r.Email,
|
||||
Password: r.Password,
|
||||
FirstName: r.FirstName,
|
||||
LastName: r.LastName,
|
||||
Phone: r.Phone,
|
||||
Status: entities.UserStatusActive,
|
||||
}
|
||||
}
|
||||
|
||||
func FromEntity(user *entities.User) *UserResponse {
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &UserResponse{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Email: user.Email,
|
||||
FirstName: user.FirstName,
|
||||
LastName: user.LastName,
|
||||
Phone: user.Phone,
|
||||
Avatar: user.Avatar,
|
||||
Status: user.Status,
|
||||
LastLoginAt: user.LastLoginAt,
|
||||
CreatedAt: user.CreatedAt,
|
||||
UpdatedAt: user.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func FromEntities(users []*entities.User) []*UserResponse {
|
||||
if users == nil {
|
||||
return []*UserResponse{}
|
||||
}
|
||||
|
||||
responses := make([]*UserResponse, len(users))
|
||||
for i, user := range users {
|
||||
responses[i] = FromEntity(user)
|
||||
}
|
||||
return responses
|
||||
}
|
||||
138
internal/domains/user/entities/user.go
Normal file
138
internal/domains/user/entities/user.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package entities
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// User 用户实体
|
||||
type User struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
|
||||
Username string `gorm:"uniqueIndex;type:varchar(50);not null" json:"username"`
|
||||
Email string `gorm:"uniqueIndex;type:varchar(100);not null" json:"email"`
|
||||
Password string `gorm:"type:varchar(255);not null" json:"-"`
|
||||
FirstName string `gorm:"type:varchar(50)" json:"first_name"`
|
||||
LastName string `gorm:"type:varchar(50)" json:"last_name"`
|
||||
Phone string `gorm:"type:varchar(20)" json:"phone"`
|
||||
Avatar string `gorm:"type:varchar(255)" json:"avatar"`
|
||||
Status UserStatus `gorm:"type:varchar(20);default:'active'" json:"status"`
|
||||
LastLoginAt *time.Time `json:"last_login_at"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
|
||||
|
||||
// 软删除字段
|
||||
IsDeleted bool `gorm:"default:false" json:"is_deleted"`
|
||||
|
||||
// 版本控制
|
||||
Version int `gorm:"default:1" json:"version"`
|
||||
}
|
||||
|
||||
// UserStatus 用户状态枚举
|
||||
type UserStatus string
|
||||
|
||||
const (
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusInactive UserStatus = "inactive"
|
||||
UserStatusSuspended UserStatus = "suspended"
|
||||
UserStatusPending UserStatus = "pending"
|
||||
)
|
||||
|
||||
// 实现 Entity 接口
|
||||
func (u *User) GetID() string {
|
||||
return u.ID
|
||||
}
|
||||
|
||||
func (u *User) GetCreatedAt() time.Time {
|
||||
return u.CreatedAt
|
||||
}
|
||||
|
||||
func (u *User) GetUpdatedAt() time.Time {
|
||||
return u.UpdatedAt
|
||||
}
|
||||
|
||||
// 业务方法
|
||||
func (u *User) IsActive() bool {
|
||||
return u.Status == UserStatusActive && !u.IsDeleted
|
||||
}
|
||||
|
||||
func (u *User) GetFullName() string {
|
||||
if u.FirstName == "" && u.LastName == "" {
|
||||
return u.Username
|
||||
}
|
||||
return u.FirstName + " " + u.LastName
|
||||
}
|
||||
|
||||
func (u *User) CanLogin() bool {
|
||||
return u.IsActive() && u.Status != UserStatusSuspended
|
||||
}
|
||||
|
||||
func (u *User) MarkAsDeleted() {
|
||||
u.IsDeleted = true
|
||||
u.Status = UserStatusInactive
|
||||
}
|
||||
|
||||
func (u *User) Restore() {
|
||||
u.IsDeleted = false
|
||||
u.Status = UserStatusActive
|
||||
}
|
||||
|
||||
func (u *User) UpdateLastLogin() {
|
||||
now := time.Now()
|
||||
u.LastLoginAt = &now
|
||||
}
|
||||
|
||||
// 验证方法
|
||||
func (u *User) Validate() error {
|
||||
if u.Username == "" {
|
||||
return NewValidationError("username is required")
|
||||
}
|
||||
if u.Email == "" {
|
||||
return NewValidationError("email is required")
|
||||
}
|
||||
if u.Password == "" {
|
||||
return NewValidationError("password is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (User) TableName() string {
|
||||
return "users"
|
||||
}
|
||||
|
||||
// ValidationError 验证错误
|
||||
type ValidationError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ValidationError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
func NewValidationError(message string) *ValidationError {
|
||||
return &ValidationError{Message: message}
|
||||
}
|
||||
|
||||
// UserProfile 用户档案(扩展信息)
|
||||
type UserProfile struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
|
||||
UserID string `gorm:"type:varchar(36);not null;index" json:"user_id"`
|
||||
Bio string `gorm:"type:text" json:"bio"`
|
||||
Location string `gorm:"type:varchar(100)" json:"location"`
|
||||
Website string `gorm:"type:varchar(255)" json:"website"`
|
||||
Birthday *time.Time `json:"birthday"`
|
||||
Gender string `gorm:"type:varchar(10)" json:"gender"`
|
||||
Timezone string `gorm:"type:varchar(50)" json:"timezone"`
|
||||
Language string `gorm:"type:varchar(10);default:'zh-CN'" json:"language"`
|
||||
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
|
||||
|
||||
// 关联关系
|
||||
User *User `gorm:"foreignKey:UserID;references:ID" json:"user,omitempty"`
|
||||
}
|
||||
|
||||
func (UserProfile) TableName() string {
|
||||
return "user_profiles"
|
||||
}
|
||||
299
internal/domains/user/events/user_events.go
Normal file
299
internal/domains/user/events/user_events.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserEventType 用户事件类型
|
||||
type UserEventType string
|
||||
|
||||
const (
|
||||
UserCreatedEvent UserEventType = "user.created"
|
||||
UserUpdatedEvent UserEventType = "user.updated"
|
||||
UserDeletedEvent UserEventType = "user.deleted"
|
||||
UserRestoredEvent UserEventType = "user.restored"
|
||||
UserLoggedInEvent UserEventType = "user.logged_in"
|
||||
UserLoggedOutEvent UserEventType = "user.logged_out"
|
||||
UserPasswordChangedEvent UserEventType = "user.password_changed"
|
||||
UserStatusChangedEvent UserEventType = "user.status_changed"
|
||||
UserProfileUpdatedEvent UserEventType = "user.profile_updated"
|
||||
)
|
||||
|
||||
// BaseUserEvent 用户事件基础结构
|
||||
type BaseUserEvent struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Version string `json:"version"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Source string `json:"source"`
|
||||
AggregateID string `json:"aggregate_id"`
|
||||
AggregateType string `json:"aggregate_type"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
Payload interface{} `json:"payload"`
|
||||
|
||||
// DDD特有字段
|
||||
DomainVersion string `json:"domain_version"`
|
||||
CausationID string `json:"causation_id"`
|
||||
CorrelationID string `json:"correlation_id"`
|
||||
}
|
||||
|
||||
// 实现 Event 接口
|
||||
func (e *BaseUserEvent) GetID() string {
|
||||
return e.ID
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetType() string {
|
||||
return e.Type
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetVersion() string {
|
||||
return e.Version
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetTimestamp() time.Time {
|
||||
return e.Timestamp
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetPayload() interface{} {
|
||||
return e.Payload
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetMetadata() map[string]interface{} {
|
||||
return e.Metadata
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetSource() string {
|
||||
return e.Source
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetAggregateID() string {
|
||||
return e.AggregateID
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetAggregateType() string {
|
||||
return e.AggregateType
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetDomainVersion() string {
|
||||
return e.DomainVersion
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetCausationID() string {
|
||||
return e.CausationID
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) GetCorrelationID() string {
|
||||
return e.CorrelationID
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) Marshal() ([]byte, error) {
|
||||
return json.Marshal(e)
|
||||
}
|
||||
|
||||
func (e *BaseUserEvent) Unmarshal(data []byte) error {
|
||||
return json.Unmarshal(data, e)
|
||||
}
|
||||
|
||||
// UserCreated 用户创建事件
|
||||
type UserCreated struct {
|
||||
*BaseUserEvent
|
||||
User *entities.User `json:"user"`
|
||||
}
|
||||
|
||||
func NewUserCreatedEvent(user *entities.User, correlationID string) *UserCreated {
|
||||
return &UserCreated{
|
||||
BaseUserEvent: &BaseUserEvent{
|
||||
ID: uuid.New().String(),
|
||||
Type: string(UserCreatedEvent),
|
||||
Version: "1.0",
|
||||
Timestamp: time.Now(),
|
||||
Source: "user-service",
|
||||
AggregateID: user.ID,
|
||||
AggregateType: "User",
|
||||
DomainVersion: "1.0",
|
||||
CorrelationID: correlationID,
|
||||
Metadata: map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"email": user.Email,
|
||||
},
|
||||
},
|
||||
User: user,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *UserCreated) GetPayload() interface{} {
|
||||
return e.User
|
||||
}
|
||||
|
||||
// UserUpdated 用户更新事件
|
||||
type UserUpdated struct {
|
||||
*BaseUserEvent
|
||||
UserID string `json:"user_id"`
|
||||
Changes map[string]interface{} `json:"changes"`
|
||||
OldValues map[string]interface{} `json:"old_values"`
|
||||
NewValues map[string]interface{} `json:"new_values"`
|
||||
}
|
||||
|
||||
func NewUserUpdatedEvent(userID string, changes, oldValues, newValues map[string]interface{}, correlationID string) *UserUpdated {
|
||||
return &UserUpdated{
|
||||
BaseUserEvent: &BaseUserEvent{
|
||||
ID: uuid.New().String(),
|
||||
Type: string(UserUpdatedEvent),
|
||||
Version: "1.0",
|
||||
Timestamp: time.Now(),
|
||||
Source: "user-service",
|
||||
AggregateID: userID,
|
||||
AggregateType: "User",
|
||||
DomainVersion: "1.0",
|
||||
CorrelationID: correlationID,
|
||||
Metadata: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"changed_fields": len(changes),
|
||||
},
|
||||
},
|
||||
UserID: userID,
|
||||
Changes: changes,
|
||||
OldValues: oldValues,
|
||||
NewValues: newValues,
|
||||
}
|
||||
}
|
||||
|
||||
// UserDeleted 用户删除事件
|
||||
type UserDeleted struct {
|
||||
*BaseUserEvent
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
SoftDelete bool `json:"soft_delete"`
|
||||
}
|
||||
|
||||
func NewUserDeletedEvent(userID, username, email string, softDelete bool, correlationID string) *UserDeleted {
|
||||
return &UserDeleted{
|
||||
BaseUserEvent: &BaseUserEvent{
|
||||
ID: uuid.New().String(),
|
||||
Type: string(UserDeletedEvent),
|
||||
Version: "1.0",
|
||||
Timestamp: time.Now(),
|
||||
Source: "user-service",
|
||||
AggregateID: userID,
|
||||
AggregateType: "User",
|
||||
DomainVersion: "1.0",
|
||||
CorrelationID: correlationID,
|
||||
Metadata: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": username,
|
||||
"email": email,
|
||||
"soft_delete": softDelete,
|
||||
},
|
||||
},
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
SoftDelete: softDelete,
|
||||
}
|
||||
}
|
||||
|
||||
// UserLoggedIn 用户登录事件
|
||||
type UserLoggedIn struct {
|
||||
*BaseUserEvent
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
IPAddress string `json:"ip_address"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
}
|
||||
|
||||
func NewUserLoggedInEvent(userID, username, ipAddress, userAgent, correlationID string) *UserLoggedIn {
|
||||
return &UserLoggedIn{
|
||||
BaseUserEvent: &BaseUserEvent{
|
||||
ID: uuid.New().String(),
|
||||
Type: string(UserLoggedInEvent),
|
||||
Version: "1.0",
|
||||
Timestamp: time.Now(),
|
||||
Source: "user-service",
|
||||
AggregateID: userID,
|
||||
AggregateType: "User",
|
||||
DomainVersion: "1.0",
|
||||
CorrelationID: correlationID,
|
||||
Metadata: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": username,
|
||||
"ip_address": ipAddress,
|
||||
"user_agent": userAgent,
|
||||
},
|
||||
},
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
IPAddress: ipAddress,
|
||||
UserAgent: userAgent,
|
||||
}
|
||||
}
|
||||
|
||||
// UserPasswordChanged 用户密码修改事件
|
||||
type UserPasswordChanged struct {
|
||||
*BaseUserEvent
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func NewUserPasswordChangedEvent(userID, username, correlationID string) *UserPasswordChanged {
|
||||
return &UserPasswordChanged{
|
||||
BaseUserEvent: &BaseUserEvent{
|
||||
ID: uuid.New().String(),
|
||||
Type: string(UserPasswordChangedEvent),
|
||||
Version: "1.0",
|
||||
Timestamp: time.Now(),
|
||||
Source: "user-service",
|
||||
AggregateID: userID,
|
||||
AggregateType: "User",
|
||||
DomainVersion: "1.0",
|
||||
CorrelationID: correlationID,
|
||||
Metadata: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": username,
|
||||
},
|
||||
},
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
}
|
||||
}
|
||||
|
||||
// UserStatusChanged 用户状态变更事件
|
||||
type UserStatusChanged struct {
|
||||
*BaseUserEvent
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
OldStatus entities.UserStatus `json:"old_status"`
|
||||
NewStatus entities.UserStatus `json:"new_status"`
|
||||
}
|
||||
|
||||
func NewUserStatusChangedEvent(userID, username string, oldStatus, newStatus entities.UserStatus, correlationID string) *UserStatusChanged {
|
||||
return &UserStatusChanged{
|
||||
BaseUserEvent: &BaseUserEvent{
|
||||
ID: uuid.New().String(),
|
||||
Type: string(UserStatusChangedEvent),
|
||||
Version: "1.0",
|
||||
Timestamp: time.Now(),
|
||||
Source: "user-service",
|
||||
AggregateID: userID,
|
||||
AggregateType: "User",
|
||||
DomainVersion: "1.0",
|
||||
CorrelationID: correlationID,
|
||||
Metadata: map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"username": username,
|
||||
"old_status": oldStatus,
|
||||
"new_status": newStatus,
|
||||
},
|
||||
},
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
OldStatus: oldStatus,
|
||||
NewStatus: newStatus,
|
||||
}
|
||||
}
|
||||
455
internal/domains/user/handlers/user_handler.go
Normal file
455
internal/domains/user/handlers/user_handler.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/domains/user/dto"
|
||||
"tyapi-server/internal/domains/user/services"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
)
|
||||
|
||||
// UserHandler 用户HTTP处理器
|
||||
type UserHandler struct {
|
||||
userService *services.UserService
|
||||
response interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
jwtAuth *middleware.JWTAuthMiddleware
|
||||
}
|
||||
|
||||
// NewUserHandler 创建用户处理器
|
||||
func NewUserHandler(
|
||||
userService *services.UserService,
|
||||
response interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
jwtAuth *middleware.JWTAuthMiddleware,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
userService: userService,
|
||||
response: response,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
jwtAuth: jwtAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetPath 返回处理器路径
|
||||
func (h *UserHandler) GetPath() string {
|
||||
return "/users"
|
||||
}
|
||||
|
||||
// GetMethod 返回HTTP方法
|
||||
func (h *UserHandler) GetMethod() string {
|
||||
return "GET" // 主要用于列表,具体方法在路由注册时指定
|
||||
}
|
||||
|
||||
// GetMiddlewares 返回中间件
|
||||
func (h *UserHandler) GetMiddlewares() []gin.HandlerFunc {
|
||||
return []gin.HandlerFunc{
|
||||
// 这里可以添加特定的中间件
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 主处理函数(用于列表)
|
||||
func (h *UserHandler) Handle(c *gin.Context) {
|
||||
h.List(c)
|
||||
}
|
||||
|
||||
// RequiresAuth 是否需要认证
|
||||
func (h *UserHandler) RequiresAuth() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// GetPermissions 获取所需权限
|
||||
func (h *UserHandler) GetPermissions() []string {
|
||||
return []string{"user:read"}
|
||||
}
|
||||
|
||||
// REST操作实现
|
||||
|
||||
// Create 创建用户
|
||||
func (h *UserHandler) Create(c *gin.Context) {
|
||||
var req dto.CreateUserRequest
|
||||
|
||||
// 验证请求体
|
||||
if err := h.validator.BindAndValidate(c, &req); err != nil {
|
||||
return // 响应已在验证器中处理
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user, err := h.userService.Create(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to create user", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := dto.FromEntity(user)
|
||||
h.response.Created(c, response, "User created successfully")
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (h *UserHandler) GetByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.response.BadRequest(c, "User ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户
|
||||
user, err := h.userService.GetByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user", zap.Error(err))
|
||||
h.response.NotFound(c, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := dto.FromEntity(user)
|
||||
h.response.Success(c, response)
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (h *UserHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.response.BadRequest(c, "User ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.UpdateUserRequest
|
||||
|
||||
// 验证请求体
|
||||
if err := h.validator.BindAndValidate(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新用户
|
||||
user, err := h.userService.Update(c.Request.Context(), id, &req)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to update user", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := dto.FromEntity(user)
|
||||
h.response.Success(c, response, "User updated successfully")
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (h *UserHandler) Delete(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.response.BadRequest(c, "User ID is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 删除用户
|
||||
if err := h.userService.Delete(c.Request.Context(), id); err != nil {
|
||||
h.logger.Error("Failed to delete user", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
h.response.Success(c, nil, "User deleted successfully")
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (h *UserHandler) List(c *gin.Context) {
|
||||
var req dto.UserListRequest
|
||||
|
||||
// 验证查询参数
|
||||
if err := h.validator.ValidateQuery(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
|
||||
// 构建查询选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
Sort: req.Sort,
|
||||
Order: req.Order,
|
||||
Search: req.Search,
|
||||
Filters: req.Filters,
|
||||
}
|
||||
|
||||
// 获取用户列表
|
||||
users, err := h.userService.List(c.Request.Context(), options)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user list", zap.Error(err))
|
||||
h.response.InternalError(c, "Failed to get user list")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
countOptions := interfaces.CountOptions{
|
||||
Search: req.Search,
|
||||
Filters: req.Filters,
|
||||
}
|
||||
total, err := h.userService.Count(c.Request.Context(), countOptions)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to count users", zap.Error(err))
|
||||
h.response.InternalError(c, "Failed to count users")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
userResponses := dto.FromEntities(users)
|
||||
pagination := buildPagination(req.Page, req.PageSize, total)
|
||||
|
||||
h.response.Paginated(c, userResponses, pagination)
|
||||
}
|
||||
|
||||
// Login 用户登录
|
||||
func (h *UserHandler) Login(c *gin.Context) {
|
||||
var req dto.LoginRequest
|
||||
|
||||
// 验证请求体
|
||||
if err := h.validator.BindAndValidate(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 用户登录
|
||||
user, err := h.userService.Login(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
h.logger.Error("Login failed", zap.Error(err))
|
||||
h.response.Unauthorized(c, "Invalid credentials")
|
||||
return
|
||||
}
|
||||
|
||||
// 生成JWT token
|
||||
accessToken, err := h.jwtAuth.GenerateToken(user.ID, user.Username, user.Email)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to generate token", zap.Error(err))
|
||||
h.response.InternalError(c, "Failed to generate access token")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建登录响应
|
||||
loginResponse := &dto.LoginResponse{
|
||||
User: dto.FromEntity(user),
|
||||
AccessToken: accessToken,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 86400, // 24小时,从配置获取
|
||||
}
|
||||
|
||||
h.response.Success(c, loginResponse, "Login successful")
|
||||
}
|
||||
|
||||
// Logout 用户登出
|
||||
func (h *UserHandler) Logout(c *gin.Context) {
|
||||
// 简单实现,客户端删除token即可
|
||||
// 如果需要服务端黑名单,可以在这里实现
|
||||
h.response.Success(c, nil, "Logout successful")
|
||||
}
|
||||
|
||||
// GetProfile 获取当前用户信息
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := h.userService.GetByID(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user profile", zap.Error(err))
|
||||
h.response.NotFound(c, "User not found")
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := dto.FromEntity(user)
|
||||
h.response.Success(c, response)
|
||||
}
|
||||
|
||||
// UpdateProfile 更新当前用户信息
|
||||
func (h *UserHandler) UpdateProfile(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.UpdateUserRequest
|
||||
|
||||
// 验证请求体
|
||||
if err := h.validator.BindAndValidate(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新用户
|
||||
user, err := h.userService.Update(c.Request.Context(), userID, &req)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to update profile", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 返回响应
|
||||
response := dto.FromEntity(user)
|
||||
h.response.Success(c, response, "Profile updated successfully")
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.ChangePasswordRequest
|
||||
|
||||
// 验证请求体
|
||||
if err := h.validator.BindAndValidate(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 修改密码
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), userID, &req); err != nil {
|
||||
h.logger.Error("Failed to change password", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, nil, "Password changed successfully")
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (h *UserHandler) Search(c *gin.Context) {
|
||||
var req dto.UserSearchRequest
|
||||
|
||||
// 验证查询参数
|
||||
if err := h.validator.ValidateQuery(c, &req); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 10
|
||||
}
|
||||
|
||||
// 构建查询选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: req.Page,
|
||||
PageSize: req.PageSize,
|
||||
Search: req.Query,
|
||||
}
|
||||
|
||||
// 搜索用户
|
||||
users, err := h.userService.Search(c.Request.Context(), req.Query, options)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to search users", zap.Error(err))
|
||||
h.response.InternalError(c, "Failed to search users")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取搜索结果总数
|
||||
countOptions := interfaces.CountOptions{
|
||||
Search: req.Query,
|
||||
}
|
||||
total, err := h.userService.Count(c.Request.Context(), countOptions)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to count search results", zap.Error(err))
|
||||
h.response.InternalError(c, "Failed to count search results")
|
||||
return
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
userResponses := dto.FromEntities(users)
|
||||
pagination := buildPagination(req.Page, req.PageSize, total)
|
||||
|
||||
h.response.Paginated(c, userResponses, pagination)
|
||||
}
|
||||
|
||||
// GetStats 获取用户统计
|
||||
func (h *UserHandler) GetStats(c *gin.Context) {
|
||||
stats, err := h.userService.GetStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to get user stats", zap.Error(err))
|
||||
h.response.InternalError(c, "Failed to get user statistics")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, stats)
|
||||
}
|
||||
|
||||
// 私有方法
|
||||
|
||||
// getCurrentUserID 获取当前用户ID
|
||||
func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// parsePageSize 解析页面大小
|
||||
func (h *UserHandler) parsePageSize(str string, defaultValue int) int {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if size, err := strconv.Atoi(str); err == nil && size > 0 && size <= 100 {
|
||||
return size
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// parsePage 解析页码
|
||||
func (h *UserHandler) parsePage(str string, defaultValue int) int {
|
||||
if str == "" {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
if page, err := strconv.Atoi(str); err == nil && page > 0 {
|
||||
return page
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// buildPagination 构建分页元数据
|
||||
func buildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
|
||||
totalPages := int(float64(total) / float64(pageSize))
|
||||
if float64(total)/float64(pageSize) > float64(totalPages) {
|
||||
totalPages++
|
||||
}
|
||||
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
return interfaces.PaginationMeta{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
HasNext: page < totalPages,
|
||||
HasPrev: page > 1,
|
||||
}
|
||||
}
|
||||
339
internal/domains/user/repositories/user_repository.go
Normal file
339
internal/domains/user/repositories/user_repository.go
Normal file
@@ -0,0 +1,339 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储实现
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
cache interfaces.CacheService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewUserRepository 创建用户仓储
|
||||
func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *UserRepository {
|
||||
return &UserRepository{
|
||||
db: db,
|
||||
cache: cache,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *UserRepository) Create(ctx context.Context, entity *entities.User) error {
|
||||
if err := r.db.WithContext(ctx).Create(entity).Error; err != nil {
|
||||
r.logger.Error("Failed to create user", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.invalidateUserCaches(ctx, entity.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entities.User, error) {
|
||||
// 先尝试从缓存获取
|
||||
cacheKey := r.GetCacheKey(id)
|
||||
var user entities.User
|
||||
|
||||
if err := r.cache.Get(ctx, cacheKey, &user); err == nil {
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// 从数据库获取
|
||||
if err := r.db.WithContext(ctx).Where("id = ? AND is_deleted = false", id).First(&user).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 缓存结果
|
||||
r.cache.Set(ctx, cacheKey, &user, 1*time.Hour)
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *UserRepository) Update(ctx context.Context, entity *entities.User) error {
|
||||
if err := r.db.WithContext(ctx).Save(entity).Error; err != nil {
|
||||
r.logger.Error("Failed to update user", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.invalidateUserCaches(ctx, entity.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *UserRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("Failed to delete user", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.invalidateUserCaches(ctx, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建用户
|
||||
func (r *UserRepository) CreateBatch(ctx context.Context, entities []*entities.User) error {
|
||||
if err := r.db.WithContext(ctx).CreateInBatches(entities, 100).Error; err != nil {
|
||||
r.logger.Error("Failed to create users in batch", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除列表缓存
|
||||
r.cache.DeletePattern(ctx, "users:list:*")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取用户
|
||||
func (r *UserRepository) GetByIDs(ctx context.Context, ids []string) ([]*entities.User, error) {
|
||||
var users []entities.User
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("id IN ? AND is_deleted = false", ids).
|
||||
Find(&users).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.User, len(users))
|
||||
for i := range users {
|
||||
result[i] = &users[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新用户
|
||||
func (r *UserRepository) UpdateBatch(ctx context.Context, entities []*entities.User) error {
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for _, entity := range entities {
|
||||
if err := tx.Save(entity).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除用户
|
||||
func (r *UserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("id IN ?", ids).
|
||||
Delete(&entities.User{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
for _, id := range ids {
|
||||
r.invalidateUserCaches(ctx, id)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (r *UserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]*entities.User, error) {
|
||||
// 尝试从缓存获取
|
||||
cacheKey := fmt.Sprintf("users:list:%d:%d:%s", options.Page, options.PageSize, options.Sort)
|
||||
var users []*entities.User
|
||||
|
||||
if err := r.cache.Get(ctx, cacheKey, &users); err == nil {
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// 从数据库查询
|
||||
query := r.db.WithContext(ctx).Where("is_deleted = false")
|
||||
|
||||
// 应用过滤条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Order
|
||||
if order == "" {
|
||||
order = "asc"
|
||||
}
|
||||
query = query.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
query = query.Order("created_at desc")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
var userEntities []entities.User
|
||||
if err := query.Find(&userEntities).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
users = make([]*entities.User, len(userEntities))
|
||||
for i := range userEntities {
|
||||
users[i] = &userEntities[i]
|
||||
}
|
||||
|
||||
// 缓存结果
|
||||
r.cache.Set(ctx, cacheKey, users, 30*time.Minute)
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// Count 统计用户数量
|
||||
func (r *UserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
query := r.db.WithContext(ctx).Model(&entities.User{}).Where("is_deleted = false")
|
||||
|
||||
// 应用过滤条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
var count int64
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Exists 检查用户是否存在
|
||||
func (r *UserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ? AND is_deleted = false", id).
|
||||
Count(&count).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除用户
|
||||
func (r *UserRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ?", id).
|
||||
Update("is_deleted", true).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.invalidateUserCaches(ctx, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore 恢复用户
|
||||
func (r *UserRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ?", id).
|
||||
Update("is_deleted", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.invalidateUserCaches(ctx, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *UserRepository) WithTx(tx interface{}) interfaces.Repository[*entities.User] {
|
||||
gormTx, ok := tx.(*gorm.DB)
|
||||
if !ok {
|
||||
return r
|
||||
}
|
||||
|
||||
return &UserRepository{
|
||||
db: gormTx,
|
||||
cache: r.cache,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateCache 清除缓存
|
||||
func (r *UserRepository) InvalidateCache(ctx context.Context, keys ...string) error {
|
||||
return r.cache.Delete(ctx, keys...)
|
||||
}
|
||||
|
||||
// WarmupCache 预热缓存
|
||||
func (r *UserRepository) WarmupCache(ctx context.Context) error {
|
||||
// 预热热门用户数据
|
||||
// 这里可以实现具体的预热逻辑
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCacheKey 获取缓存键
|
||||
func (r *UserRepository) GetCacheKey(id string) string {
|
||||
return fmt.Sprintf("user:%s", id)
|
||||
}
|
||||
|
||||
// FindByUsername 根据用户名查找用户
|
||||
func (r *UserRepository) FindByUsername(ctx context.Context, username string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("username = ? AND is_deleted = false", username).
|
||||
First(&user).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// FindByEmail 根据邮箱查找用户
|
||||
func (r *UserRepository) FindByEmail(ctx context.Context, email string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("email = ? AND is_deleted = false", email).
|
||||
First(&user).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// invalidateUserCaches 清除用户相关缓存
|
||||
func (r *UserRepository) invalidateUserCaches(ctx context.Context, userID string) {
|
||||
keys := []string{
|
||||
r.GetCacheKey(userID),
|
||||
}
|
||||
|
||||
r.cache.Delete(ctx, keys...)
|
||||
r.cache.DeletePattern(ctx, "users:list:*")
|
||||
}
|
||||
133
internal/domains/user/routes/user_routes.go
Normal file
133
internal/domains/user/routes/user_routes.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/domains/user/handlers"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// UserRoutes 用户路由注册器
|
||||
type UserRoutes struct {
|
||||
handler *handlers.UserHandler
|
||||
jwtAuth *middleware.JWTAuthMiddleware
|
||||
optionalAuth *middleware.OptionalAuthMiddleware
|
||||
}
|
||||
|
||||
// NewUserRoutes 创建用户路由注册器
|
||||
func NewUserRoutes(
|
||||
handler *handlers.UserHandler,
|
||||
jwtAuth *middleware.JWTAuthMiddleware,
|
||||
optionalAuth *middleware.OptionalAuthMiddleware,
|
||||
) *UserRoutes {
|
||||
return &UserRoutes{
|
||||
handler: handler,
|
||||
jwtAuth: jwtAuth,
|
||||
optionalAuth: optionalAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes 注册用户路由
|
||||
func (r *UserRoutes) RegisterRoutes(router *gin.Engine) {
|
||||
// API版本组
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
// 公开路由(不需要认证)
|
||||
public := v1.Group("/auth")
|
||||
{
|
||||
public.POST("/login", r.handler.Login)
|
||||
public.POST("/register", r.handler.Create)
|
||||
}
|
||||
|
||||
// 需要认证的路由
|
||||
protected := v1.Group("/users")
|
||||
protected.Use(r.jwtAuth.Handle())
|
||||
{
|
||||
// 用户管理(管理员)
|
||||
protected.GET("", r.handler.List)
|
||||
protected.POST("", r.handler.Create)
|
||||
protected.GET("/:id", r.handler.GetByID)
|
||||
protected.PUT("/:id", r.handler.Update)
|
||||
protected.DELETE("/:id", r.handler.Delete)
|
||||
|
||||
// 用户搜索
|
||||
protected.GET("/search", r.handler.Search)
|
||||
|
||||
// 用户统计
|
||||
protected.GET("/stats", r.handler.GetStats)
|
||||
}
|
||||
|
||||
// 用户个人操作路由
|
||||
profile := v1.Group("/profile")
|
||||
profile.Use(r.jwtAuth.Handle())
|
||||
{
|
||||
profile.GET("", r.handler.GetProfile)
|
||||
profile.PUT("", r.handler.UpdateProfile)
|
||||
profile.POST("/change-password", r.handler.ChangePassword)
|
||||
profile.POST("/logout", r.handler.Logout)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterPublicRoutes 注册公开路由
|
||||
func (r *UserRoutes) RegisterPublicRoutes(router *gin.Engine) {
|
||||
v1 := router.Group("/api/v1")
|
||||
|
||||
// 公开的用户相关路由
|
||||
public := v1.Group("/public")
|
||||
{
|
||||
// 可选认证的路由(用户可能登录也可能未登录)
|
||||
public.Use(r.optionalAuth.Handle())
|
||||
|
||||
// 这里可以添加一些公开的用户信息查询接口
|
||||
// 比如根据用户名查看公开信息(如果用户设置为公开)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterAdminRoutes 注册管理员路由
|
||||
func (r *UserRoutes) RegisterAdminRoutes(router *gin.Engine) {
|
||||
admin := router.Group("/admin/v1")
|
||||
admin.Use(r.jwtAuth.Handle())
|
||||
// 这里可以添加管理员权限检查中间件
|
||||
|
||||
// 管理员用户管理
|
||||
users := admin.Group("/users")
|
||||
{
|
||||
users.GET("", r.handler.List)
|
||||
users.GET("/:id", r.handler.GetByID)
|
||||
users.PUT("/:id", r.handler.Update)
|
||||
users.DELETE("/:id", r.handler.Delete)
|
||||
users.GET("/stats", r.handler.GetStats)
|
||||
users.GET("/search", r.handler.Search)
|
||||
|
||||
// 批量操作
|
||||
users.POST("/batch-delete", r.handleBatchDelete)
|
||||
users.POST("/batch-update", r.handleBatchUpdate)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量删除处理器
|
||||
func (r *UserRoutes) handleBatchDelete(c *gin.Context) {
|
||||
// 实现批量删除逻辑
|
||||
// 这里可以接收用户ID列表并调用服务进行批量删除
|
||||
c.JSON(200, gin.H{"message": "Batch delete not implemented yet"})
|
||||
}
|
||||
|
||||
// 批量更新处理器
|
||||
func (r *UserRoutes) handleBatchUpdate(c *gin.Context) {
|
||||
// 实现批量更新逻辑
|
||||
c.JSON(200, gin.H{"message": "Batch update not implemented yet"})
|
||||
}
|
||||
|
||||
// RegisterHealthRoutes 注册健康检查路由
|
||||
func (r *UserRoutes) RegisterHealthRoutes(router *gin.Engine) {
|
||||
health := router.Group("/health")
|
||||
{
|
||||
health.GET("/users", func(c *gin.Context) {
|
||||
// 用户服务健康检查
|
||||
c.JSON(200, gin.H{
|
||||
"service": "users",
|
||||
"status": "healthy",
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
469
internal/domains/user/services/user_service.go
Normal file
469
internal/domains/user/services/user_service.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"tyapi-server/internal/domains/user/dto"
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/events"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// UserService 用户服务实现
|
||||
type UserService struct {
|
||||
repo *repositories.UserRepository
|
||||
eventBus interfaces.EventBus
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务
|
||||
func NewUserService(
|
||||
repo *repositories.UserRepository,
|
||||
eventBus interfaces.EventBus,
|
||||
logger *zap.Logger,
|
||||
) *UserService {
|
||||
return &UserService{
|
||||
repo: repo,
|
||||
eventBus: eventBus,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (s *UserService) Name() string {
|
||||
return "user-service"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (s *UserService) Initialize(ctx context.Context) error {
|
||||
s.logger.Info("User service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (s *UserService) HealthCheck(ctx context.Context) error {
|
||||
// 简单检查:尝试查询用户数量
|
||||
_, err := s.repo.Count(ctx, interfaces.CountOptions{})
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (s *UserService) Shutdown(ctx context.Context) error {
|
||||
s.logger.Info("User service shutdown")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (s *UserService) Create(ctx context.Context, createDTO interface{}) (*entities.User, error) {
|
||||
req, ok := createDTO.(*dto.CreateUserRequest)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid DTO type for user creation")
|
||||
}
|
||||
|
||||
// 验证业务规则
|
||||
if err := s.ValidateCreate(ctx, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查用户名和邮箱是否已存在
|
||||
if err := s.checkDuplicates(ctx, req.Username, req.Email); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建用户实体
|
||||
user := req.ToEntity()
|
||||
user.ID = uuid.New().String()
|
||||
|
||||
// 加密密码
|
||||
hashedPassword, err := s.hashPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to hash password: %w", err)
|
||||
}
|
||||
user.Password = hashedPassword
|
||||
|
||||
// 保存用户
|
||||
if err := s.repo.Create(ctx, user); err != nil {
|
||||
s.logger.Error("Failed to create user", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// 发布用户创建事件
|
||||
event := events.NewUserCreatedEvent(user, s.getCorrelationID(ctx))
|
||||
if err := s.eventBus.Publish(ctx, event); err != nil {
|
||||
s.logger.Warn("Failed to publish user created event", zap.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Info("User created successfully",
|
||||
zap.String("user_id", user.ID),
|
||||
zap.String("username", user.Username))
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (s *UserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("user ID is required")
|
||||
}
|
||||
|
||||
user, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (s *UserService) Update(ctx context.Context, id string, updateDTO interface{}) (*entities.User, error) {
|
||||
req, ok := updateDTO.(*dto.UpdateUserRequest)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid DTO type for user update")
|
||||
}
|
||||
|
||||
// 验证业务规则
|
||||
if err := s.ValidateUpdate(ctx, id, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取现有用户
|
||||
user, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
|
||||
// 记录变更前的值
|
||||
oldValues := s.captureUserValues(user)
|
||||
|
||||
// 应用更新
|
||||
s.applyUserUpdates(user, req)
|
||||
|
||||
// 保存更新
|
||||
if err := s.repo.Update(ctx, user); err != nil {
|
||||
s.logger.Error("Failed to update user", zap.Error(err))
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
// 发布用户更新事件
|
||||
newValues := s.captureUserValues(user)
|
||||
changes := s.findChanges(oldValues, newValues)
|
||||
if len(changes) > 0 {
|
||||
event := events.NewUserUpdatedEvent(user.ID, changes, oldValues, newValues, s.getCorrelationID(ctx))
|
||||
if err := s.eventBus.Publish(ctx, event); err != nil {
|
||||
s.logger.Warn("Failed to publish user updated event", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Info("User updated successfully",
|
||||
zap.String("user_id", user.ID),
|
||||
zap.Int("changes", len(changes)))
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (s *UserService) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("user ID is required")
|
||||
}
|
||||
|
||||
// 获取用户信息用于事件
|
||||
user, err := s.repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
|
||||
// 软删除用户
|
||||
if err := s.repo.SoftDelete(ctx, id); err != nil {
|
||||
s.logger.Error("Failed to delete user", zap.Error(err))
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
|
||||
// 发布用户删除事件
|
||||
event := events.NewUserDeletedEvent(user.ID, user.Username, user.Email, true, s.getCorrelationID(ctx))
|
||||
if err := s.eventBus.Publish(ctx, event); err != nil {
|
||||
s.logger.Warn("Failed to publish user deleted event", zap.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Info("User deleted successfully", zap.String("user_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (s *UserService) List(ctx context.Context, options interfaces.ListOptions) ([]*entities.User, error) {
|
||||
return s.repo.List(ctx, options)
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (s *UserService) Search(ctx context.Context, query string, options interfaces.ListOptions) ([]*entities.User, error) {
|
||||
// 设置搜索关键字
|
||||
searchOptions := options
|
||||
searchOptions.Search = query
|
||||
|
||||
return s.repo.List(ctx, searchOptions)
|
||||
}
|
||||
|
||||
// Count 统计用户数量
|
||||
func (s *UserService) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
return s.repo.Count(ctx, options)
|
||||
}
|
||||
|
||||
// Validate 验证用户实体
|
||||
func (s *UserService) Validate(ctx context.Context, entity *entities.User) error {
|
||||
return entity.Validate()
|
||||
}
|
||||
|
||||
// ValidateCreate 验证创建请求
|
||||
func (s *UserService) ValidateCreate(ctx context.Context, createDTO interface{}) error {
|
||||
req, ok := createDTO.(*dto.CreateUserRequest)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid DTO type")
|
||||
}
|
||||
|
||||
// 基础验证已经由binding标签处理,这里添加业务规则验证
|
||||
if req.Username == "admin" || req.Username == "root" {
|
||||
return fmt.Errorf("username '%s' is reserved", req.Username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateUpdate 验证更新请求
|
||||
func (s *UserService) ValidateUpdate(ctx context.Context, id string, updateDTO interface{}) error {
|
||||
_, ok := updateDTO.(*dto.UpdateUserRequest)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid DTO type")
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
return fmt.Errorf("user ID is required")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 业务方法
|
||||
|
||||
// Login 用户登录
|
||||
func (s *UserService) Login(ctx context.Context, loginReq *dto.LoginRequest) (*entities.User, error) {
|
||||
// 根据用户名或邮箱查找用户
|
||||
var user *entities.User
|
||||
var err error
|
||||
|
||||
if s.isEmail(loginReq.Login) {
|
||||
user, err = s.repo.FindByEmail(ctx, loginReq.Login)
|
||||
} else {
|
||||
user, err = s.repo.FindByUsername(ctx, loginReq.Login)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// 验证密码
|
||||
if !s.checkPassword(loginReq.Password, user.Password) {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.CanLogin() {
|
||||
return nil, fmt.Errorf("account is disabled or suspended")
|
||||
}
|
||||
|
||||
// 更新最后登录时间
|
||||
user.UpdateLastLogin()
|
||||
if err := s.repo.Update(ctx, user); err != nil {
|
||||
s.logger.Warn("Failed to update last login time", zap.Error(err))
|
||||
}
|
||||
|
||||
// 发布登录事件
|
||||
event := events.NewUserLoggedInEvent(
|
||||
user.ID, user.Username,
|
||||
s.getClientIP(ctx), s.getUserAgent(ctx),
|
||||
s.getCorrelationID(ctx))
|
||||
if err := s.eventBus.Publish(ctx, event); err != nil {
|
||||
s.logger.Warn("Failed to publish user logged in event", zap.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Info("User logged in successfully",
|
||||
zap.String("user_id", user.ID),
|
||||
zap.String("username", user.Username))
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
|
||||
// 获取用户
|
||||
user, err := s.repo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if !s.checkPassword(req.OldPassword, user.Password) {
|
||||
return fmt.Errorf("current password is incorrect")
|
||||
}
|
||||
|
||||
// 加密新密码
|
||||
hashedPassword, err := s.hashPassword(req.NewPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to hash new password: %w", err)
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
user.Password = hashedPassword
|
||||
if err := s.repo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("failed to update password: %w", err)
|
||||
}
|
||||
|
||||
// 发布密码修改事件
|
||||
event := events.NewUserPasswordChangedEvent(user.ID, user.Username, s.getCorrelationID(ctx))
|
||||
if err := s.eventBus.Publish(ctx, event); err != nil {
|
||||
s.logger.Warn("Failed to publish password changed event", zap.Error(err))
|
||||
}
|
||||
|
||||
s.logger.Info("Password changed successfully", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取用户统计
|
||||
func (s *UserService) GetStats(ctx context.Context) (*dto.UserStatsResponse, error) {
|
||||
total, err := s.repo.Count(ctx, interfaces.CountOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 这里可以并行查询不同状态的用户数量
|
||||
// 简化实现,返回基础统计
|
||||
return &dto.UserStatsResponse{
|
||||
TotalUsers: total,
|
||||
ActiveUsers: total, // 简化
|
||||
InactiveUsers: 0,
|
||||
SuspendedUsers: 0,
|
||||
NewUsersToday: 0,
|
||||
NewUsersWeek: 0,
|
||||
NewUsersMonth: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 私有方法
|
||||
|
||||
// checkDuplicates 检查重复的用户名和邮箱
|
||||
func (s *UserService) checkDuplicates(ctx context.Context, username, email string) error {
|
||||
// 检查用户名
|
||||
if existingUser, err := s.repo.FindByUsername(ctx, username); err == nil && existingUser != nil {
|
||||
return fmt.Errorf("username already exists")
|
||||
}
|
||||
|
||||
// 检查邮箱
|
||||
if existingUser, err := s.repo.FindByEmail(ctx, email); err == nil && existingUser != nil {
|
||||
return fmt.Errorf("email already exists")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// hashPassword 加密密码
|
||||
func (s *UserService) hashPassword(password string) (string, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(hash), nil
|
||||
}
|
||||
|
||||
// checkPassword 验证密码
|
||||
func (s *UserService) checkPassword(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// isEmail 检查是否为邮箱格式
|
||||
func (s *UserService) isEmail(str string) bool {
|
||||
return len(str) > 0 && len(str) < 255 &&
|
||||
len(str) > 5 &&
|
||||
str[len(str)-4:] != ".." &&
|
||||
(len(str) > 6 && str[len(str)-4:] == ".com") ||
|
||||
(len(str) > 5 && str[len(str)-3:] == ".cn") ||
|
||||
(len(str) > 6 && str[len(str)-4:] == ".org") ||
|
||||
(len(str) > 6 && str[len(str)-4:] == ".net")
|
||||
// 简化的邮箱检查,实际应该使用正则表达式
|
||||
}
|
||||
|
||||
// applyUserUpdates 应用用户更新
|
||||
func (s *UserService) applyUserUpdates(user *entities.User, req *dto.UpdateUserRequest) {
|
||||
if req.FirstName != nil {
|
||||
user.FirstName = *req.FirstName
|
||||
}
|
||||
if req.LastName != nil {
|
||||
user.LastName = *req.LastName
|
||||
}
|
||||
if req.Phone != nil {
|
||||
user.Phone = *req.Phone
|
||||
}
|
||||
if req.Avatar != nil {
|
||||
user.Avatar = *req.Avatar
|
||||
}
|
||||
user.UpdatedAt = time.Now()
|
||||
}
|
||||
|
||||
// captureUserValues 捕获用户值用于变更比较
|
||||
func (s *UserService) captureUserValues(user *entities.User) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"first_name": user.FirstName,
|
||||
"last_name": user.LastName,
|
||||
"phone": user.Phone,
|
||||
"avatar": user.Avatar,
|
||||
}
|
||||
}
|
||||
|
||||
// findChanges 找出变更的字段
|
||||
func (s *UserService) findChanges(oldValues, newValues map[string]interface{}) map[string]interface{} {
|
||||
changes := make(map[string]interface{})
|
||||
|
||||
for key, newValue := range newValues {
|
||||
if oldValue, exists := oldValues[key]; !exists || oldValue != newValue {
|
||||
changes[key] = newValue
|
||||
}
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
|
||||
// getCorrelationID 获取关联ID
|
||||
func (s *UserService) getCorrelationID(ctx context.Context) string {
|
||||
if id := ctx.Value("correlation_id"); id != nil {
|
||||
if correlationID, ok := id.(string); ok {
|
||||
return correlationID
|
||||
}
|
||||
}
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// getClientIP 获取客户端IP
|
||||
func (s *UserService) getClientIP(ctx context.Context) string {
|
||||
if ip := ctx.Value("client_ip"); ip != nil {
|
||||
if clientIP, ok := ip.(string); ok {
|
||||
return clientIP
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// getUserAgent 获取用户代理
|
||||
func (s *UserService) getUserAgent(ctx context.Context) string {
|
||||
if ua := ctx.Value("user_agent"); ua != nil {
|
||||
if userAgent, ok := ua.(string); ok {
|
||||
return userAgent
|
||||
}
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
284
internal/shared/cache/redis_cache.go
vendored
Normal file
284
internal/shared/cache/redis_cache.go
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// RedisCache Redis缓存实现
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
logger *zap.Logger
|
||||
prefix string
|
||||
|
||||
// 统计信息
|
||||
hits int64
|
||||
misses int64
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存实例
|
||||
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
logger: logger,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (r *RedisCache) Name() string {
|
||||
return "redis-cache"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (r *RedisCache) Initialize(ctx context.Context) error {
|
||||
// 测试连接
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to connect to Redis", zap.Error(err))
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Redis cache service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (r *RedisCache) HealthCheck(ctx context.Context) error {
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (r *RedisCache) Shutdown(ctx context.Context) error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
val, err := r.client.Get(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
r.misses++
|
||||
return fmt.Errorf("cache miss: key %s not found", key)
|
||||
}
|
||||
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.hits++
|
||||
return json.Unmarshal([]byte(val), dest)
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value: %w", err)
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
case string:
|
||||
expiration, _ = time.ParseDuration(v)
|
||||
default:
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
|
||||
err = r.client.Set(ctx, fullKey, data, expiration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
err := r.client.Del(ctx, fullKeys...).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
count, err := r.client.Exists(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取
|
||||
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
values, err := r.client.MGet(ctx, fullKeys...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for i, val := range values {
|
||||
if val != nil {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
|
||||
result[keys[i]] = data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置
|
||||
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
default:
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
for key, value := range data {
|
||||
fullKey := r.getFullKey(key)
|
||||
jsonData, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pipe.Set(ctx, fullKey, jsonData, expiration)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeletePattern 按模式删除
|
||||
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
return r.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys 获取匹配的键
|
||||
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 移除前缀
|
||||
result := make([]string, len(keys))
|
||||
prefixLen := len(r.prefix) + 1 // +1 for ":"
|
||||
for i, key := range keys {
|
||||
if len(key) > prefixLen {
|
||||
result[i] = key[prefixLen:]
|
||||
} else {
|
||||
result[i] = key
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计
|
||||
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
|
||||
dbSize, _ := r.client.DBSize(ctx).Result()
|
||||
|
||||
return interfaces.CacheStats{
|
||||
Hits: r.hits,
|
||||
Misses: r.misses,
|
||||
Keys: dbSize,
|
||||
Memory: 0, // 暂时设为0,后续可解析Redis info
|
||||
Connections: 0, // 暂时设为0,后续可解析Redis info
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getFullKey 获取完整键名
|
||||
func (r *RedisCache) getFullKey(key string) string {
|
||||
if r.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", r.prefix, key)
|
||||
}
|
||||
|
||||
// Flush 清空所有缓存
|
||||
func (r *RedisCache) Flush(ctx context.Context) error {
|
||||
if r.prefix == "" {
|
||||
return r.client.FlushDB(ctx).Err()
|
||||
}
|
||||
|
||||
// 只删除带前缀的键
|
||||
return r.DeletePattern(ctx, "*")
|
||||
}
|
||||
|
||||
// GetClient 获取原始Redis客户端
|
||||
func (r *RedisCache) GetClient() *redis.Client {
|
||||
return r.client
|
||||
}
|
||||
195
internal/shared/database/database.go
Normal file
195
internal/shared/database/database.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Name string
|
||||
SSLMode string
|
||||
Timezone string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
}
|
||||
|
||||
// DB 数据库包装器
|
||||
type DB struct {
|
||||
*gorm.DB
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewConnection 创建新的数据库连接
|
||||
func NewConnection(config Config) (*DB, error) {
|
||||
// 构建DSN
|
||||
dsn := buildDSN(config)
|
||||
|
||||
// 配置GORM
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true, // 使用单数表名
|
||||
},
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
return &DB{
|
||||
DB: db,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildDSN 构建数据库连接字符串
|
||||
func buildDSN(config Config) string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s",
|
||||
config.Host,
|
||||
config.User,
|
||||
config.Password,
|
||||
config.Name,
|
||||
config.Port,
|
||||
config.SSLMode,
|
||||
config.Timezone,
|
||||
)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// Ping 检查数据库连接
|
||||
func (db *DB) Ping() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Ping()
|
||||
}
|
||||
|
||||
// GetStats 获取连接池统计信息
|
||||
func (db *DB) GetStats() (map[string]interface{}, error) {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := sqlDB.Stats()
|
||||
return map[string]interface{}{
|
||||
"max_open_connections": stats.MaxOpenConnections,
|
||||
"open_connections": stats.OpenConnections,
|
||||
"in_use": stats.InUse,
|
||||
"idle": stats.Idle,
|
||||
"wait_count": stats.WaitCount,
|
||||
"wait_duration": stats.WaitDuration,
|
||||
"max_idle_closed": stats.MaxIdleClosed,
|
||||
"max_idle_time_closed": stats.MaxIdleTimeClosed,
|
||||
"max_lifetime_closed": stats.MaxLifetimeClosed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginTx 开始事务
|
||||
func (db *DB) BeginTx() *gorm.DB {
|
||||
return db.DB.Begin()
|
||||
}
|
||||
|
||||
// Migrate 执行数据库迁移
|
||||
func (db *DB) Migrate(models ...interface{}) error {
|
||||
return db.DB.AutoMigrate(models...)
|
||||
}
|
||||
|
||||
// IsHealthy 检查数据库健康状态
|
||||
func (db *DB) IsHealthy() bool {
|
||||
return db.Ping() == nil
|
||||
}
|
||||
|
||||
// WithContext 返回带上下文的数据库实例
|
||||
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
|
||||
if c, ok := ctx.(context.Context); ok {
|
||||
return db.DB.WithContext(c)
|
||||
}
|
||||
return db.DB
|
||||
}
|
||||
|
||||
// 事务包装器
|
||||
type TxWrapper struct {
|
||||
tx *gorm.DB
|
||||
}
|
||||
|
||||
// NewTxWrapper 创建事务包装器
|
||||
func (db *DB) NewTxWrapper() *TxWrapper {
|
||||
return &TxWrapper{
|
||||
tx: db.BeginTx(),
|
||||
}
|
||||
}
|
||||
|
||||
// Commit 提交事务
|
||||
func (tx *TxWrapper) Commit() error {
|
||||
return tx.tx.Commit().Error
|
||||
}
|
||||
|
||||
// Rollback 回滚事务
|
||||
func (tx *TxWrapper) Rollback() error {
|
||||
return tx.tx.Rollback().Error
|
||||
}
|
||||
|
||||
// GetDB 获取事务数据库实例
|
||||
func (tx *TxWrapper) GetDB() *gorm.DB {
|
||||
return tx.tx
|
||||
}
|
||||
|
||||
// WithTx 在事务中执行函数
|
||||
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
|
||||
tx := db.BeginTx()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
313
internal/shared/events/event_bus.go
Normal file
313
internal/shared/events/event_bus.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// MemoryEventBus 内存事件总线实现
|
||||
type MemoryEventBus struct {
|
||||
subscribers map[string][]interfaces.EventHandler
|
||||
mutex sync.RWMutex
|
||||
logger *zap.Logger
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
eventQueue chan eventTask
|
||||
workerCount int
|
||||
}
|
||||
|
||||
// eventTask 事件任务
|
||||
type eventTask struct {
|
||||
event interfaces.Event
|
||||
handler interfaces.EventHandler
|
||||
retries int
|
||||
}
|
||||
|
||||
// NewMemoryEventBus 创建内存事件总线
|
||||
func NewMemoryEventBus(logger *zap.Logger, workerCount int) *MemoryEventBus {
|
||||
if workerCount <= 0 {
|
||||
workerCount = 5 // 默认5个工作协程
|
||||
}
|
||||
|
||||
return &MemoryEventBus{
|
||||
subscribers: make(map[string][]interfaces.EventHandler),
|
||||
logger: logger,
|
||||
eventQueue: make(chan eventTask, 1000), // 缓冲1000个事件
|
||||
workerCount: workerCount,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (bus *MemoryEventBus) Name() string {
|
||||
return "memory-event-bus"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (bus *MemoryEventBus) Initialize(ctx context.Context) error {
|
||||
bus.logger.Info("Memory event bus service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (bus *MemoryEventBus) HealthCheck(ctx context.Context) error {
|
||||
if !bus.running {
|
||||
return fmt.Errorf("event bus is not running")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (bus *MemoryEventBus) Shutdown(ctx context.Context) error {
|
||||
bus.Stop(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动事件总线
|
||||
func (bus *MemoryEventBus) Start(ctx context.Context) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
if bus.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
bus.running = true
|
||||
|
||||
// 启动工作协程
|
||||
for i := 0; i < bus.workerCount; i++ {
|
||||
go bus.worker(i)
|
||||
}
|
||||
|
||||
bus.logger.Info("Event bus started", zap.Int("workers", bus.workerCount))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止事件总线
|
||||
func (bus *MemoryEventBus) Stop(ctx context.Context) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
if !bus.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
bus.running = false
|
||||
close(bus.stopCh)
|
||||
|
||||
// 等待所有工作协程结束或超时
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
time.Sleep(5 * time.Second) // 给工作协程5秒时间结束
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
bus.logger.Info("Event bus stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish 发布事件(同步)
|
||||
func (bus *MemoryEventBus) Publish(ctx context.Context, event interfaces.Event) error {
|
||||
bus.mutex.RLock()
|
||||
handlers := bus.subscribers[event.GetType()]
|
||||
bus.mutex.RUnlock()
|
||||
|
||||
if len(handlers) == 0 {
|
||||
bus.logger.Debug("No handlers for event type", zap.String("type", event.GetType()))
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, handler := range handlers {
|
||||
if handler.IsAsync() {
|
||||
// 异步处理
|
||||
select {
|
||||
case bus.eventQueue <- eventTask{event: event, handler: handler, retries: 0}:
|
||||
default:
|
||||
bus.logger.Warn("Event queue is full, dropping event",
|
||||
zap.String("type", event.GetType()),
|
||||
zap.String("handler", handler.GetName()))
|
||||
}
|
||||
} else {
|
||||
// 同步处理
|
||||
if err := bus.handleEventWithRetry(ctx, event, handler); err != nil {
|
||||
bus.logger.Error("Failed to handle event synchronously",
|
||||
zap.String("type", event.GetType()),
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishBatch 批量发布事件
|
||||
func (bus *MemoryEventBus) PublishBatch(ctx context.Context, events []interfaces.Event) error {
|
||||
for _, event := range events {
|
||||
if err := bus.Publish(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe 订阅事件
|
||||
func (bus *MemoryEventBus) Subscribe(eventType string, handler interfaces.EventHandler) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
handlers := bus.subscribers[eventType]
|
||||
|
||||
// 检查是否已经订阅
|
||||
for _, h := range handlers {
|
||||
if h.GetName() == handler.GetName() {
|
||||
return fmt.Errorf("handler %s already subscribed to event type %s", handler.GetName(), eventType)
|
||||
}
|
||||
}
|
||||
|
||||
bus.subscribers[eventType] = append(handlers, handler)
|
||||
|
||||
bus.logger.Info("Handler subscribed to event",
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.String("event_type", eventType))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe 取消订阅
|
||||
func (bus *MemoryEventBus) Unsubscribe(eventType string, handler interfaces.EventHandler) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
handlers := bus.subscribers[eventType]
|
||||
for i, h := range handlers {
|
||||
if h.GetName() == handler.GetName() {
|
||||
// 删除处理器
|
||||
bus.subscribers[eventType] = append(handlers[:i], handlers[i+1:]...)
|
||||
|
||||
bus.logger.Info("Handler unsubscribed from event",
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.String("event_type", eventType))
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("handler %s not found for event type %s", handler.GetName(), eventType)
|
||||
}
|
||||
|
||||
// GetSubscribers 获取订阅者
|
||||
func (bus *MemoryEventBus) GetSubscribers(eventType string) []interfaces.EventHandler {
|
||||
bus.mutex.RLock()
|
||||
defer bus.mutex.RUnlock()
|
||||
|
||||
handlers := bus.subscribers[eventType]
|
||||
result := make([]interfaces.EventHandler, len(handlers))
|
||||
copy(result, handlers)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (bus *MemoryEventBus) worker(id int) {
|
||||
bus.logger.Debug("Event worker started", zap.Int("worker_id", id))
|
||||
|
||||
for {
|
||||
select {
|
||||
case task := <-bus.eventQueue:
|
||||
bus.processEventTask(task)
|
||||
case <-bus.stopCh:
|
||||
bus.logger.Debug("Event worker stopped", zap.Int("worker_id", id))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processEventTask 处理事件任务
|
||||
func (bus *MemoryEventBus) processEventTask(task eventTask) {
|
||||
ctx := context.Background()
|
||||
|
||||
err := bus.handleEventWithRetry(ctx, task.event, task.handler)
|
||||
if err != nil {
|
||||
retryConfig := task.handler.GetRetryConfig()
|
||||
|
||||
if task.retries < retryConfig.MaxRetries {
|
||||
// 重试
|
||||
delay := time.Duration(float64(retryConfig.RetryDelay) *
|
||||
(1 + retryConfig.BackoffFactor*float64(task.retries)))
|
||||
|
||||
if delay > retryConfig.MaxDelay {
|
||||
delay = retryConfig.MaxDelay
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(delay)
|
||||
task.retries++
|
||||
|
||||
select {
|
||||
case bus.eventQueue <- task:
|
||||
default:
|
||||
bus.logger.Error("Failed to requeue event for retry",
|
||||
zap.String("type", task.event.GetType()),
|
||||
zap.String("handler", task.handler.GetName()),
|
||||
zap.Int("retries", task.retries))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
bus.logger.Error("Event processing failed after max retries",
|
||||
zap.String("type", task.event.GetType()),
|
||||
zap.String("handler", task.handler.GetName()),
|
||||
zap.Int("retries", task.retries),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleEventWithRetry 处理事件并支持重试
|
||||
func (bus *MemoryEventBus) handleEventWithRetry(ctx context.Context, event interfaces.Event, handler interfaces.EventHandler) error {
|
||||
start := time.Now()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(start)
|
||||
bus.logger.Debug("Event handled",
|
||||
zap.String("type", event.GetType()),
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.Duration("duration", duration))
|
||||
}()
|
||||
|
||||
return handler.Handle(ctx, event)
|
||||
}
|
||||
|
||||
// GetStats 获取事件总线统计信息
|
||||
func (bus *MemoryEventBus) GetStats() map[string]interface{} {
|
||||
bus.mutex.RLock()
|
||||
defer bus.mutex.RUnlock()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"running": bus.running,
|
||||
"worker_count": bus.workerCount,
|
||||
"queue_length": len(bus.eventQueue),
|
||||
"queue_capacity": cap(bus.eventQueue),
|
||||
"event_types": len(bus.subscribers),
|
||||
}
|
||||
|
||||
// 各事件类型的订阅者数量
|
||||
eventTypes := make(map[string]int)
|
||||
for eventType, handlers := range bus.subscribers {
|
||||
eventTypes[eventType] = len(handlers)
|
||||
}
|
||||
stats["subscribers"] = eventTypes
|
||||
|
||||
return stats
|
||||
}
|
||||
282
internal/shared/health/health_checker.go
Normal file
282
internal/shared/health/health_checker.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HealthChecker 健康检查器实现
|
||||
type HealthChecker struct {
|
||||
services map[string]interfaces.Service
|
||||
cache map[string]*interfaces.HealthStatus
|
||||
cacheTTL time.Duration
|
||||
mutex sync.RWMutex
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewHealthChecker 创建健康检查器
|
||||
func NewHealthChecker(logger *zap.Logger) *HealthChecker {
|
||||
return &HealthChecker{
|
||||
services: make(map[string]interfaces.Service),
|
||||
cache: make(map[string]*interfaces.HealthStatus),
|
||||
cacheTTL: 30 * time.Second, // 缓存30秒
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterService 注册服务
|
||||
func (h *HealthChecker) RegisterService(service interfaces.Service) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.services[service.Name()] = service
|
||||
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
|
||||
}
|
||||
|
||||
// CheckHealth 检查单个服务健康状态
|
||||
func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *interfaces.HealthStatus {
|
||||
h.mutex.RLock()
|
||||
service, exists := h.services[serviceName]
|
||||
if !exists {
|
||||
h.mutex.RUnlock()
|
||||
return &interfaces.HealthStatus{
|
||||
Status: "DOWN",
|
||||
Message: "Service not found",
|
||||
Details: map[string]interface{}{"error": "service not registered"},
|
||||
CheckedAt: time.Now().Unix(),
|
||||
ResponseTime: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// 检查缓存
|
||||
if cached, exists := h.cache[serviceName]; exists {
|
||||
if time.Since(time.Unix(cached.CheckedAt, 0)) < h.cacheTTL {
|
||||
h.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
}
|
||||
h.mutex.RUnlock()
|
||||
|
||||
// 执行健康检查
|
||||
start := time.Now()
|
||||
status := &interfaces.HealthStatus{
|
||||
CheckedAt: start.Unix(),
|
||||
}
|
||||
|
||||
// 设置超时上下文
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := service.HealthCheck(checkCtx)
|
||||
responseTime := time.Since(start).Milliseconds()
|
||||
status.ResponseTime = responseTime
|
||||
|
||||
if err != nil {
|
||||
status.Status = "DOWN"
|
||||
status.Message = "Health check failed"
|
||||
status.Details = map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"service_name": serviceName,
|
||||
"check_time": start.Format(time.RFC3339),
|
||||
}
|
||||
h.logger.Warn("Service health check failed",
|
||||
zap.String("service", serviceName),
|
||||
zap.Error(err),
|
||||
zap.Int64("response_time_ms", responseTime))
|
||||
} else {
|
||||
status.Status = "UP"
|
||||
status.Message = "Service is healthy"
|
||||
status.Details = map[string]interface{}{
|
||||
"service_name": serviceName,
|
||||
"check_time": start.Format(time.RFC3339),
|
||||
}
|
||||
h.logger.Debug("Service health check passed",
|
||||
zap.String("service", serviceName),
|
||||
zap.Int64("response_time_ms", responseTime))
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
h.mutex.Lock()
|
||||
h.cache[serviceName] = status
|
||||
h.mutex.Unlock()
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// CheckAllHealth 检查所有服务的健康状态
|
||||
func (h *HealthChecker) CheckAllHealth(ctx context.Context) map[string]*interfaces.HealthStatus {
|
||||
h.mutex.RLock()
|
||||
serviceNames := make([]string, 0, len(h.services))
|
||||
for name := range h.services {
|
||||
serviceNames = append(serviceNames, name)
|
||||
}
|
||||
h.mutex.RUnlock()
|
||||
|
||||
results := make(map[string]*interfaces.HealthStatus)
|
||||
var wg sync.WaitGroup
|
||||
var mutex sync.Mutex
|
||||
|
||||
// 并发检查所有服务
|
||||
for _, serviceName := range serviceNames {
|
||||
wg.Add(1)
|
||||
go func(name string) {
|
||||
defer wg.Done()
|
||||
status := h.CheckHealth(ctx, name)
|
||||
|
||||
mutex.Lock()
|
||||
results[name] = status
|
||||
mutex.Unlock()
|
||||
}(serviceName)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
// GetOverallStatus 获取整体健康状态
|
||||
func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.HealthStatus {
|
||||
allStatus := h.CheckAllHealth(ctx)
|
||||
|
||||
overall := &interfaces.HealthStatus{
|
||||
CheckedAt: time.Now().Unix(),
|
||||
ResponseTime: 0,
|
||||
Details: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
var totalResponseTime int64
|
||||
healthyCount := 0
|
||||
totalCount := len(allStatus)
|
||||
|
||||
for serviceName, status := range allStatus {
|
||||
overall.Details[serviceName] = map[string]interface{}{
|
||||
"status": status.Status,
|
||||
"message": status.Message,
|
||||
"response_time": status.ResponseTime,
|
||||
}
|
||||
|
||||
totalResponseTime += status.ResponseTime
|
||||
if status.Status == "UP" {
|
||||
healthyCount++
|
||||
}
|
||||
}
|
||||
|
||||
if totalCount > 0 {
|
||||
overall.ResponseTime = totalResponseTime / int64(totalCount)
|
||||
}
|
||||
|
||||
// 确定整体状态
|
||||
if healthyCount == totalCount {
|
||||
overall.Status = "UP"
|
||||
overall.Message = "All services are healthy"
|
||||
} else if healthyCount == 0 {
|
||||
overall.Status = "DOWN"
|
||||
overall.Message = "All services are down"
|
||||
} else {
|
||||
overall.Status = "DEGRADED"
|
||||
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
|
||||
}
|
||||
|
||||
return overall
|
||||
}
|
||||
|
||||
// GetServiceNames 获取所有注册的服务名称
|
||||
func (h *HealthChecker) GetServiceNames() []string {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(h.services))
|
||||
for name := range h.services {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RemoveService 移除服务
|
||||
func (h *HealthChecker) RemoveService(serviceName string) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
delete(h.services, serviceName)
|
||||
delete(h.cache, serviceName)
|
||||
|
||||
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
|
||||
}
|
||||
|
||||
// ClearCache 清除缓存
|
||||
func (h *HealthChecker) ClearCache() {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.cache = make(map[string]*interfaces.HealthStatus)
|
||||
h.logger.Debug("Health check cache cleared")
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计
|
||||
func (h *HealthChecker) GetCacheStats() map[string]interface{} {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_services": len(h.services),
|
||||
"cached_results": len(h.cache),
|
||||
"cache_ttl_seconds": h.cacheTTL.Seconds(),
|
||||
}
|
||||
|
||||
// 计算缓存命中率
|
||||
if len(h.services) > 0 {
|
||||
hitRate := float64(len(h.cache)) / float64(len(h.services)) * 100
|
||||
stats["cache_hit_rate"] = fmt.Sprintf("%.2f%%", hitRate)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// SetCacheTTL 设置缓存TTL
|
||||
func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.cacheTTL = ttl
|
||||
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
|
||||
}
|
||||
|
||||
// StartPeriodicCheck 启动定期健康检查
|
||||
func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.logger.Info("Stopped periodic health check")
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.performPeriodicCheck(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performPeriodicCheck 执行定期检查
|
||||
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
|
||||
overall := h.GetOverallStatus(ctx)
|
||||
|
||||
h.logger.Info("Periodic health check completed",
|
||||
zap.String("overall_status", overall.Status),
|
||||
zap.String("message", overall.Message),
|
||||
zap.Int64("response_time_ms", overall.ResponseTime))
|
||||
|
||||
// 如果有服务下线,记录警告
|
||||
if overall.Status != "UP" {
|
||||
h.logger.Warn("Some services are not healthy",
|
||||
zap.String("status", overall.Status),
|
||||
zap.Any("details", overall.Details))
|
||||
}
|
||||
}
|
||||
260
internal/shared/http/response.go
Normal file
260
internal/shared/http/response.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ResponseBuilder 响应构建器实现
|
||||
type ResponseBuilder struct{}
|
||||
|
||||
// NewResponseBuilder 创建响应构建器
|
||||
func NewResponseBuilder() interfaces.ResponseBuilder {
|
||||
return &ResponseBuilder{}
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Success"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Created 创建成功响应
|
||||
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Created successfully"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
|
||||
// 根据错误类型确定状态码
|
||||
statusCode := http.StatusInternalServerError
|
||||
message := "Internal server error"
|
||||
errorDetail := err.Error()
|
||||
|
||||
// 这里可以根据不同的错误类型设置不同的状态码
|
||||
// 例如:ValidationError -> 400, NotFoundError -> 404, etc.
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
Errors: errorDetail,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// BadRequest 400错误响应
|
||||
func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
response.Errors = errors[0]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Unauthorized 401错误响应
|
||||
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
|
||||
msg := "Unauthorized"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
}
|
||||
|
||||
// Forbidden 403错误响应
|
||||
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
|
||||
msg := "Forbidden"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusForbidden, response)
|
||||
}
|
||||
|
||||
// NotFound 404错误响应
|
||||
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
|
||||
msg := "Resource not found"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Conflict 409错误响应
|
||||
func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusConflict, response)
|
||||
}
|
||||
|
||||
// InternalError 500错误响应
|
||||
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
|
||||
msg := "Internal server error"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Paginated 分页响应
|
||||
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: "Success",
|
||||
Data: data,
|
||||
Pagination: &pagination,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// getRequestID 从上下文获取请求ID
|
||||
func (r *ResponseBuilder) getRequestID(c *gin.Context) string {
|
||||
if requestID, exists := c.Get("request_id"); exists {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildPagination 构建分页元数据
|
||||
func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
|
||||
totalPages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
return interfaces.PaginationMeta{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
HasNext: page < totalPages,
|
||||
HasPrev: page > 1,
|
||||
}
|
||||
}
|
||||
|
||||
// CustomResponse 自定义响应
|
||||
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: statusCode >= 200 && statusCode < 300,
|
||||
Message: http.StatusText(statusCode),
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// ValidationError 验证错误响应
|
||||
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: "Validation failed",
|
||||
Errors: errors,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnprocessableEntity, response)
|
||||
}
|
||||
|
||||
// TooManyRequests 限流错误响应
|
||||
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
|
||||
msg := "Too many requests"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Meta: map[string]interface{}{
|
||||
"retry_after": "60s",
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, response)
|
||||
}
|
||||
258
internal/shared/http/router.go
Normal file
258
internal/shared/http/router.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GinRouter Gin路由器实现
|
||||
type GinRouter struct {
|
||||
engine *gin.Engine
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
middlewares []interfaces.Middleware
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewGinRouter 创建Gin路由器
|
||||
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *GinRouter {
|
||||
// 设置Gin模式
|
||||
if cfg.App.IsProduction() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
} else {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
|
||||
// 创建Gin引擎
|
||||
engine := gin.New()
|
||||
|
||||
return &GinRouter{
|
||||
engine: engine,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
middlewares: make([]interfaces.Middleware, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册处理器
|
||||
func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
|
||||
// 应用处理器中间件
|
||||
middlewares := handler.GetMiddlewares()
|
||||
|
||||
// 注册路由
|
||||
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
|
||||
|
||||
r.logger.Info("Registered HTTP handler",
|
||||
zap.String("method", handler.GetMethod()),
|
||||
zap.String("path", handler.GetPath()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterMiddleware 注册中间件
|
||||
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
|
||||
r.middlewares = append(r.middlewares, middleware)
|
||||
|
||||
r.logger.Info("Registered middleware",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterGroup 注册路由组
|
||||
func (r *GinRouter) RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes {
|
||||
return r.engine.Group(prefix, middlewares...)
|
||||
}
|
||||
|
||||
// GetRoutes 获取路由信息
|
||||
func (r *GinRouter) GetRoutes() gin.RoutesInfo {
|
||||
return r.engine.Routes()
|
||||
}
|
||||
|
||||
// Start 启动路由器
|
||||
func (r *GinRouter) Start(addr string) error {
|
||||
// 应用中间件(按优先级排序)
|
||||
r.applyMiddlewares()
|
||||
|
||||
// 创建HTTP服务器
|
||||
r.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r.engine,
|
||||
ReadTimeout: r.config.Server.ReadTimeout,
|
||||
WriteTimeout: r.config.Server.WriteTimeout,
|
||||
IdleTimeout: r.config.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
|
||||
|
||||
// 启动服务器
|
||||
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止路由器
|
||||
func (r *GinRouter) Stop(ctx context.Context) error {
|
||||
if r.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Info("Stopping HTTP server...")
|
||||
|
||||
// 优雅关闭服务器
|
||||
if err := r.server.Shutdown(ctx); err != nil {
|
||||
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("HTTP server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEngine 获取Gin引擎
|
||||
func (r *GinRouter) GetEngine() *gin.Engine {
|
||||
return r.engine
|
||||
}
|
||||
|
||||
// applyMiddlewares 应用中间件
|
||||
func (r *GinRouter) applyMiddlewares() {
|
||||
// 按优先级排序中间件
|
||||
sort.Slice(r.middlewares, func(i, j int) bool {
|
||||
return r.middlewares[i].GetPriority() > r.middlewares[j].GetPriority()
|
||||
})
|
||||
|
||||
// 应用全局中间件
|
||||
for _, middleware := range r.middlewares {
|
||||
if middleware.IsGlobal() {
|
||||
r.engine.Use(middleware.Handle())
|
||||
r.logger.Debug("Applied global middleware",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupDefaultRoutes 设置默认路由
|
||||
func (r *GinRouter) SetupDefaultRoutes() {
|
||||
// 健康检查
|
||||
r.engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"service": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
})
|
||||
})
|
||||
|
||||
// API信息
|
||||
r.engine.GET("/info", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
"environment": r.config.App.Env,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// 404处理
|
||||
r.engine.NoRoute(func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Route not found",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// 405处理
|
||||
r.engine.NoMethod(func(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"success": false,
|
||||
"message": "Method not allowed",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// PrintRoutes 打印路由信息
|
||||
func (r *GinRouter) PrintRoutes() {
|
||||
routes := r.GetRoutes()
|
||||
|
||||
r.logger.Info("Registered routes:")
|
||||
for _, route := range routes {
|
||||
r.logger.Info("Route",
|
||||
zap.String("method", route.Method),
|
||||
zap.String("path", route.Path),
|
||||
zap.String("handler", route.Handler))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取路由器统计信息
|
||||
func (r *GinRouter) GetStats() map[string]interface{} {
|
||||
routes := r.GetRoutes()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_routes": len(routes),
|
||||
"total_middlewares": len(r.middlewares),
|
||||
"server_config": map[string]interface{}{
|
||||
"read_timeout": r.config.Server.ReadTimeout,
|
||||
"write_timeout": r.config.Server.WriteTimeout,
|
||||
"idle_timeout": r.config.Server.IdleTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
// 按方法统计路由数量
|
||||
methodStats := make(map[string]int)
|
||||
for _, route := range routes {
|
||||
methodStats[route.Method]++
|
||||
}
|
||||
stats["routes_by_method"] = methodStats
|
||||
|
||||
// 中间件统计
|
||||
middlewareStats := make([]map[string]interface{}, 0, len(r.middlewares))
|
||||
for _, middleware := range r.middlewares {
|
||||
middlewareStats = append(middlewareStats, map[string]interface{}{
|
||||
"name": middleware.GetName(),
|
||||
"priority": middleware.GetPriority(),
|
||||
"global": middleware.IsGlobal(),
|
||||
})
|
||||
}
|
||||
stats["middlewares"] = middlewareStats
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// EnableMetrics 启用指标收集
|
||||
func (r *GinRouter) EnableMetrics(collector interfaces.MetricsCollector) {
|
||||
r.engine.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
duration := time.Since(start).Seconds()
|
||||
collector.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableProfiling 启用性能分析
|
||||
func (r *GinRouter) EnableProfiling() {
|
||||
if r.config.Development.EnableProfiler {
|
||||
// 这里可以集成pprof
|
||||
r.logger.Info("Profiling enabled")
|
||||
}
|
||||
}
|
||||
273
internal/shared/http/validator.go
Normal file
273
internal/shared/http/validator.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// RequestValidator 请求验证器实现
|
||||
type RequestValidator struct {
|
||||
validator *validator.Validate
|
||||
response interfaces.ResponseBuilder
|
||||
}
|
||||
|
||||
// NewRequestValidator 创建请求验证器
|
||||
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
|
||||
v := validator.New()
|
||||
|
||||
// 注册自定义验证器
|
||||
registerCustomValidators(v)
|
||||
|
||||
return &RequestValidator{
|
||||
validator: v,
|
||||
response: response,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证请求体
|
||||
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateQuery 验证查询参数
|
||||
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindQuery(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid query parameters", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateParam 验证路径参数
|
||||
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindUri(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid path parameters", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindAndValidate 绑定并验证请求
|
||||
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
|
||||
// 绑定请求体
|
||||
if err := c.ShouldBindJSON(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid request body", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证数据
|
||||
return v.Validate(c, dto)
|
||||
}
|
||||
|
||||
// formatValidationErrors 格式化验证错误
|
||||
func (v *RequestValidator) formatValidationErrors(err error) map[string][]string {
|
||||
errors := make(map[string][]string)
|
||||
|
||||
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, fieldError := range validationErrors {
|
||||
fieldName := v.getFieldName(fieldError)
|
||||
errorMessage := v.getErrorMessage(fieldError)
|
||||
|
||||
if _, exists := errors[fieldName]; !exists {
|
||||
errors[fieldName] = []string{}
|
||||
}
|
||||
errors[fieldName] = append(errors[fieldName], errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// getFieldName 获取字段名(JSON标签优先)
|
||||
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
|
||||
// 可以通过反射获取JSON标签,这里简化处理
|
||||
fieldName := fieldError.Field()
|
||||
|
||||
// 转换为snake_case(可选)
|
||||
return v.toSnakeCase(fieldName)
|
||||
}
|
||||
|
||||
// getErrorMessage 获取错误消息
|
||||
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
|
||||
field := fieldError.Field()
|
||||
tag := fieldError.Tag()
|
||||
param := fieldError.Param()
|
||||
|
||||
switch tag {
|
||||
case "required":
|
||||
return fmt.Sprintf("%s is required", field)
|
||||
case "email":
|
||||
return fmt.Sprintf("%s must be a valid email address", field)
|
||||
case "min":
|
||||
return fmt.Sprintf("%s must be at least %s characters", field, param)
|
||||
case "max":
|
||||
return fmt.Sprintf("%s must be at most %s characters", field, param)
|
||||
case "len":
|
||||
return fmt.Sprintf("%s must be exactly %s characters", field, param)
|
||||
case "gt":
|
||||
return fmt.Sprintf("%s must be greater than %s", field, param)
|
||||
case "gte":
|
||||
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
|
||||
case "lt":
|
||||
return fmt.Sprintf("%s must be less than %s", field, param)
|
||||
case "lte":
|
||||
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
|
||||
case "oneof":
|
||||
return fmt.Sprintf("%s must be one of [%s]", field, param)
|
||||
case "url":
|
||||
return fmt.Sprintf("%s must be a valid URL", field)
|
||||
case "alpha":
|
||||
return fmt.Sprintf("%s must contain only alphabetic characters", field)
|
||||
case "alphanum":
|
||||
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
|
||||
case "numeric":
|
||||
return fmt.Sprintf("%s must be numeric", field)
|
||||
case "phone":
|
||||
return fmt.Sprintf("%s must be a valid phone number", field)
|
||||
case "username":
|
||||
return fmt.Sprintf("%s must be a valid username", field)
|
||||
default:
|
||||
return fmt.Sprintf("%s is invalid", field)
|
||||
}
|
||||
}
|
||||
|
||||
// toSnakeCase 转换为snake_case
|
||||
func (v *RequestValidator) toSnakeCase(str string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range str {
|
||||
if i > 0 && (r >= 'A' && r <= 'Z') {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// registerCustomValidators 注册自定义验证器
|
||||
func registerCustomValidators(v *validator.Validate) {
|
||||
// 注册手机号验证器
|
||||
v.RegisterValidation("phone", validatePhone)
|
||||
|
||||
// 注册用户名验证器
|
||||
v.RegisterValidation("username", validateUsername)
|
||||
|
||||
// 注册密码强度验证器
|
||||
v.RegisterValidation("strong_password", validateStrongPassword)
|
||||
}
|
||||
|
||||
// validatePhone 验证手机号
|
||||
func validatePhone(fl validator.FieldLevel) bool {
|
||||
phone := fl.Field().String()
|
||||
if phone == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 简单的手机号验证(可根据需要完善)
|
||||
if len(phone) < 10 || len(phone) > 15 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否以+开头或全是数字
|
||||
if strings.HasPrefix(phone, "+") {
|
||||
phone = phone[1:]
|
||||
}
|
||||
|
||||
for _, r := range phone {
|
||||
if r < '0' || r > '9' {
|
||||
if r != '-' && r != ' ' && r != '(' && r != ')' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateUsername 验证用户名
|
||||
func validateUsername(fl validator.FieldLevel) bool {
|
||||
username := fl.Field().String()
|
||||
if username == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 用户名规则:3-30个字符,只能包含字母、数字、下划线,不能以数字开头
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不能以数字开头
|
||||
if username[0] >= '0' && username[0] <= '9' {
|
||||
return false
|
||||
}
|
||||
|
||||
// 只能包含字母、数字、下划线
|
||||
for _, r := range username {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateStrongPassword 验证密码强度
|
||||
func validateStrongPassword(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
if password == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 密码强度规则:至少8个字符,包含大小写字母、数字
|
||||
if len(password) < 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
|
||||
for _, r := range password {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
hasUpper = true
|
||||
case r >= 'a' && r <= 'z':
|
||||
hasLower = true
|
||||
case r >= '0' && r <= '9':
|
||||
hasDigit = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasUpper && hasLower && hasDigit
|
||||
}
|
||||
|
||||
// ValidateStruct 直接验证结构体(不通过HTTP上下文)
|
||||
func (v *RequestValidator) ValidateStruct(dto interface{}) error {
|
||||
return v.validator.Struct(dto)
|
||||
}
|
||||
|
||||
// GetValidator 获取原始验证器(用于特殊情况)
|
||||
func (v *RequestValidator) GetValidator() *validator.Validate {
|
||||
return v.validator
|
||||
}
|
||||
92
internal/shared/interfaces/event.go
Normal file
92
internal/shared/interfaces/event.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Event 事件接口
|
||||
type Event interface {
|
||||
// 事件基础信息
|
||||
GetID() string
|
||||
GetType() string
|
||||
GetVersion() string
|
||||
GetTimestamp() time.Time
|
||||
|
||||
// 事件数据
|
||||
GetPayload() interface{}
|
||||
GetMetadata() map[string]interface{}
|
||||
|
||||
// 事件来源
|
||||
GetSource() string
|
||||
GetAggregateID() string
|
||||
GetAggregateType() string
|
||||
|
||||
// 序列化
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) error
|
||||
}
|
||||
|
||||
// EventHandler 事件处理器接口
|
||||
type EventHandler interface {
|
||||
// 处理器标识
|
||||
GetName() string
|
||||
GetEventTypes() []string
|
||||
|
||||
// 事件处理
|
||||
Handle(ctx context.Context, event Event) error
|
||||
|
||||
// 处理器配置
|
||||
IsAsync() bool
|
||||
GetRetryConfig() RetryConfig
|
||||
}
|
||||
|
||||
// DomainEvent 领域事件基础接口
|
||||
type DomainEvent interface {
|
||||
Event
|
||||
|
||||
// 领域特定信息
|
||||
GetDomainVersion() string
|
||||
GetCausationID() string
|
||||
GetCorrelationID() string
|
||||
}
|
||||
|
||||
// RetryConfig 重试配置
|
||||
type RetryConfig struct {
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryDelay time.Duration `json:"retry_delay"`
|
||||
BackoffFactor float64 `json:"backoff_factor"`
|
||||
MaxDelay time.Duration `json:"max_delay"`
|
||||
}
|
||||
|
||||
// EventStore 事件存储接口
|
||||
type EventStore interface {
|
||||
// 事件存储
|
||||
SaveEvent(ctx context.Context, event Event) error
|
||||
SaveEvents(ctx context.Context, events []Event) error
|
||||
|
||||
// 事件查询
|
||||
GetEvents(ctx context.Context, aggregateID string, fromVersion int) ([]Event, error)
|
||||
GetEventsByType(ctx context.Context, eventType string, limit int) ([]Event, error)
|
||||
GetEventsSince(ctx context.Context, timestamp time.Time, limit int) ([]Event, error)
|
||||
|
||||
// 快照支持
|
||||
SaveSnapshot(ctx context.Context, aggregateID string, snapshot interface{}) error
|
||||
GetSnapshot(ctx context.Context, aggregateID string) (interface{}, error)
|
||||
}
|
||||
|
||||
// EventBus 事件总线接口
|
||||
type EventBus interface {
|
||||
// 事件发布
|
||||
Publish(ctx context.Context, event Event) error
|
||||
PublishBatch(ctx context.Context, events []Event) error
|
||||
|
||||
// 事件订阅
|
||||
Subscribe(eventType string, handler EventHandler) error
|
||||
Unsubscribe(eventType string, handler EventHandler) error
|
||||
|
||||
// 订阅管理
|
||||
GetSubscribers(eventType string) []EventHandler
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
}
|
||||
152
internal/shared/interfaces/http.go
Normal file
152
internal/shared/interfaces/http.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HTTPHandler HTTP处理器接口
|
||||
type HTTPHandler interface {
|
||||
// 处理器信息
|
||||
GetPath() string
|
||||
GetMethod() string
|
||||
GetMiddlewares() []gin.HandlerFunc
|
||||
|
||||
// 处理函数
|
||||
Handle(c *gin.Context)
|
||||
|
||||
// 权限验证
|
||||
RequiresAuth() bool
|
||||
GetPermissions() []string
|
||||
}
|
||||
|
||||
// RESTHandler REST风格处理器接口
|
||||
type RESTHandler interface {
|
||||
HTTPHandler
|
||||
|
||||
// CRUD操作
|
||||
Create(c *gin.Context)
|
||||
GetByID(c *gin.Context)
|
||||
Update(c *gin.Context)
|
||||
Delete(c *gin.Context)
|
||||
List(c *gin.Context)
|
||||
}
|
||||
|
||||
// Middleware 中间件接口
|
||||
type Middleware interface {
|
||||
// 中间件名称
|
||||
GetName() string
|
||||
// 中间件优先级
|
||||
GetPriority() int
|
||||
// 中间件处理函数
|
||||
Handle() gin.HandlerFunc
|
||||
// 是否全局中间件
|
||||
IsGlobal() bool
|
||||
}
|
||||
|
||||
// Router 路由器接口
|
||||
type Router interface {
|
||||
// 路由注册
|
||||
RegisterHandler(handler HTTPHandler) error
|
||||
RegisterMiddleware(middleware Middleware) error
|
||||
RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes
|
||||
|
||||
// 路由管理
|
||||
GetRoutes() gin.RoutesInfo
|
||||
Start(addr string) error
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// 引擎获取
|
||||
GetEngine() *gin.Engine
|
||||
}
|
||||
|
||||
// ResponseBuilder 响应构建器接口
|
||||
type ResponseBuilder interface {
|
||||
// 成功响应
|
||||
Success(c *gin.Context, data interface{}, message ...string)
|
||||
Created(c *gin.Context, data interface{}, message ...string)
|
||||
|
||||
// 错误响应
|
||||
Error(c *gin.Context, err error)
|
||||
BadRequest(c *gin.Context, message string, errors ...interface{})
|
||||
Unauthorized(c *gin.Context, message ...string)
|
||||
Forbidden(c *gin.Context, message ...string)
|
||||
NotFound(c *gin.Context, message ...string)
|
||||
Conflict(c *gin.Context, message string)
|
||||
InternalError(c *gin.Context, message ...string)
|
||||
|
||||
// 分页响应
|
||||
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
|
||||
}
|
||||
|
||||
// RequestValidator 请求验证器接口
|
||||
type RequestValidator interface {
|
||||
// 验证请求
|
||||
Validate(c *gin.Context, dto interface{}) error
|
||||
ValidateQuery(c *gin.Context, dto interface{}) error
|
||||
ValidateParam(c *gin.Context, dto interface{}) error
|
||||
|
||||
// 绑定和验证
|
||||
BindAndValidate(c *gin.Context, dto interface{}) error
|
||||
}
|
||||
|
||||
// PaginationMeta 分页元数据
|
||||
type PaginationMeta struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Total int64 `json:"total"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
HasNext bool `json:"has_next"`
|
||||
HasPrev bool `json:"has_prev"`
|
||||
}
|
||||
|
||||
// APIResponse 标准API响应结构
|
||||
type APIResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Errors interface{} `json:"errors,omitempty"`
|
||||
Pagination *PaginationMeta `json:"pagination,omitempty"`
|
||||
Meta map[string]interface{} `json:"meta,omitempty"`
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthChecker 健康检查器接口
|
||||
type HealthChecker interface {
|
||||
// 健康检查
|
||||
CheckHealth(ctx context.Context) HealthStatus
|
||||
GetName() string
|
||||
GetDependencies() []string
|
||||
}
|
||||
|
||||
// HealthStatus 健康状态
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"` // UP, DOWN, DEGRADED
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
CheckedAt int64 `json:"checked_at"`
|
||||
ResponseTime int64 `json:"response_time_ms"`
|
||||
}
|
||||
|
||||
// MetricsCollector 指标收集器接口
|
||||
type MetricsCollector interface {
|
||||
// HTTP指标
|
||||
RecordHTTPRequest(method, path string, status int, duration float64)
|
||||
RecordHTTPDuration(method, path string, duration float64)
|
||||
|
||||
// 业务指标
|
||||
IncrementCounter(name string, labels map[string]string)
|
||||
RecordGauge(name string, value float64, labels map[string]string)
|
||||
RecordHistogram(name string, value float64, labels map[string]string)
|
||||
|
||||
// 自定义指标
|
||||
RegisterCounter(name, help string, labels []string) error
|
||||
RegisterGauge(name, help string, labels []string) error
|
||||
RegisterHistogram(name, help string, labels []string, buckets []float64) error
|
||||
|
||||
// 指标导出
|
||||
GetHandler() http.Handler
|
||||
}
|
||||
74
internal/shared/interfaces/repository.go
Normal file
74
internal/shared/interfaces/repository.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Entity 通用实体接口
|
||||
type Entity interface {
|
||||
GetID() string
|
||||
GetCreatedAt() time.Time
|
||||
GetUpdatedAt() time.Time
|
||||
}
|
||||
|
||||
// BaseRepository 基础仓储接口
|
||||
type BaseRepository interface {
|
||||
// 基础操作
|
||||
Delete(ctx context.Context, id string) error
|
||||
Count(ctx context.Context, options CountOptions) (int64, error)
|
||||
Exists(ctx context.Context, id string) (bool, error)
|
||||
|
||||
// 软删除支持
|
||||
SoftDelete(ctx context.Context, id string) error
|
||||
Restore(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// Repository 通用仓储接口,支持泛型
|
||||
type Repository[T any] interface {
|
||||
BaseRepository
|
||||
|
||||
// 基础CRUD操作
|
||||
Create(ctx context.Context, entity T) error
|
||||
GetByID(ctx context.Context, id string) (T, error)
|
||||
Update(ctx context.Context, entity T) error
|
||||
|
||||
// 批量操作
|
||||
CreateBatch(ctx context.Context, entities []T) error
|
||||
GetByIDs(ctx context.Context, ids []string) ([]T, error)
|
||||
UpdateBatch(ctx context.Context, entities []T) error
|
||||
DeleteBatch(ctx context.Context, ids []string) error
|
||||
|
||||
// 查询操作
|
||||
List(ctx context.Context, options ListOptions) ([]T, error)
|
||||
|
||||
// 事务支持
|
||||
WithTx(tx interface{}) Repository[T]
|
||||
}
|
||||
|
||||
// ListOptions 列表查询选项
|
||||
type ListOptions struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Sort string `json:"sort"`
|
||||
Order string `json:"order"`
|
||||
Filters map[string]interface{} `json:"filters"`
|
||||
Search string `json:"search"`
|
||||
Include []string `json:"include"`
|
||||
}
|
||||
|
||||
// CountOptions 计数查询选项
|
||||
type CountOptions struct {
|
||||
Filters map[string]interface{} `json:"filters"`
|
||||
Search string `json:"search"`
|
||||
}
|
||||
|
||||
// CachedRepository 支持缓存的仓储接口
|
||||
type CachedRepository[T Entity] interface {
|
||||
Repository[T]
|
||||
|
||||
// 缓存操作
|
||||
InvalidateCache(ctx context.Context, keys ...string) error
|
||||
WarmupCache(ctx context.Context) error
|
||||
GetCacheKey(id string) string
|
||||
}
|
||||
101
internal/shared/interfaces/service.go
Normal file
101
internal/shared/interfaces/service.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Service 通用服务接口
|
||||
type Service interface {
|
||||
// 服务名称
|
||||
Name() string
|
||||
// 服务初始化
|
||||
Initialize(ctx context.Context) error
|
||||
// 服务健康检查
|
||||
HealthCheck(ctx context.Context) error
|
||||
// 服务关闭
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DomainService 领域服务接口,支持泛型
|
||||
type DomainService[T Entity] interface {
|
||||
Service
|
||||
|
||||
// 基础业务操作
|
||||
Create(ctx context.Context, dto interface{}) (*T, error)
|
||||
GetByID(ctx context.Context, id string) (*T, error)
|
||||
Update(ctx context.Context, id string, dto interface{}) (*T, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// 列表和查询
|
||||
List(ctx context.Context, options ListOptions) ([]*T, error)
|
||||
Search(ctx context.Context, query string, options ListOptions) ([]*T, error)
|
||||
Count(ctx context.Context, options CountOptions) (int64, error)
|
||||
|
||||
// 业务规则验证
|
||||
Validate(ctx context.Context, entity *T) error
|
||||
ValidateCreate(ctx context.Context, dto interface{}) error
|
||||
ValidateUpdate(ctx context.Context, id string, dto interface{}) error
|
||||
}
|
||||
|
||||
// EventService 事件服务接口
|
||||
type EventService interface {
|
||||
Service
|
||||
|
||||
// 事件发布
|
||||
Publish(ctx context.Context, event Event) error
|
||||
PublishBatch(ctx context.Context, events []Event) error
|
||||
|
||||
// 事件订阅
|
||||
Subscribe(eventType string, handler EventHandler) error
|
||||
Unsubscribe(eventType string, handler EventHandler) error
|
||||
|
||||
// 异步处理
|
||||
PublishAsync(ctx context.Context, event Event) error
|
||||
}
|
||||
|
||||
// CacheService 缓存服务接口
|
||||
type CacheService interface {
|
||||
Service
|
||||
|
||||
// 基础缓存操作
|
||||
Get(ctx context.Context, key string, dest interface{}) error
|
||||
Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error
|
||||
Delete(ctx context.Context, keys ...string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// 批量操作
|
||||
GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error)
|
||||
SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error
|
||||
|
||||
// 模式操作
|
||||
DeletePattern(ctx context.Context, pattern string) error
|
||||
Keys(ctx context.Context, pattern string) ([]string, error)
|
||||
|
||||
// 缓存统计
|
||||
Stats(ctx context.Context) (CacheStats, error)
|
||||
}
|
||||
|
||||
// CacheStats 缓存统计信息
|
||||
type CacheStats struct {
|
||||
Hits int64 `json:"hits"`
|
||||
Misses int64 `json:"misses"`
|
||||
Keys int64 `json:"keys"`
|
||||
Memory int64 `json:"memory"`
|
||||
Connections int64 `json:"connections"`
|
||||
}
|
||||
|
||||
// TransactionService 事务服务接口
|
||||
type TransactionService interface {
|
||||
Service
|
||||
|
||||
// 事务操作
|
||||
Begin(ctx context.Context) (Transaction, error)
|
||||
RunInTransaction(ctx context.Context, fn func(Transaction) error) error
|
||||
}
|
||||
|
||||
// Transaction 事务接口
|
||||
type Transaction interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
GetDB() interface{}
|
||||
}
|
||||
241
internal/shared/logger/logger.go
Normal file
241
internal/shared/logger/logger.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Logger 日志接口
|
||||
type Logger interface {
|
||||
Debug(msg string, fields ...zapcore.Field)
|
||||
Info(msg string, fields ...zapcore.Field)
|
||||
Warn(msg string, fields ...zapcore.Field)
|
||||
Error(msg string, fields ...zapcore.Field)
|
||||
Fatal(msg string, fields ...zapcore.Field)
|
||||
Panic(msg string, fields ...zapcore.Field)
|
||||
|
||||
With(fields ...zapcore.Field) Logger
|
||||
WithContext(ctx context.Context) Logger
|
||||
Sync() error
|
||||
}
|
||||
|
||||
// ZapLogger Zap日志实现
|
||||
type ZapLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Config 日志配置
|
||||
type Config struct {
|
||||
Level string
|
||||
Format string
|
||||
Output string
|
||||
FilePath string
|
||||
MaxSize int
|
||||
MaxBackups int
|
||||
MaxAge int
|
||||
Compress bool
|
||||
}
|
||||
|
||||
// NewLogger 创建新的日志实例
|
||||
func NewLogger(config Config) (Logger, error) {
|
||||
// 设置日志级别
|
||||
level, err := zapcore.ParseLevel(config.Level)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的日志级别: %w", err)
|
||||
}
|
||||
|
||||
// 配置编码器
|
||||
var encoder zapcore.Encoder
|
||||
encoderConfig := getEncoderConfig()
|
||||
|
||||
switch config.Format {
|
||||
case "json":
|
||||
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||
case "console":
|
||||
encoder = zapcore.NewConsoleEncoder(encoderConfig)
|
||||
default:
|
||||
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||
}
|
||||
|
||||
// 配置输出
|
||||
var writeSyncer zapcore.WriteSyncer
|
||||
switch config.Output {
|
||||
case "stdout":
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
case "stderr":
|
||||
writeSyncer = zapcore.AddSync(os.Stderr)
|
||||
case "file":
|
||||
if config.FilePath == "" {
|
||||
config.FilePath = "logs/app.log"
|
||||
}
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建日志目录失败: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开日志文件失败: %w", err)
|
||||
}
|
||||
writeSyncer = zapcore.AddSync(file)
|
||||
default:
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
}
|
||||
|
||||
// 创建核心
|
||||
core := zapcore.NewCore(encoder, writeSyncer, level)
|
||||
|
||||
// 创建logger
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
return &ZapLogger{
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getEncoderConfig 获取编码器配置
|
||||
func getEncoderConfig() zapcore.EncoderConfig {
|
||||
return zapcore.EncoderConfig{
|
||||
TimeKey: "timestamp",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "message",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug 调试日志
|
||||
func (l *ZapLogger) Debug(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Debug(msg, fields...)
|
||||
}
|
||||
|
||||
// Info 信息日志
|
||||
func (l *ZapLogger) Info(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Info(msg, fields...)
|
||||
}
|
||||
|
||||
// Warn 警告日志
|
||||
func (l *ZapLogger) Warn(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Warn(msg, fields...)
|
||||
}
|
||||
|
||||
// Error 错误日志
|
||||
func (l *ZapLogger) Error(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Error(msg, fields...)
|
||||
}
|
||||
|
||||
// Fatal 致命错误日志
|
||||
func (l *ZapLogger) Fatal(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Fatal(msg, fields...)
|
||||
}
|
||||
|
||||
// Panic 恐慌日志
|
||||
func (l *ZapLogger) Panic(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Panic(msg, fields...)
|
||||
}
|
||||
|
||||
// With 添加字段
|
||||
func (l *ZapLogger) With(fields ...zapcore.Field) Logger {
|
||||
return &ZapLogger{
|
||||
logger: l.logger.With(fields...),
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext 从上下文添加字段
|
||||
func (l *ZapLogger) WithContext(ctx context.Context) Logger {
|
||||
// 从上下文中提取常用字段
|
||||
fields := []zapcore.Field{}
|
||||
|
||||
if traceID := getTraceIDFromContext(ctx); traceID != "" {
|
||||
fields = append(fields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
if userID := getUserIDFromContext(ctx); userID != "" {
|
||||
fields = append(fields, zap.String("user_id", userID))
|
||||
}
|
||||
|
||||
if requestID := getRequestIDFromContext(ctx); requestID != "" {
|
||||
fields = append(fields, zap.String("request_id", requestID))
|
||||
}
|
||||
|
||||
return l.With(fields...)
|
||||
}
|
||||
|
||||
// Sync 同步日志
|
||||
func (l *ZapLogger) Sync() error {
|
||||
return l.logger.Sync()
|
||||
}
|
||||
|
||||
// getTraceIDFromContext 从上下文获取追踪ID
|
||||
func getTraceIDFromContext(ctx context.Context) string {
|
||||
if traceID := ctx.Value("trace_id"); traceID != nil {
|
||||
if id, ok := traceID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从上下文获取用户ID
|
||||
func getUserIDFromContext(ctx context.Context) string {
|
||||
if userID := ctx.Value("user_id"); userID != nil {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRequestIDFromContext 从上下文获取请求ID
|
||||
func getRequestIDFromContext(ctx context.Context) string {
|
||||
if requestID := ctx.Value("request_id"); requestID != nil {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Field 创建日志字段的便捷函数
|
||||
func String(key, val string) zapcore.Field {
|
||||
return zap.String(key, val)
|
||||
}
|
||||
|
||||
func Int(key string, val int) zapcore.Field {
|
||||
return zap.Int(key, val)
|
||||
}
|
||||
|
||||
func Int64(key string, val int64) zapcore.Field {
|
||||
return zap.Int64(key, val)
|
||||
}
|
||||
|
||||
func Float64(key string, val float64) zapcore.Field {
|
||||
return zap.Float64(key, val)
|
||||
}
|
||||
|
||||
func Bool(key string, val bool) zapcore.Field {
|
||||
return zap.Bool(key, val)
|
||||
}
|
||||
|
||||
func Error(err error) zapcore.Field {
|
||||
return zap.Error(err)
|
||||
}
|
||||
|
||||
func Any(key string, val interface{}) zapcore.Field {
|
||||
return zap.Any(key, val)
|
||||
}
|
||||
|
||||
func Duration(key string, val interface{}) zapcore.Field {
|
||||
return zap.Any(key, val)
|
||||
}
|
||||
261
internal/shared/middleware/auth.go
Normal file
261
internal/shared/middleware/auth.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT认证中间件
|
||||
type JWTAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewJWTAuthMiddleware 创建JWT认证中间件
|
||||
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
|
||||
return &JWTAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *JWTAuthMiddleware) GetName() string {
|
||||
return "jwt_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *JWTAuthMiddleware) GetPriority() int {
|
||||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "Missing authorization header")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "Invalid authorization header format")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "Missing token")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("Invalid token",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息添加到上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *JWTAuthMiddleware) IsGlobal() bool {
|
||||
return false // 不是全局中间件,需要手动应用到需要认证的路由
|
||||
}
|
||||
|
||||
// JWTClaims JWT声明结构
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// validateToken 验证JWT token
|
||||
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// respondUnauthorized 返回未授权响应
|
||||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Unauthorized",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "tyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"tyapi-client"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新token
|
||||
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Issuer: "tyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"tyapi-refresh"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 验证刷新token
|
||||
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// 检查是否为刷新token
|
||||
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims.Subject, nil
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
|
||||
type OptionalAuthMiddleware struct {
|
||||
jwtAuth *JWTAuthMiddleware
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware 创建可选认证中间件
|
||||
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
|
||||
return &OptionalAuthMiddleware{
|
||||
jwtAuth: jwtAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *OptionalAuthMiddleware) GetName() string {
|
||||
return "optional_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *OptionalAuthMiddleware) GetPriority() int {
|
||||
return 60 // 与JWT认证中间件相同
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// 没有认证头部,设置匿名用户标识
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取并验证token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
claims, err := m.jwtAuth.validateToken(tokenString)
|
||||
if err != nil {
|
||||
// token无效,但不返回错误,设置为未认证
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// token有效,设置用户信息
|
||||
c.Set("is_authenticated", true)
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *OptionalAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
104
internal/shared/middleware/cors.go
Normal file
104
internal/shared/middleware/cors.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware CORS中间件
|
||||
type CORSMiddleware struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewCORSMiddleware 创建CORS中间件
|
||||
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
|
||||
return &CORSMiddleware{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *CORSMiddleware) GetName() string {
|
||||
return "cors"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *CORSMiddleware) GetPriority() int {
|
||||
return 100 // 高优先级,最先执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.config.Development.EnableCors {
|
||||
// 如果没有启用CORS,返回空处理函数
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
config := cors.Config{
|
||||
AllowAllOrigins: false,
|
||||
AllowOrigins: m.getAllowedOrigins(),
|
||||
AllowMethods: m.getAllowedMethods(),
|
||||
AllowHeaders: m.getAllowedHeaders(),
|
||||
ExposeHeaders: []string{
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24小时
|
||||
}
|
||||
|
||||
return cors.New(config)
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *CORSMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getAllowedOrigins 获取允许的来源
|
||||
func (m *CORSMiddleware) getAllowedOrigins() []string {
|
||||
if m.config.Development.CorsOrigins == "" {
|
||||
return []string{"http://localhost:3000", "http://localhost:8080"}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的origins字符串
|
||||
return []string{m.config.Development.CorsOrigins}
|
||||
}
|
||||
|
||||
// getAllowedMethods 获取允许的方法
|
||||
func (m *CORSMiddleware) getAllowedMethods() []string {
|
||||
if m.config.Development.CorsMethods == "" {
|
||||
return []string{
|
||||
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的methods字符串
|
||||
return []string{m.config.Development.CorsMethods}
|
||||
}
|
||||
|
||||
// getAllowedHeaders 获取允许的头部
|
||||
func (m *CORSMiddleware) getAllowedHeaders() []string {
|
||||
if m.config.Development.CorsHeaders == "" {
|
||||
return []string{
|
||||
"Origin",
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"Accept",
|
||||
"Accept-Encoding",
|
||||
"Accept-Language",
|
||||
"X-Request-ID",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的headers字符串
|
||||
return []string{m.config.Development.CorsHeaders}
|
||||
}
|
||||
166
internal/shared/middleware/ratelimit.go
Normal file
166
internal/shared/middleware/ratelimit.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RateLimitMiddleware) GetName() string {
|
||||
return "ratelimit"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RateLimitMiddleware) GetPriority() int {
|
||||
return 90 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端标识(IP地址)
|
||||
clientID := m.getClientID(c)
|
||||
|
||||
// 获取或创建限流器
|
||||
limiter := m.getLimiter(clientID)
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
c.Header("Retry-After", "60")
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": "Rate limit exceeded",
|
||||
"error": "Too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RateLimitMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getClientID 获取客户端标识
|
||||
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
|
||||
// 优先使用X-Forwarded-For头部
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
return xff
|
||||
}
|
||||
|
||||
// 使用X-Real-IP头部
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 使用RemoteAddr
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// getLimiter 获取或创建限流器
|
||||
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
|
||||
m.mutex.RLock()
|
||||
limiter, exists := m.limiters[clientID]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists := m.limiters[clientID]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 创建新的限流器
|
||||
// rate.Every计算每个请求之间的间隔
|
||||
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
|
||||
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
|
||||
|
||||
m.limiters[clientID] = limiter
|
||||
|
||||
// 启动清理协程(仅第一次创建时)
|
||||
if len(m.limiters) == 1 {
|
||||
go m.cleanupRoutine()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanupRoutine() {
|
||||
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanup() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for clientID, limiter := range m.limiters {
|
||||
// 如果限流器在过去1小时内没有被使用,则删除它
|
||||
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
|
||||
delete(m.limiters, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取限流统计
|
||||
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"active_limiters": len(m.limiters),
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests": m.config.RateLimit.Requests,
|
||||
"window": m.config.RateLimit.Window,
|
||||
"burst": m.config.RateLimit.Burst,
|
||||
},
|
||||
}
|
||||
}
|
||||
241
internal/shared/middleware/request_logger.go
Normal file
241
internal/shared/middleware/request_logger.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 请求日志中间件
|
||||
type RequestLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
|
||||
return &RequestLoggerMiddleware{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestLoggerMiddleware) GetName() string {
|
||||
return "request_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestLoggerMiddleware) GetPriority() int {
|
||||
return 80 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 使用zap logger记录请求信息
|
||||
m.logger.Info("HTTP Request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
)
|
||||
|
||||
// 返回空字符串,因为我们已经用zap记录了
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestLoggerMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestIDMiddleware 请求ID中间件
|
||||
type RequestIDMiddleware struct{}
|
||||
|
||||
// NewRequestIDMiddleware 创建请求ID中间件
|
||||
func NewRequestIDMiddleware() *RequestIDMiddleware {
|
||||
return &RequestIDMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestIDMiddleware) GetName() string {
|
||||
return "request_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestIDMiddleware) GetPriority() int {
|
||||
return 95 // 最高优先级,第一个执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 设置请求ID到上下文和响应头
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 添加到响应头,方便客户端追踪
|
||||
c.Writer.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware 安全头部中间件
|
||||
type SecurityHeadersMiddleware struct{}
|
||||
|
||||
// NewSecurityHeadersMiddleware 创建安全头部中间件
|
||||
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
|
||||
return &SecurityHeadersMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *SecurityHeadersMiddleware) GetName() string {
|
||||
return "security_headers"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *SecurityHeadersMiddleware) GetPriority() int {
|
||||
return 85 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置安全头部
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Content-Security-Policy", "default-src 'self'")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ResponseTimeMiddleware 响应时间中间件
|
||||
type ResponseTimeMiddleware struct{}
|
||||
|
||||
// NewResponseTimeMiddleware 创建响应时间中间件
|
||||
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
|
||||
return &ResponseTimeMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ResponseTimeMiddleware) GetName() string {
|
||||
return "response_time"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ResponseTimeMiddleware) GetPriority() int {
|
||||
return 75 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
// 计算响应时间并添加到头部
|
||||
duration := time.Since(start)
|
||||
c.Header("X-Response-Time", duration.String())
|
||||
|
||||
// 记录到上下文中,供其他中间件使用
|
||||
c.Set("response_time", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ResponseTimeMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
|
||||
type RequestBodyLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
enable bool
|
||||
}
|
||||
|
||||
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
|
||||
return &RequestBodyLoggerMiddleware{
|
||||
logger: logger,
|
||||
enable: enable,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestBodyLoggerMiddleware) GetName() string {
|
||||
return "request_body_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
|
||||
return 70 // 较低优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.enable {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 只记录POST, PUT, PATCH请求的body
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err == nil {
|
||||
// 重新设置body供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
|
||||
m.logger.Debug("Request Body",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("body", string(bodyBytes)),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
|
||||
return false // 可选中间件,不是全局的
|
||||
}
|
||||
Reference in New Issue
Block a user