Files
tyapi-server/internal/container/container.go

707 lines
21 KiB
Go

package container
import (
"context"
"fmt"
nethttp "net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/fx"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/user/handlers"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/routes"
"tyapi-server/internal/domains/user/services"
"tyapi-server/internal/shared/cache"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/events"
"tyapi-server/internal/shared/health"
"tyapi-server/internal/shared/hooks"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/metrics"
"tyapi-server/internal/shared/middleware"
"tyapi-server/internal/shared/resilience"
"tyapi-server/internal/shared/saga"
"tyapi-server/internal/shared/sms"
"tyapi-server/internal/shared/tracing"
)
// Container 应用容器
type Container struct {
App *fx.App
}
// NewContainer 创建新的应用容器
func NewContainer() *Container {
app := fx.New(
// 配置模块
fx.Provide(
config.LoadConfig,
),
// 基础设施模块
fx.Provide(
NewLogger,
// 使用带追踪的组件
NewTracedDatabase,
NewRedisClient,
fx.Annotate(NewTracedRedisCache, fx.As(new(interfaces.CacheService))),
NewEventBus,
NewHealthChecker,
NewSMSService,
),
// 高级特性模块
fx.Provide(
NewTracer,
NewPrometheusMetrics,
NewBusinessMetrics,
NewCircuitBreakerWrapper,
NewRetryerWrapper,
NewSagaManager,
NewHookSystem,
),
// HTTP基础组件
fx.Provide(
NewResponseBuilder,
NewRequestValidator,
NewGinRouter,
),
// 中间件组件
fx.Provide(
NewRequestIDMiddleware,
NewSecurityHeadersMiddleware,
NewResponseTimeMiddleware,
NewCORSMiddleware,
NewRateLimitMiddleware,
NewRequestLoggerMiddleware,
NewJWTAuthMiddleware,
NewOptionalAuthMiddleware,
NewTracingMiddleware,
NewMetricsMiddleware,
NewTraceIDMiddleware,
NewErrorTrackingMiddleware,
NewRequestBodyLoggerMiddleware,
),
// 用户域组件
fx.Provide(
NewUserRepository,
NewSMSCodeRepository,
NewSMSCodeService,
NewUserService,
// 使用带自动追踪的用户服务
fx.Annotate(NewTracedUserService, fx.As(new(interfaces.UserService))),
NewUserHandler,
),
// 应用生命周期
fx.Invoke(
RegisterLifecycleHooks,
RegisterMiddlewares,
RegisterRoutes,
),
)
return &Container{App: app}
}
// Start 启动容器
func (c *Container) Start() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return c.App.Start(ctx)
}
// Stop 停止容器
func (c *Container) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return c.App.Stop(ctx)
}
// ================ 基础设施构造函数 ================
// NewLogger 创建日志器
func NewLogger(cfg *config.Config) (*zap.Logger, error) {
level, err := zap.ParseAtomicLevel(cfg.Logger.Level)
if err != nil {
level = zap.NewAtomicLevelAt(zap.InfoLevel)
}
var config zap.Config
if cfg.App.IsDevelopment() {
config = zap.NewDevelopmentConfig()
config.Level = level
config.Encoding = "console"
if cfg.Logger.UseColor {
config.EncoderConfig.EncodeLevel = zapcore.LowercaseColorLevelEncoder
}
} else {
config = zap.NewProductionConfig()
config.Level = level
config.Encoding = cfg.Logger.Format
if config.Encoding == "" {
config.Encoding = "json"
}
}
config.OutputPaths = []string{cfg.Logger.Output}
config.ErrorOutputPaths = []string{"stderr"}
if cfg.Logger.Output == "" {
config.OutputPaths = []string{"stdout"}
}
return config.Build()
}
// NewDatabase 创建数据库连接
func NewDatabase(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
dbConfig := database.Config{
Host: cfg.Database.Host,
Port: cfg.Database.Port,
User: cfg.Database.User,
Password: cfg.Database.Password,
Name: cfg.Database.Name,
SSLMode: cfg.Database.SSLMode,
Timezone: cfg.Database.Timezone,
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: cfg.Database.ConnMaxLifetime,
}
db, err := database.NewConnection(dbConfig)
if err != nil {
return nil, err
}
return db.DB, nil
}
// NewTracedDatabase 创建带追踪的数据库连接
func NewTracedDatabase(cfg *config.Config, tracer *tracing.Tracer, logger *zap.Logger) (*gorm.DB, error) {
// 先创建基础数据库连接
db, err := NewDatabase(cfg, logger)
if err != nil {
return nil, err
}
// 创建并注册GORM追踪插件
tracingPlugin := tracing.NewGormTracingPlugin(tracer, logger)
if err := db.Use(tracingPlugin); err != nil {
logger.Error("注册GORM追踪插件失败", zap.Error(err))
return nil, err
}
logger.Info("GORM自动追踪已启用")
return db, nil
}
// NewRedisClient 创建Redis客户端
func NewRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Redis.GetRedisAddr(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
PoolSize: cfg.Redis.PoolSize,
MinIdleConns: cfg.Redis.MinIdleConns,
DialTimeout: cfg.Redis.DialTimeout,
ReadTimeout: cfg.Redis.ReadTimeout,
WriteTimeout: cfg.Redis.WriteTimeout,
})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
logger.Error("Redis连接失败", zap.Error(err))
return nil, err
}
logger.Info("Redis连接已建立")
return client, nil
}
// NewRedisCache 创建Redis缓存服务
func NewRedisCache(client *redis.Client, logger *zap.Logger, cfg *config.Config) interfaces.CacheService {
return cache.NewRedisCache(client, logger, "app")
}
// NewTracedRedisCache 创建带追踪的Redis缓存服务
func NewTracedRedisCache(client *redis.Client, tracer *tracing.Tracer, logger *zap.Logger, cfg *config.Config) interfaces.CacheService {
return tracing.NewTracedRedisCache(client, tracer, logger, "app")
}
// NewEventBus 创建事件总线
func NewEventBus(logger *zap.Logger, cfg *config.Config) interfaces.EventBus {
return events.NewMemoryEventBus(logger, 5)
}
// NewHealthChecker 创建健康检查器
func NewHealthChecker(logger *zap.Logger) *health.HealthChecker {
return health.NewHealthChecker(logger)
}
// NewSMSService 创建短信服务
func NewSMSService(cfg *config.Config, logger *zap.Logger) (sms.Service, error) {
if cfg.SMS.MockEnabled {
logger.Info("使用模拟短信服务 (mock_enabled=true)")
return sms.NewMockSMSService(logger), nil
}
logger.Info("使用阿里云短信服务")
return sms.NewAliSMSService(cfg.SMS, logger)
}
// ================ HTTP组件构造函数 ================
// NewResponseBuilder 创建响应构建器
func NewResponseBuilder() interfaces.ResponseBuilder {
return sharedhttp.NewResponseBuilder()
}
// NewRequestValidator 创建中文请求验证器
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
return sharedhttp.NewRequestValidatorZh(response)
}
// NewGinRouter 创建Gin路由器
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *sharedhttp.GinRouter {
return sharedhttp.NewGinRouter(cfg, logger)
}
// ================ 中间件构造函数 ================
func NewRequestIDMiddleware() *middleware.RequestIDMiddleware {
return middleware.NewRequestIDMiddleware()
}
func NewSecurityHeadersMiddleware() *middleware.SecurityHeadersMiddleware {
return middleware.NewSecurityHeadersMiddleware()
}
func NewResponseTimeMiddleware() *middleware.ResponseTimeMiddleware {
return middleware.NewResponseTimeMiddleware()
}
func NewCORSMiddleware(cfg *config.Config) *middleware.CORSMiddleware {
return middleware.NewCORSMiddleware(cfg)
}
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *middleware.RateLimitMiddleware {
return middleware.NewRateLimitMiddleware(cfg, response)
}
func NewRequestLoggerMiddleware(logger *zap.Logger, cfg *config.Config, tracer *tracing.Tracer) *middleware.RequestLoggerMiddleware {
return middleware.NewRequestLoggerMiddleware(logger, cfg.App.IsDevelopment(), tracer)
}
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *middleware.JWTAuthMiddleware {
return middleware.NewJWTAuthMiddleware(cfg, logger)
}
func NewOptionalAuthMiddleware(jwtAuth *middleware.JWTAuthMiddleware) *middleware.OptionalAuthMiddleware {
return middleware.NewOptionalAuthMiddleware(jwtAuth)
}
func NewTraceIDMiddleware(tracer *tracing.Tracer) *middleware.TraceIDMiddleware {
return middleware.NewTraceIDMiddleware(tracer)
}
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *middleware.ErrorTrackingMiddleware {
return middleware.NewErrorTrackingMiddleware(logger, tracer)
}
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, cfg *config.Config, tracer *tracing.Tracer) *middleware.RequestBodyLoggerMiddleware {
return middleware.NewRequestBodyLoggerMiddleware(logger, cfg.App.IsDevelopment(), tracer)
}
// ================ 高级特性构造函数 ================
// NewTracer 创建链路追踪器
func NewTracer(cfg *config.Config, logger *zap.Logger) *tracing.Tracer {
tracingConfig := tracing.TracerConfig{
ServiceName: cfg.App.Name,
ServiceVersion: cfg.App.Version,
Environment: cfg.App.Env,
Endpoint: cfg.Monitoring.TracingEndpoint,
SampleRate: cfg.Monitoring.SampleRate,
Enabled: cfg.Monitoring.TracingEnabled,
}
return tracing.NewTracer(tracingConfig, logger)
}
// NewPrometheusMetrics 创建Prometheus指标收集器
func NewPrometheusMetrics(logger *zap.Logger, cfg *config.Config) interfaces.MetricsCollector {
if !cfg.Monitoring.MetricsEnabled {
return &NoopMetricsCollector{}
}
return metrics.NewPrometheusMetrics(logger)
}
// NewBusinessMetrics 创建业务指标收集器
func NewBusinessMetrics(prometheusMetrics interfaces.MetricsCollector, logger *zap.Logger) *metrics.BusinessMetrics {
return metrics.NewBusinessMetrics(prometheusMetrics, logger)
}
// NewCircuitBreakerWrapper 创建熔断器包装器
func NewCircuitBreakerWrapper(logger *zap.Logger, cfg *config.Config) *resilience.Wrapper {
return resilience.NewWrapper(logger)
}
// NewRetryerWrapper 创建重试器包装器
func NewRetryerWrapper(logger *zap.Logger) *resilience.RetryerWrapper {
return resilience.NewRetryerWrapper(logger)
}
// NewSagaManager 创建Saga管理器
func NewSagaManager(cfg *config.Config, logger *zap.Logger) *saga.SagaManager {
sagaConfig := saga.SagaConfig{
DefaultTimeout: 30 * time.Second,
DefaultMaxRetries: 3,
Parallel: false,
}
return saga.NewSagaManager(sagaConfig, logger)
}
// NewHookSystem 创建钩子系统
func NewHookSystem(logger *zap.Logger) *hooks.HookSystem {
hookConfig := hooks.HookConfig{
DefaultTimeout: 30 * time.Second,
TrackDuration: true,
ErrorStrategy: hooks.ContinueOnError,
}
return hooks.NewHookSystem(hookConfig, logger)
}
// NewTracingMiddleware 创建追踪中间件
func NewTracingMiddleware(tracer *tracing.Tracer) *TracingMiddleware {
return &TracingMiddleware{tracer: tracer}
}
// NewMetricsMiddleware 创建指标中间件
func NewMetricsMiddleware(metricsCollector interfaces.MetricsCollector) *MetricsMiddleware {
return &MetricsMiddleware{metrics: metricsCollector}
}
// ================ 用户域构造函数 ================
func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *repositories.UserRepository {
return repositories.NewUserRepository(db, cache, logger)
}
func NewSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *repositories.SMSCodeRepository {
return repositories.NewSMSCodeRepository(db, cache, logger)
}
func NewSMSCodeService(
repo *repositories.SMSCodeRepository,
smsClient sms.Service,
cache interfaces.CacheService,
cfg *config.Config,
logger *zap.Logger,
) *services.SMSCodeService {
return services.NewSMSCodeService(repo, smsClient, cache, cfg.SMS, logger)
}
func NewUserService(
repo *repositories.UserRepository,
smsCodeService *services.SMSCodeService,
eventBus interfaces.EventBus,
logger *zap.Logger,
) *services.UserService {
return services.NewUserService(repo, smsCodeService, eventBus, logger)
}
// NewTracedUserService 创建带自动追踪的用户服务
func NewTracedUserService(
baseService *services.UserService,
tracer *tracing.Tracer,
logger *zap.Logger,
) interfaces.UserService {
serviceWrapper := tracing.NewServiceWrapper(tracer, logger)
return tracing.NewTracedUserService(baseService, serviceWrapper)
}
func NewUserHandler(
userService interfaces.UserService,
smsCodeService *services.SMSCodeService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
) *handlers.UserHandler {
return handlers.NewUserHandler(userService, smsCodeService, response, validator, logger, jwtAuth)
}
// ================ 中间件定义 ================
// TracingMiddleware 追踪中间件
type TracingMiddleware struct {
tracer *tracing.Tracer
}
func (tm *TracingMiddleware) GetName() string { return "tracing" }
func (tm *TracingMiddleware) GetPriority() int { return 1 }
func (tm *TracingMiddleware) IsGlobal() bool { return true }
func (tm *TracingMiddleware) Handle() gin.HandlerFunc {
return tm.tracer.TraceMiddleware()
}
// MetricsMiddleware 指标中间件
type MetricsMiddleware struct {
metrics interfaces.MetricsCollector
}
func (mm *MetricsMiddleware) GetName() string { return "metrics" }
func (mm *MetricsMiddleware) GetPriority() int { return 2 }
func (mm *MetricsMiddleware) IsGlobal() bool { return true }
func (mm *MetricsMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
duration := time.Since(start).Seconds()
mm.metrics.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
mm.metrics.RecordHTTPDuration(c.Request.Method, c.FullPath(), duration)
}
}
// NoopMetricsCollector 空的指标收集器实现
type NoopMetricsCollector struct{}
func (n *NoopMetricsCollector) RecordHTTPRequest(method, path string, status int, duration float64) {}
func (n *NoopMetricsCollector) RecordHTTPDuration(method, path string, duration float64) {}
func (n *NoopMetricsCollector) IncrementCounter(name string, labels map[string]string) {}
func (n *NoopMetricsCollector) RecordGauge(name string, value float64, labels map[string]string) {}
func (n *NoopMetricsCollector) RecordHistogram(name string, value float64, labels map[string]string) {
}
func (n *NoopMetricsCollector) RegisterCounter(name, help string, labels []string) error { return nil }
func (n *NoopMetricsCollector) RegisterGauge(name, help string, labels []string) error { return nil }
func (n *NoopMetricsCollector) RegisterHistogram(name, help string, labels []string, buckets []float64) error {
return nil
}
func (n *NoopMetricsCollector) GetHandler() nethttp.Handler { return nil }
// ================ 注册函数 ================
// RegisterMiddlewares 注册中间件
func RegisterMiddlewares(
router *sharedhttp.GinRouter,
requestID *middleware.RequestIDMiddleware,
security *middleware.SecurityHeadersMiddleware,
responseTime *middleware.ResponseTimeMiddleware,
cors *middleware.CORSMiddleware,
rateLimit *middleware.RateLimitMiddleware,
requestLogger *middleware.RequestLoggerMiddleware,
tracingMiddleware *TracingMiddleware,
metricsMiddleware *MetricsMiddleware,
traceIDMiddleware *middleware.TraceIDMiddleware,
errorTrackingMiddleware *middleware.ErrorTrackingMiddleware,
requestBodyLogger *middleware.RequestBodyLoggerMiddleware,
) {
router.RegisterMiddleware(requestID)
router.RegisterMiddleware(security)
router.RegisterMiddleware(responseTime)
router.RegisterMiddleware(cors)
router.RegisterMiddleware(rateLimit)
router.RegisterMiddleware(requestLogger)
router.RegisterMiddleware(tracingMiddleware)
router.RegisterMiddleware(metricsMiddleware)
router.RegisterMiddleware(traceIDMiddleware)
router.RegisterMiddleware(errorTrackingMiddleware)
router.RegisterMiddleware(requestBodyLogger)
}
// RegisterRoutes 注册路由
func RegisterRoutes(
router *sharedhttp.GinRouter,
userHandler *handlers.UserHandler,
jwtAuth *middleware.JWTAuthMiddleware,
metricsCollector interfaces.MetricsCollector,
) {
router.SetupDefaultRoutes()
if handler := metricsCollector.GetHandler(); handler != nil {
router.GetEngine().GET("/metrics", gin.WrapH(handler))
}
routes.UserRoutes(router.GetEngine(), userHandler, jwtAuth)
router.PrintRoutes()
}
// RegisterLifecycleHooks 注册生命周期钩子
func RegisterLifecycleHooks(
lc fx.Lifecycle,
logger *zap.Logger,
cfg *config.Config,
db *gorm.DB,
cache interfaces.CacheService,
eventBus interfaces.EventBus,
healthChecker *health.HealthChecker,
router *sharedhttp.GinRouter,
userService *services.UserService,
tracer *tracing.Tracer,
prometheusMetrics interfaces.MetricsCollector,
businessMetrics *metrics.BusinessMetrics,
circuitBreakerWrapper *resilience.Wrapper,
retryerWrapper *resilience.RetryerWrapper,
sagaManager *saga.SagaManager,
hookSystem *hooks.HookSystem,
) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
logger.Info("正在启动应用服务...")
// 初始化高级特性
if err := tracer.Initialize(ctx); err != nil {
logger.Error("初始化追踪器失败", zap.Error(err))
return err
}
if err := hookSystem.Initialize(ctx); err != nil {
logger.Error("初始化钩子系统失败", zap.Error(err))
return err
}
if err := sagaManager.Initialize(ctx); err != nil {
logger.Error("初始化事务管理器失败", zap.Error(err))
return err
}
// 注册服务到健康检查器
healthChecker.RegisterService(userService)
healthChecker.RegisterService(sagaManager)
// 启动基础服务
if err := cache.Initialize(ctx); err != nil {
logger.Error("初始化缓存失败", zap.Error(err))
return err
}
if err := eventBus.Start(ctx); err != nil {
logger.Error("启动事件总线失败", zap.Error(err))
return err
}
if err := hookSystem.Start(ctx); err != nil {
logger.Error("启动钩子系统失败", zap.Error(err))
return err
}
// 注册应用钩子
registerApplicationHooks(hookSystem, businessMetrics, logger)
// 启动健康检查
if cfg.Health.Enabled {
go healthChecker.StartPeriodicCheck(ctx, cfg.Health.Interval)
}
// 启动HTTP服务器
go func() {
addr := fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
if err := router.Start(addr); err != nil {
logger.Error("启动HTTP服务器失败", zap.Error(err))
}
}()
logger.Info("所有服务已成功启动")
return nil
},
OnStop: func(ctx context.Context) error {
logger.Info("正在停止应用服务...")
// 按顺序关闭服务
router.Stop(ctx)
hookSystem.Shutdown(ctx)
sagaManager.Shutdown(ctx)
eventBus.Stop(ctx)
tracer.Shutdown(ctx)
cache.Shutdown(ctx)
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
logger.Info("所有服务已停止")
return nil
},
})
}
// registerApplicationHooks 注册应用钩子
func registerApplicationHooks(hookSystem *hooks.HookSystem, businessMetrics *metrics.BusinessMetrics, logger *zap.Logger) {
userCreatedHook := &hooks.Hook{
Name: "metrics.user_created",
Priority: 1,
Async: false,
Timeout: 5 * time.Second,
Func: func(ctx context.Context, data interface{}) error {
businessMetrics.RecordUserCreated("register")
logger.Debug("记录用户创建指标")
return nil
},
}
hookSystem.Register("user.created", userCreatedHook)
userLoginHook := &hooks.Hook{
Name: "metrics.user_login",
Priority: 1,
Async: false,
Timeout: 5 * time.Second,
Func: func(ctx context.Context, data interface{}) error {
businessMetrics.RecordUserLogin("web", "success")
logger.Debug("记录用户登录指标")
return nil
},
}
hookSystem.Register("user.logged_in", userLoginHook)
logger.Info("应用钩子已成功注册")
}
// ServiceRegistrar 服务注册器接口
type ServiceRegistrar interface {
RegisterServices() fx.Option
}
// DomainModule 领域模块接口
type DomainModule interface {
ServiceRegistrar
GetName() string
GetDependencies() []string
}
// RegisterDomainModule 注册领域模块
func RegisterDomainModule(module DomainModule) fx.Option {
return fx.Options(
fx.Provide(
fx.Annotated{
Name: module.GetName(),
Target: func() DomainModule {
return module
},
},
),
module.RegisterServices(),
)
}
// GetContainer 获取容器实例(用于测试或特殊情况)
func GetContainer(cfg *config.Config) *Container {
return NewContainer()
}