Initial commit: Basic project structure and dependencies

This commit is contained in:
2025-06-30 19:21:56 +08:00
commit 03e615a8fd
50 changed files with 11664 additions and 0 deletions

235
internal/app/app.go Normal file
View File

@@ -0,0 +1,235 @@
package app
import (
"fmt"
"os"
"os/signal"
"syscall"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/config"
"tyapi-server/internal/container"
"tyapi-server/internal/domains/user/entities"
)
// Application 应用程序结构
type Application struct {
container *container.Container
config *config.Config
logger *zap.Logger
}
// NewApplication 创建新的应用程序实例
func NewApplication() (*Application, error) {
// 加载配置
cfg, err := config.LoadConfig()
if err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
// 创建日志器
logger, err := createLogger(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create logger: %w", err)
}
// 创建容器
cont := container.NewContainer()
return &Application{
container: cont,
config: cfg,
logger: logger,
}, nil
}
// Run 运行应用程序
func (a *Application) Run() error {
// 打印启动信息
a.printBanner()
// 启动容器
a.logger.Info("Starting application container...")
if err := a.container.Start(); err != nil {
a.logger.Error("Failed to start container", zap.Error(err))
return err
}
// 设置优雅关闭
a.setupGracefulShutdown()
a.logger.Info("Application started successfully",
zap.String("version", a.config.App.Version),
zap.String("environment", a.config.App.Env),
zap.String("port", a.config.Server.Port))
// 等待信号
return a.waitForShutdown()
}
// RunMigrations 运行数据库迁移
func (a *Application) RunMigrations() error {
a.logger.Info("Running database migrations...")
// 创建数据库连接
db, err := a.createDatabaseConnection()
if err != nil {
return fmt.Errorf("failed to create database connection: %w", err)
}
// 自动迁移
if err := a.autoMigrate(db); err != nil {
return fmt.Errorf("failed to run migrations: %w", err)
}
a.logger.Info("Database migrations completed successfully")
return nil
}
// printBanner 打印启动横幅
func (a *Application) printBanner() {
banner := fmt.Sprintf(`
╔══════════════════════════════════════════════════════════════╗
║ %s ║
║ Version: %s ║
║ Environment: %s ║
║ Port: %s ║
╚══════════════════════════════════════════════════════════════╝
`,
a.config.App.Name,
a.config.App.Version,
a.config.App.Env,
a.config.Server.Port,
)
fmt.Println(banner)
}
// setupGracefulShutdown 设置优雅关闭
func (a *Application) setupGracefulShutdown() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
go func() {
<-c
a.logger.Info("Received shutdown signal, starting graceful shutdown...")
// 停止容器
if err := a.container.Stop(); err != nil {
a.logger.Error("Error during container shutdown", zap.Error(err))
}
a.logger.Info("Application shutdown completed")
os.Exit(0)
}()
}
// waitForShutdown 等待关闭信号
func (a *Application) waitForShutdown() error {
// 创建一个通道来等待关闭
done := make(chan bool, 1)
// 启动一个协程来监听信号
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)
<-c
done <- true
}()
// 等待关闭信号
<-done
return nil
}
// createDatabaseConnection 创建数据库连接
func (a *Application) createDatabaseConnection() (*gorm.DB, error) {
return container.NewDatabase(a.config, a.logger)
}
// autoMigrate 自动迁移
func (a *Application) autoMigrate(db *gorm.DB) error {
// 迁移用户相关表
return db.AutoMigrate(
&entities.User{},
// 后续可以添加其他实体
)
}
// createLogger 创建日志器
func createLogger(cfg *config.Config) (*zap.Logger, error) {
level, err := zap.ParseAtomicLevel(cfg.Logger.Level)
if err != nil {
level = zap.NewAtomicLevelAt(zap.InfoLevel)
}
config := zap.Config{
Level: level,
Development: cfg.App.IsDevelopment(),
Encoding: cfg.Logger.Format,
EncoderConfig: zap.NewProductionEncoderConfig(),
OutputPaths: []string{cfg.Logger.Output},
ErrorOutputPaths: []string{"stderr"},
}
if cfg.Logger.Format == "" {
config.Encoding = "json"
}
if cfg.Logger.Output == "" {
config.OutputPaths = []string{"stdout"}
}
return config.Build()
}
// GetConfig 获取配置
func (a *Application) GetConfig() *config.Config {
return a.config
}
// GetLogger 获取日志器
func (a *Application) GetLogger() *zap.Logger {
return a.logger
}
// HealthCheck 应用程序健康检查
func (a *Application) HealthCheck() error {
// 这里可以添加应用程序级别的健康检查逻辑
return nil
}
// GetVersion 获取版本信息
func (a *Application) GetVersion() map[string]string {
return map[string]string{
"name": a.config.App.Name,
"version": a.config.App.Version,
"environment": a.config.App.Env,
"go_version": "1.23.4+",
}
}
// RunCommand 运行特定命令
func (a *Application) RunCommand(command string, args ...string) error {
switch command {
case "migrate":
return a.RunMigrations()
case "version":
version := a.GetVersion()
fmt.Printf("Name: %s\n", version["name"])
fmt.Printf("Version: %s\n", version["version"])
fmt.Printf("Environment: %s\n", version["environment"])
fmt.Printf("Go Version: %s\n", version["go_version"])
return nil
case "health":
if err := a.HealthCheck(); err != nil {
fmt.Printf("Health check failed: %v\n", err)
return err
}
fmt.Println("Application is healthy")
return nil
default:
return fmt.Errorf("unknown command: %s", command)
}
}

166
internal/config/config.go Normal file
View File

@@ -0,0 +1,166 @@
package config
import (
"time"
)
// Config 应用程序总配置
type Config struct {
Server ServerConfig `mapstructure:"server"`
Database DatabaseConfig `mapstructure:"database"`
Redis RedisConfig `mapstructure:"redis"`
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
JWT JWTConfig `mapstructure:"jwt"`
RateLimit RateLimitConfig `mapstructure:"ratelimit"`
Monitoring MonitoringConfig `mapstructure:"monitoring"`
Health HealthConfig `mapstructure:"health"`
Resilience ResilienceConfig `mapstructure:"resilience"`
Development DevelopmentConfig `mapstructure:"development"`
App AppConfig `mapstructure:"app"`
}
// ServerConfig HTTP服务器配置
type ServerConfig struct {
Port string `mapstructure:"port"`
Mode string `mapstructure:"mode"`
Host string `mapstructure:"host"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
IdleTimeout time.Duration `mapstructure:"idle_timeout"`
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port string `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
Name string `mapstructure:"name"`
SSLMode string `mapstructure:"sslmode"`
Timezone string `mapstructure:"timezone"`
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
}
// RedisConfig Redis配置
type RedisConfig struct {
Host string `mapstructure:"host"`
Port string `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
PoolSize int `mapstructure:"pool_size"`
MinIdleConns int `mapstructure:"min_idle_conns"`
MaxRetries int `mapstructure:"max_retries"`
DialTimeout time.Duration `mapstructure:"dial_timeout"`
ReadTimeout time.Duration `mapstructure:"read_timeout"`
WriteTimeout time.Duration `mapstructure:"write_timeout"`
}
// CacheConfig 缓存配置
type CacheConfig struct {
DefaultTTL time.Duration `mapstructure:"default_ttl"`
CleanupInterval time.Duration `mapstructure:"cleanup_interval"`
MaxSize int `mapstructure:"max_size"`
}
// LoggerConfig 日志配置
type LoggerConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
Output string `mapstructure:"output"`
FilePath string `mapstructure:"file_path"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
}
// JWTConfig JWT配置
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpiresIn time.Duration `mapstructure:"expires_in"`
RefreshExpiresIn time.Duration `mapstructure:"refresh_expires_in"`
}
// RateLimitConfig 限流配置
type RateLimitConfig struct {
Requests int `mapstructure:"requests"`
Window time.Duration `mapstructure:"window"`
Burst int `mapstructure:"burst"`
}
// MonitoringConfig 监控配置
type MonitoringConfig struct {
MetricsEnabled bool `mapstructure:"metrics_enabled"`
MetricsPort string `mapstructure:"metrics_port"`
TracingEnabled bool `mapstructure:"tracing_enabled"`
TracingEndpoint string `mapstructure:"tracing_endpoint"`
SampleRate float64 `mapstructure:"sample_rate"`
}
// HealthConfig 健康检查配置
type HealthConfig struct {
Enabled bool `mapstructure:"enabled"`
Interval time.Duration `mapstructure:"interval"`
Timeout time.Duration `mapstructure:"timeout"`
}
// ResilienceConfig 容错配置
type ResilienceConfig struct {
CircuitBreakerEnabled bool `mapstructure:"circuit_breaker_enabled"`
CircuitBreakerThreshold int `mapstructure:"circuit_breaker_threshold"`
CircuitBreakerTimeout time.Duration `mapstructure:"circuit_breaker_timeout"`
RetryMaxAttempts int `mapstructure:"retry_max_attempts"`
RetryInitialDelay time.Duration `mapstructure:"retry_initial_delay"`
RetryMaxDelay time.Duration `mapstructure:"retry_max_delay"`
}
// DevelopmentConfig 开发配置
type DevelopmentConfig struct {
Debug bool `mapstructure:"debug"`
EnableProfiler bool `mapstructure:"enable_profiler"`
EnableCors bool `mapstructure:"enable_cors"`
CorsOrigins string `mapstructure:"cors_allowed_origins"`
CorsMethods string `mapstructure:"cors_allowed_methods"`
CorsHeaders string `mapstructure:"cors_allowed_headers"`
}
// AppConfig 应用程序配置
type AppConfig struct {
Name string `mapstructure:"name"`
Version string `mapstructure:"version"`
Env string `mapstructure:"env"`
}
// GetDSN 获取数据库DSN连接字符串
func (d DatabaseConfig) GetDSN() string {
return "host=" + d.Host +
" user=" + d.User +
" password=" + d.Password +
" dbname=" + d.Name +
" port=" + d.Port +
" sslmode=" + d.SSLMode +
" TimeZone=" + d.Timezone
}
// GetRedisAddr 获取Redis地址
func (r RedisConfig) GetRedisAddr() string {
return r.Host + ":" + r.Port
}
// IsProduction 检查是否为生产环境
func (a AppConfig) IsProduction() bool {
return a.Env == "production"
}
// IsDevelopment 检查是否为开发环境
func (a AppConfig) IsDevelopment() bool {
return a.Env == "development"
}
// IsStaging 检查是否为测试环境
func (a AppConfig) IsStaging() bool {
return a.Env == "staging"
}

311
internal/config/loader.go Normal file
View File

@@ -0,0 +1,311 @@
package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/spf13/viper"
)
// LoadConfig 加载应用程序配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("./configs")
viper.AddConfigPath("$HOME/.tyapi")
// 设置环境变量前缀
viper.SetEnvPrefix("")
viper.AutomaticEnv()
// 配置环境变量键名映射
setupEnvKeyMapping()
// 设置默认值
setDefaults()
// 尝试读取配置文件(可选)
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
// 配置文件不存在时使用环境变量和默认值
}
var config Config
if err := viper.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
// 验证配置
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
return &config, nil
}
// setupEnvKeyMapping 设置环境变量到配置键的映射
func setupEnvKeyMapping() {
// 服务器配置
viper.BindEnv("server.port", "SERVER_PORT")
viper.BindEnv("server.mode", "SERVER_MODE")
viper.BindEnv("server.host", "SERVER_HOST")
viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT")
viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT")
viper.BindEnv("server.idle_timeout", "SERVER_IDLE_TIMEOUT")
// 数据库配置
viper.BindEnv("database.host", "DB_HOST")
viper.BindEnv("database.port", "DB_PORT")
viper.BindEnv("database.user", "DB_USER")
viper.BindEnv("database.password", "DB_PASSWORD")
viper.BindEnv("database.name", "DB_NAME")
viper.BindEnv("database.sslmode", "DB_SSLMODE")
viper.BindEnv("database.timezone", "DB_TIMEZONE")
viper.BindEnv("database.max_open_conns", "DB_MAX_OPEN_CONNS")
viper.BindEnv("database.max_idle_conns", "DB_MAX_IDLE_CONNS")
viper.BindEnv("database.conn_max_lifetime", "DB_CONN_MAX_LIFETIME")
// Redis配置
viper.BindEnv("redis.host", "REDIS_HOST")
viper.BindEnv("redis.port", "REDIS_PORT")
viper.BindEnv("redis.password", "REDIS_PASSWORD")
viper.BindEnv("redis.db", "REDIS_DB")
viper.BindEnv("redis.pool_size", "REDIS_POOL_SIZE")
viper.BindEnv("redis.min_idle_conns", "REDIS_MIN_IDLE_CONNS")
viper.BindEnv("redis.max_retries", "REDIS_MAX_RETRIES")
viper.BindEnv("redis.dial_timeout", "REDIS_DIAL_TIMEOUT")
viper.BindEnv("redis.read_timeout", "REDIS_READ_TIMEOUT")
viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT")
// 缓存配置
viper.BindEnv("cache.default_ttl", "CACHE_DEFAULT_TTL")
viper.BindEnv("cache.cleanup_interval", "CACHE_CLEANUP_INTERVAL")
viper.BindEnv("cache.max_size", "CACHE_MAX_SIZE")
// 日志配置
viper.BindEnv("logger.level", "LOG_LEVEL")
viper.BindEnv("logger.format", "LOG_FORMAT")
viper.BindEnv("logger.output", "LOG_OUTPUT")
viper.BindEnv("logger.file_path", "LOG_FILE_PATH")
viper.BindEnv("logger.max_size", "LOG_MAX_SIZE")
viper.BindEnv("logger.max_backups", "LOG_MAX_BACKUPS")
viper.BindEnv("logger.max_age", "LOG_MAX_AGE")
viper.BindEnv("logger.compress", "LOG_COMPRESS")
// JWT配置
viper.BindEnv("jwt.secret", "JWT_SECRET")
viper.BindEnv("jwt.expires_in", "JWT_EXPIRES_IN")
viper.BindEnv("jwt.refresh_expires_in", "JWT_REFRESH_EXPIRES_IN")
// 限流配置
viper.BindEnv("ratelimit.requests", "RATE_LIMIT_REQUESTS")
viper.BindEnv("ratelimit.window", "RATE_LIMIT_WINDOW")
viper.BindEnv("ratelimit.burst", "RATE_LIMIT_BURST")
// 监控配置
viper.BindEnv("monitoring.metrics_enabled", "METRICS_ENABLED")
viper.BindEnv("monitoring.metrics_port", "METRICS_PORT")
viper.BindEnv("monitoring.tracing_enabled", "TRACING_ENABLED")
viper.BindEnv("monitoring.tracing_endpoint", "TRACING_ENDPOINT")
viper.BindEnv("monitoring.sample_rate", "TRACING_SAMPLE_RATE")
// 健康检查配置
viper.BindEnv("health.enabled", "HEALTH_CHECK_ENABLED")
viper.BindEnv("health.interval", "HEALTH_CHECK_INTERVAL")
viper.BindEnv("health.timeout", "HEALTH_CHECK_TIMEOUT")
// 容错配置
viper.BindEnv("resilience.circuit_breaker_enabled", "CIRCUIT_BREAKER_ENABLED")
viper.BindEnv("resilience.circuit_breaker_threshold", "CIRCUIT_BREAKER_THRESHOLD")
viper.BindEnv("resilience.circuit_breaker_timeout", "CIRCUIT_BREAKER_TIMEOUT")
viper.BindEnv("resilience.retry_max_attempts", "RETRY_MAX_ATTEMPTS")
viper.BindEnv("resilience.retry_initial_delay", "RETRY_INITIAL_DELAY")
viper.BindEnv("resilience.retry_max_delay", "RETRY_MAX_DELAY")
// 开发配置
viper.BindEnv("development.debug", "DEBUG")
viper.BindEnv("development.enable_profiler", "ENABLE_PROFILER")
viper.BindEnv("development.enable_cors", "ENABLE_CORS")
viper.BindEnv("development.cors_allowed_origins", "CORS_ALLOWED_ORIGINS")
viper.BindEnv("development.cors_allowed_methods", "CORS_ALLOWED_METHODS")
viper.BindEnv("development.cors_allowed_headers", "CORS_ALLOWED_HEADERS")
// 应用程序配置
viper.BindEnv("app.name", "APP_NAME")
viper.BindEnv("app.version", "APP_VERSION")
viper.BindEnv("app.env", "ENV")
}
// setDefaults 设置默认配置值
func setDefaults() {
// 服务器默认值
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.read_timeout", "30s")
viper.SetDefault("server.write_timeout", "30s")
viper.SetDefault("server.idle_timeout", "120s")
// 数据库默认值
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", "5432")
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "password")
viper.SetDefault("database.name", "tyapi_db")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.timezone", "Asia/Shanghai")
viper.SetDefault("database.max_open_conns", 100)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime", "300s")
// Redis默认值
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", "6379")
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.pool_size", 10)
viper.SetDefault("redis.min_idle_conns", 5)
viper.SetDefault("redis.max_retries", 3)
viper.SetDefault("redis.dial_timeout", "5s")
viper.SetDefault("redis.read_timeout", "3s")
viper.SetDefault("redis.write_timeout", "3s")
// 缓存默认值
viper.SetDefault("cache.default_ttl", "300s")
viper.SetDefault("cache.cleanup_interval", "600s")
viper.SetDefault("cache.max_size", 1000)
// 日志默认值
viper.SetDefault("logger.level", "info")
viper.SetDefault("logger.format", "json")
viper.SetDefault("logger.output", "stdout")
viper.SetDefault("logger.file_path", "logs/app.log")
viper.SetDefault("logger.max_size", 100)
viper.SetDefault("logger.max_backups", 5)
viper.SetDefault("logger.max_age", 30)
viper.SetDefault("logger.compress", true)
// JWT默认值
viper.SetDefault("jwt.secret", "your-super-secret-jwt-key-change-this-in-production")
viper.SetDefault("jwt.expires_in", "24h")
viper.SetDefault("jwt.refresh_expires_in", "168h")
// 限流默认值
viper.SetDefault("ratelimit.requests", 100)
viper.SetDefault("ratelimit.window", "60s")
viper.SetDefault("ratelimit.burst", 10)
// 监控默认值
viper.SetDefault("monitoring.metrics_enabled", true)
viper.SetDefault("monitoring.metrics_port", "9090")
viper.SetDefault("monitoring.tracing_enabled", false)
viper.SetDefault("monitoring.tracing_endpoint", "http://localhost:14268/api/traces")
viper.SetDefault("monitoring.sample_rate", 0.1)
// 健康检查默认值
viper.SetDefault("health.enabled", true)
viper.SetDefault("health.interval", "30s")
viper.SetDefault("health.timeout", "5s")
// 容错默认值
viper.SetDefault("resilience.circuit_breaker_enabled", true)
viper.SetDefault("resilience.circuit_breaker_threshold", 5)
viper.SetDefault("resilience.circuit_breaker_timeout", "60s")
viper.SetDefault("resilience.retry_max_attempts", 3)
viper.SetDefault("resilience.retry_initial_delay", "100ms")
viper.SetDefault("resilience.retry_max_delay", "2s")
// 开发默认值
viper.SetDefault("development.debug", true)
viper.SetDefault("development.enable_profiler", false)
viper.SetDefault("development.enable_cors", true)
viper.SetDefault("development.cors_allowed_origins", "*")
viper.SetDefault("development.cors_allowed_methods", "GET,POST,PUT,DELETE,OPTIONS")
viper.SetDefault("development.cors_allowed_headers", "*")
// 应用程序默认值
viper.SetDefault("app.name", "tyapi-server")
viper.SetDefault("app.version", "1.0.0")
viper.SetDefault("app.env", "development")
}
// validateConfig 验证配置
func validateConfig(config *Config) error {
// 验证必要的配置项
if config.Database.Host == "" {
return fmt.Errorf("数据库主机地址不能为空")
}
if config.Database.User == "" {
return fmt.Errorf("数据库用户名不能为空")
}
if config.Database.Name == "" {
return fmt.Errorf("数据库名称不能为空")
}
if config.JWT.Secret == "" || config.JWT.Secret == "your-super-secret-jwt-key-change-this-in-production" {
if config.App.IsProduction() {
return fmt.Errorf("生产环境必须设置安全的JWT密钥")
}
}
// 验证超时配置
if config.Server.ReadTimeout <= 0 {
return fmt.Errorf("服务器读取超时时间必须大于0")
}
if config.Server.WriteTimeout <= 0 {
return fmt.Errorf("服务器写入超时时间必须大于0")
}
// 验证数据库连接池配置
if config.Database.MaxOpenConns <= 0 {
return fmt.Errorf("数据库最大连接数必须大于0")
}
if config.Database.MaxIdleConns <= 0 {
return fmt.Errorf("数据库最大空闲连接数必须大于0")
}
if config.Database.MaxIdleConns > config.Database.MaxOpenConns {
return fmt.Errorf("数据库最大空闲连接数不能大于最大连接数")
}
return nil
}
// GetEnv 获取环境变量,如果不存在则返回默认值
func GetEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}
// ParseDuration 解析时间字符串
func ParseDuration(s string) time.Duration {
d, err := time.ParseDuration(s)
if err != nil {
return 0
}
return d
}
// SplitAndTrim 分割字符串并去除空格
func SplitAndTrim(s, sep string) []string {
parts := strings.Split(s, sep)
result := make([]string, 0, len(parts))
for _, part := range parts {
if trimmed := strings.TrimSpace(part); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}

View File

@@ -0,0 +1,441 @@
package container
import (
"context"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/fx"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/user/handlers"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/routes"
"tyapi-server/internal/domains/user/services"
"tyapi-server/internal/shared/cache"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/events"
"tyapi-server/internal/shared/health"
"tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/middleware"
)
// Container 应用容器
type Container struct {
App *fx.App
}
// NewContainer 创建新的应用容器
func NewContainer() *Container {
app := fx.New(
// 配置模块
fx.Provide(
config.LoadConfig,
),
// 基础设施模块
fx.Provide(
NewLogger,
NewDatabase,
NewRedisClient,
NewRedisCache,
NewEventBus,
NewHealthChecker,
),
// HTTP基础组件
fx.Provide(
NewResponseBuilder,
NewRequestValidator,
NewGinRouter,
),
// 中间件组件
fx.Provide(
NewRequestIDMiddleware,
NewSecurityHeadersMiddleware,
NewResponseTimeMiddleware,
NewCORSMiddleware,
NewRateLimitMiddleware,
NewRequestLoggerMiddleware,
NewJWTAuthMiddleware,
NewOptionalAuthMiddleware,
),
// 用户域组件
fx.Provide(
NewUserRepository,
NewUserService,
NewUserHandler,
NewUserRoutes,
),
// 应用生命周期
fx.Invoke(
RegisterLifecycleHooks,
RegisterMiddlewares,
RegisterRoutes,
),
)
return &Container{App: app}
}
// Start 启动容器
func (c *Container) Start() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return c.App.Start(ctx)
}
// Stop 停止容器
func (c *Container) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return c.App.Stop(ctx)
}
// 基础设施构造函数
// NewLogger 创建日志器
func NewLogger(cfg *config.Config) (*zap.Logger, error) {
level, err := zap.ParseAtomicLevel(cfg.Logger.Level)
if err != nil {
level = zap.NewAtomicLevelAt(zap.InfoLevel)
}
config := zap.Config{
Level: level,
Development: cfg.App.IsDevelopment(),
Encoding: cfg.Logger.Format,
EncoderConfig: zap.NewProductionEncoderConfig(),
OutputPaths: []string{cfg.Logger.Output},
ErrorOutputPaths: []string{"stderr"},
}
if cfg.Logger.Format == "" {
config.Encoding = "json"
}
if cfg.Logger.Output == "" {
config.OutputPaths = []string{"stdout"}
}
return config.Build()
}
// NewDatabase 创建数据库连接
func NewDatabase(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
dbConfig := database.Config{
Host: cfg.Database.Host,
Port: cfg.Database.Port,
User: cfg.Database.User,
Password: cfg.Database.Password,
Name: cfg.Database.Name,
SSLMode: cfg.Database.SSLMode,
Timezone: cfg.Database.Timezone,
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: cfg.Database.ConnMaxLifetime,
}
db, err := database.NewConnection(dbConfig)
if err != nil {
return nil, err
}
return db.DB, nil
}
// NewRedisClient 创建Redis客户端
func NewRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Redis.GetRedisAddr(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
PoolSize: cfg.Redis.PoolSize,
MinIdleConns: cfg.Redis.MinIdleConns,
DialTimeout: cfg.Redis.DialTimeout,
ReadTimeout: cfg.Redis.ReadTimeout,
WriteTimeout: cfg.Redis.WriteTimeout,
})
// 测试连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
logger.Error("Failed to connect to Redis", zap.Error(err))
return nil, err
}
logger.Info("Redis connection established")
return client, nil
}
// NewRedisCache 创建Redis缓存服务
func NewRedisCache(client *redis.Client, logger *zap.Logger, cfg *config.Config) interfaces.CacheService {
return cache.NewRedisCache(client, logger, "app")
}
// NewEventBus 创建事件总线
func NewEventBus(logger *zap.Logger, cfg *config.Config) interfaces.EventBus {
return events.NewMemoryEventBus(logger, 5) // 默认5个工作协程
}
// NewHealthChecker 创建健康检查器
func NewHealthChecker(logger *zap.Logger) *health.HealthChecker {
return health.NewHealthChecker(logger)
}
// HTTP组件构造函数
// NewResponseBuilder 创建响应构建器
func NewResponseBuilder() interfaces.ResponseBuilder {
return http.NewResponseBuilder()
}
// NewRequestValidator 创建请求验证器
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
return http.NewRequestValidator(response)
}
// NewGinRouter 创建Gin路由器
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *http.GinRouter {
return http.NewGinRouter(cfg, logger)
}
// 中间件构造函数
// NewRequestIDMiddleware 创建请求ID中间件
func NewRequestIDMiddleware() *middleware.RequestIDMiddleware {
return middleware.NewRequestIDMiddleware()
}
// NewSecurityHeadersMiddleware 创建安全头部中间件
func NewSecurityHeadersMiddleware() *middleware.SecurityHeadersMiddleware {
return middleware.NewSecurityHeadersMiddleware()
}
// NewResponseTimeMiddleware 创建响应时间中间件
func NewResponseTimeMiddleware() *middleware.ResponseTimeMiddleware {
return middleware.NewResponseTimeMiddleware()
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *middleware.CORSMiddleware {
return middleware.NewCORSMiddleware(cfg)
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *middleware.RateLimitMiddleware {
return middleware.NewRateLimitMiddleware(cfg)
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger) *middleware.RequestLoggerMiddleware {
return middleware.NewRequestLoggerMiddleware(logger)
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *middleware.JWTAuthMiddleware {
return middleware.NewJWTAuthMiddleware(cfg, logger)
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *middleware.JWTAuthMiddleware) *middleware.OptionalAuthMiddleware {
return middleware.NewOptionalAuthMiddleware(jwtAuth)
}
// 用户域构造函数
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *repositories.UserRepository {
return repositories.NewUserRepository(db, cache, logger)
}
// NewUserService 创建用户服务
func NewUserService(
repo *repositories.UserRepository,
eventBus interfaces.EventBus,
logger *zap.Logger,
) *services.UserService {
return services.NewUserService(repo, eventBus, logger)
}
// NewUserHandler 创建用户处理器
func NewUserHandler(
userService *services.UserService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
) *handlers.UserHandler {
return handlers.NewUserHandler(userService, response, validator, logger, jwtAuth)
}
// NewUserRoutes 创建用户路由
func NewUserRoutes(
handler *handlers.UserHandler,
jwtAuth *middleware.JWTAuthMiddleware,
optionalAuth *middleware.OptionalAuthMiddleware,
) *routes.UserRoutes {
return routes.NewUserRoutes(handler, jwtAuth, optionalAuth)
}
// 注册函数
// RegisterMiddlewares 注册中间件
func RegisterMiddlewares(
router *http.GinRouter,
requestID *middleware.RequestIDMiddleware,
security *middleware.SecurityHeadersMiddleware,
responseTime *middleware.ResponseTimeMiddleware,
cors *middleware.CORSMiddleware,
rateLimit *middleware.RateLimitMiddleware,
requestLogger *middleware.RequestLoggerMiddleware,
) {
// 注册全局中间件
router.RegisterMiddleware(requestID)
router.RegisterMiddleware(security)
router.RegisterMiddleware(responseTime)
router.RegisterMiddleware(cors)
router.RegisterMiddleware(rateLimit)
router.RegisterMiddleware(requestLogger)
}
// RegisterRoutes 注册路由
func RegisterRoutes(
router *http.GinRouter,
userRoutes *routes.UserRoutes,
) {
// 设置默认路由
router.SetupDefaultRoutes()
// 注册用户路由
userRoutes.RegisterRoutes(router.GetEngine())
userRoutes.RegisterPublicRoutes(router.GetEngine())
userRoutes.RegisterAdminRoutes(router.GetEngine())
userRoutes.RegisterHealthRoutes(router.GetEngine())
// 打印路由信息
router.PrintRoutes()
}
// 生命周期钩子
// RegisterLifecycleHooks 注册生命周期钩子
func RegisterLifecycleHooks(
lc fx.Lifecycle,
logger *zap.Logger,
cfg *config.Config,
db *gorm.DB,
cache interfaces.CacheService,
eventBus interfaces.EventBus,
healthChecker *health.HealthChecker,
router *http.GinRouter,
userService *services.UserService,
) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
logger.Info("Starting application services...")
// 注册服务到健康检查器
healthChecker.RegisterService(userService)
// 初始化缓存服务
if err := cache.Initialize(ctx); err != nil {
logger.Error("Failed to initialize cache", zap.Error(err))
return err
}
// 启动事件总线
if err := eventBus.Start(ctx); err != nil {
logger.Error("Failed to start event bus", zap.Error(err))
return err
}
// 启动健康检查(如果启用)
if cfg.Health.Enabled {
go healthChecker.StartPeriodicCheck(ctx, cfg.Health.Interval)
}
// 启动HTTP服务器
go func() {
addr := fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
if err := router.Start(addr); err != nil {
logger.Error("Failed to start HTTP server", zap.Error(err))
}
}()
logger.Info("All services started successfully")
return nil
},
OnStop: func(ctx context.Context) error {
logger.Info("Stopping application services...")
// 停止HTTP服务器
if err := router.Stop(ctx); err != nil {
logger.Error("Failed to stop HTTP server", zap.Error(err))
}
// 停止事件总线
if err := eventBus.Stop(ctx); err != nil {
logger.Error("Failed to stop event bus", zap.Error(err))
}
// 关闭缓存服务
if err := cache.Shutdown(ctx); err != nil {
logger.Error("Failed to shutdown cache service", zap.Error(err))
}
// 关闭数据库连接
if sqlDB, err := db.DB(); err == nil {
if err := sqlDB.Close(); err != nil {
logger.Error("Failed to close database", zap.Error(err))
}
}
logger.Info("All services stopped")
return nil
},
})
}
// ServiceRegistrar 服务注册器接口
type ServiceRegistrar interface {
RegisterServices() fx.Option
}
// DomainModule 领域模块接口
type DomainModule interface {
ServiceRegistrar
GetName() string
GetDependencies() []string
}
// RegisterDomainModule 注册领域模块
func RegisterDomainModule(module DomainModule) fx.Option {
return fx.Options(
fx.Provide(
fx.Annotated{
Name: module.GetName(),
Target: func() DomainModule {
return module
},
},
),
module.RegisterServices(),
)
}
// GetContainer 获取容器实例(用于测试或特殊情况)
func GetContainer(cfg *config.Config) *Container {
return NewContainer()
}

View File

@@ -0,0 +1,173 @@
package dto
import (
"time"
"tyapi-server/internal/domains/user/entities"
)
// CreateUserRequest 创建用户请求
type CreateUserRequest struct {
Username string `json:"username" binding:"required,min=3,max=50" example:"john_doe"`
Email string `json:"email" binding:"required,email" example:"john@example.com"`
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
FirstName string `json:"first_name" binding:"max=50" example:"John"`
LastName string `json:"last_name" binding:"max=50" example:"Doe"`
Phone string `json:"phone" binding:"omitempty,max=20" example:"+86-13800138000"`
}
// UpdateUserRequest 更新用户请求
type UpdateUserRequest struct {
FirstName *string `json:"first_name,omitempty" binding:"omitempty,max=50" example:"John"`
LastName *string `json:"last_name,omitempty" binding:"omitempty,max=50" example:"Doe"`
Phone *string `json:"phone,omitempty" binding:"omitempty,max=20" example:"+86-13800138000"`
Avatar *string `json:"avatar,omitempty" binding:"omitempty,url" example:"https://example.com/avatar.jpg"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required" example:"oldpassword123"`
NewPassword string `json:"new_password" binding:"required,min=6,max=128" example:"newpassword123"`
}
// UserResponse 用户响应
type UserResponse struct {
ID string `json:"id" example:"123e4567-e89b-12d3-a456-426614174000"`
Username string `json:"username" example:"john_doe"`
Email string `json:"email" example:"john@example.com"`
FirstName string `json:"first_name" example:"John"`
LastName string `json:"last_name" example:"Doe"`
Phone string `json:"phone" example:"+86-13800138000"`
Avatar string `json:"avatar" example:"https://example.com/avatar.jpg"`
Status entities.UserStatus `json:"status" example:"active"`
LastLoginAt *time.Time `json:"last_login_at,omitempty" example:"2024-01-01T00:00:00Z"`
CreatedAt time.Time `json:"created_at" example:"2024-01-01T00:00:00Z"`
UpdatedAt time.Time `json:"updated_at" example:"2024-01-01T00:00:00Z"`
Profile *UserProfileResponse `json:"profile,omitempty"`
}
// UserProfileResponse 用户档案响应
type UserProfileResponse struct {
Bio string `json:"bio,omitempty" example:"Software Developer"`
Location string `json:"location,omitempty" example:"Beijing, China"`
Website string `json:"website,omitempty" example:"https://johndoe.com"`
Birthday *time.Time `json:"birthday,omitempty" example:"1990-01-01T00:00:00Z"`
Gender string `json:"gender,omitempty" example:"male"`
Timezone string `json:"timezone,omitempty" example:"Asia/Shanghai"`
Language string `json:"language,omitempty" example:"zh-CN"`
}
// UserListRequest 用户列表请求
type UserListRequest struct {
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
PageSize int `form:"page_size" binding:"omitempty,min=1,max=100" example:"20"`
Sort string `form:"sort" binding:"omitempty,oneof=created_at updated_at username email" example:"created_at"`
Order string `form:"order" binding:"omitempty,oneof=asc desc" example:"desc"`
Status entities.UserStatus `form:"status" binding:"omitempty,oneof=active inactive suspended pending" example:"active"`
Search string `form:"search" binding:"omitempty,max=100" example:"john"`
Filters map[string]interface{} `form:"-"`
}
// UserListResponse 用户列表响应
type UserListResponse struct {
Users []*UserResponse `json:"users"`
Pagination PaginationMeta `json:"pagination"`
}
// PaginationMeta 分页元数据
type PaginationMeta struct {
Page int `json:"page" example:"1"`
PageSize int `json:"page_size" example:"20"`
Total int64 `json:"total" example:"100"`
TotalPages int `json:"total_pages" example:"5"`
HasNext bool `json:"has_next" example:"true"`
HasPrev bool `json:"has_prev" example:"false"`
}
// LoginRequest 登录请求
type LoginRequest struct {
Login string `json:"login" binding:"required" example:"john_doe"`
Password string `json:"password" binding:"required" example:"password123"`
}
// LoginResponse 登录响应
type LoginResponse struct {
User *UserResponse `json:"user"`
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
TokenType string `json:"token_type" example:"Bearer"`
ExpiresIn int64 `json:"expires_in" example:"86400"`
}
// UpdateProfileRequest 更新用户档案请求
type UpdateProfileRequest struct {
Bio *string `json:"bio,omitempty" binding:"omitempty,max=500" example:"Software Developer"`
Location *string `json:"location,omitempty" binding:"omitempty,max=100" example:"Beijing, China"`
Website *string `json:"website,omitempty" binding:"omitempty,url" example:"https://johndoe.com"`
Birthday *time.Time `json:"birthday,omitempty" example:"1990-01-01T00:00:00Z"`
Gender *string `json:"gender,omitempty" binding:"omitempty,oneof=male female other" example:"male"`
Timezone *string `json:"timezone,omitempty" binding:"omitempty,max=50" example:"Asia/Shanghai"`
Language *string `json:"language,omitempty" binding:"omitempty,max=10" example:"zh-CN"`
}
// UserStatsResponse 用户统计响应
type UserStatsResponse struct {
TotalUsers int64 `json:"total_users" example:"1000"`
ActiveUsers int64 `json:"active_users" example:"950"`
InactiveUsers int64 `json:"inactive_users" example:"30"`
SuspendedUsers int64 `json:"suspended_users" example:"20"`
NewUsersToday int64 `json:"new_users_today" example:"5"`
NewUsersWeek int64 `json:"new_users_week" example:"25"`
NewUsersMonth int64 `json:"new_users_month" example:"120"`
}
// UserSearchRequest 用户搜索请求
type UserSearchRequest struct {
Query string `form:"q" binding:"required,min=1,max=100" example:"john"`
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
PageSize int `form:"page_size" binding:"omitempty,min=1,max=50" example:"10"`
}
// 转换方法
func (r *CreateUserRequest) ToEntity() *entities.User {
return &entities.User{
Username: r.Username,
Email: r.Email,
Password: r.Password,
FirstName: r.FirstName,
LastName: r.LastName,
Phone: r.Phone,
Status: entities.UserStatusActive,
}
}
func FromEntity(user *entities.User) *UserResponse {
if user == nil {
return nil
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Email: user.Email,
FirstName: user.FirstName,
LastName: user.LastName,
Phone: user.Phone,
Avatar: user.Avatar,
Status: user.Status,
LastLoginAt: user.LastLoginAt,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
}
func FromEntities(users []*entities.User) []*UserResponse {
if users == nil {
return []*UserResponse{}
}
responses := make([]*UserResponse, len(users))
for i, user := range users {
responses[i] = FromEntity(user)
}
return responses
}

View File

@@ -0,0 +1,138 @@
package entities
import (
"time"
"gorm.io/gorm"
)
// User 用户实体
type User struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Username string `gorm:"uniqueIndex;type:varchar(50);not null" json:"username"`
Email string `gorm:"uniqueIndex;type:varchar(100);not null" json:"email"`
Password string `gorm:"type:varchar(255);not null" json:"-"`
FirstName string `gorm:"type:varchar(50)" json:"first_name"`
LastName string `gorm:"type:varchar(50)" json:"last_name"`
Phone string `gorm:"type:varchar(20)" json:"phone"`
Avatar string `gorm:"type:varchar(255)" json:"avatar"`
Status UserStatus `gorm:"type:varchar(20);default:'active'" json:"status"`
LastLoginAt *time.Time `json:"last_login_at"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 软删除字段
IsDeleted bool `gorm:"default:false" json:"is_deleted"`
// 版本控制
Version int `gorm:"default:1" json:"version"`
}
// UserStatus 用户状态枚举
type UserStatus string
const (
UserStatusActive UserStatus = "active"
UserStatusInactive UserStatus = "inactive"
UserStatusSuspended UserStatus = "suspended"
UserStatusPending UserStatus = "pending"
)
// 实现 Entity 接口
func (u *User) GetID() string {
return u.ID
}
func (u *User) GetCreatedAt() time.Time {
return u.CreatedAt
}
func (u *User) GetUpdatedAt() time.Time {
return u.UpdatedAt
}
// 业务方法
func (u *User) IsActive() bool {
return u.Status == UserStatusActive && !u.IsDeleted
}
func (u *User) GetFullName() string {
if u.FirstName == "" && u.LastName == "" {
return u.Username
}
return u.FirstName + " " + u.LastName
}
func (u *User) CanLogin() bool {
return u.IsActive() && u.Status != UserStatusSuspended
}
func (u *User) MarkAsDeleted() {
u.IsDeleted = true
u.Status = UserStatusInactive
}
func (u *User) Restore() {
u.IsDeleted = false
u.Status = UserStatusActive
}
func (u *User) UpdateLastLogin() {
now := time.Now()
u.LastLoginAt = &now
}
// 验证方法
func (u *User) Validate() error {
if u.Username == "" {
return NewValidationError("username is required")
}
if u.Email == "" {
return NewValidationError("email is required")
}
if u.Password == "" {
return NewValidationError("password is required")
}
return nil
}
// TableName 指定表名
func (User) TableName() string {
return "users"
}
// ValidationError 验证错误
type ValidationError struct {
Message string
}
func (e *ValidationError) Error() string {
return e.Message
}
func NewValidationError(message string) *ValidationError {
return &ValidationError{Message: message}
}
// UserProfile 用户档案(扩展信息)
type UserProfile struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
UserID string `gorm:"type:varchar(36);not null;index" json:"user_id"`
Bio string `gorm:"type:text" json:"bio"`
Location string `gorm:"type:varchar(100)" json:"location"`
Website string `gorm:"type:varchar(255)" json:"website"`
Birthday *time.Time `json:"birthday"`
Gender string `gorm:"type:varchar(10)" json:"gender"`
Timezone string `gorm:"type:varchar(50)" json:"timezone"`
Language string `gorm:"type:varchar(10);default:'zh-CN'" json:"language"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
// 关联关系
User *User `gorm:"foreignKey:UserID;references:ID" json:"user,omitempty"`
}
func (UserProfile) TableName() string {
return "user_profiles"
}

View File

@@ -0,0 +1,299 @@
package events
import (
"encoding/json"
"time"
"tyapi-server/internal/domains/user/entities"
"github.com/google/uuid"
)
// UserEventType 用户事件类型
type UserEventType string
const (
UserCreatedEvent UserEventType = "user.created"
UserUpdatedEvent UserEventType = "user.updated"
UserDeletedEvent UserEventType = "user.deleted"
UserRestoredEvent UserEventType = "user.restored"
UserLoggedInEvent UserEventType = "user.logged_in"
UserLoggedOutEvent UserEventType = "user.logged_out"
UserPasswordChangedEvent UserEventType = "user.password_changed"
UserStatusChangedEvent UserEventType = "user.status_changed"
UserProfileUpdatedEvent UserEventType = "user.profile_updated"
)
// BaseUserEvent 用户事件基础结构
type BaseUserEvent struct {
ID string `json:"id"`
Type string `json:"type"`
Version string `json:"version"`
Timestamp time.Time `json:"timestamp"`
Source string `json:"source"`
AggregateID string `json:"aggregate_id"`
AggregateType string `json:"aggregate_type"`
Metadata map[string]interface{} `json:"metadata"`
Payload interface{} `json:"payload"`
// DDD特有字段
DomainVersion string `json:"domain_version"`
CausationID string `json:"causation_id"`
CorrelationID string `json:"correlation_id"`
}
// 实现 Event 接口
func (e *BaseUserEvent) GetID() string {
return e.ID
}
func (e *BaseUserEvent) GetType() string {
return e.Type
}
func (e *BaseUserEvent) GetVersion() string {
return e.Version
}
func (e *BaseUserEvent) GetTimestamp() time.Time {
return e.Timestamp
}
func (e *BaseUserEvent) GetPayload() interface{} {
return e.Payload
}
func (e *BaseUserEvent) GetMetadata() map[string]interface{} {
return e.Metadata
}
func (e *BaseUserEvent) GetSource() string {
return e.Source
}
func (e *BaseUserEvent) GetAggregateID() string {
return e.AggregateID
}
func (e *BaseUserEvent) GetAggregateType() string {
return e.AggregateType
}
func (e *BaseUserEvent) GetDomainVersion() string {
return e.DomainVersion
}
func (e *BaseUserEvent) GetCausationID() string {
return e.CausationID
}
func (e *BaseUserEvent) GetCorrelationID() string {
return e.CorrelationID
}
func (e *BaseUserEvent) Marshal() ([]byte, error) {
return json.Marshal(e)
}
func (e *BaseUserEvent) Unmarshal(data []byte) error {
return json.Unmarshal(data, e)
}
// UserCreated 用户创建事件
type UserCreated struct {
*BaseUserEvent
User *entities.User `json:"user"`
}
func NewUserCreatedEvent(user *entities.User, correlationID string) *UserCreated {
return &UserCreated{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserCreatedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: user.ID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
},
},
User: user,
}
}
func (e *UserCreated) GetPayload() interface{} {
return e.User
}
// UserUpdated 用户更新事件
type UserUpdated struct {
*BaseUserEvent
UserID string `json:"user_id"`
Changes map[string]interface{} `json:"changes"`
OldValues map[string]interface{} `json:"old_values"`
NewValues map[string]interface{} `json:"new_values"`
}
func NewUserUpdatedEvent(userID string, changes, oldValues, newValues map[string]interface{}, correlationID string) *UserUpdated {
return &UserUpdated{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserUpdatedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"changed_fields": len(changes),
},
},
UserID: userID,
Changes: changes,
OldValues: oldValues,
NewValues: newValues,
}
}
// UserDeleted 用户删除事件
type UserDeleted struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
SoftDelete bool `json:"soft_delete"`
}
func NewUserDeletedEvent(userID, username, email string, softDelete bool, correlationID string) *UserDeleted {
return &UserDeleted{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserDeletedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"email": email,
"soft_delete": softDelete,
},
},
UserID: userID,
Username: username,
Email: email,
SoftDelete: softDelete,
}
}
// UserLoggedIn 用户登录事件
type UserLoggedIn struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
}
func NewUserLoggedInEvent(userID, username, ipAddress, userAgent, correlationID string) *UserLoggedIn {
return &UserLoggedIn{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserLoggedInEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"ip_address": ipAddress,
"user_agent": userAgent,
},
},
UserID: userID,
Username: username,
IPAddress: ipAddress,
UserAgent: userAgent,
}
}
// UserPasswordChanged 用户密码修改事件
type UserPasswordChanged struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
}
func NewUserPasswordChangedEvent(userID, username, correlationID string) *UserPasswordChanged {
return &UserPasswordChanged{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserPasswordChangedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
},
},
UserID: userID,
Username: username,
}
}
// UserStatusChanged 用户状态变更事件
type UserStatusChanged struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
OldStatus entities.UserStatus `json:"old_status"`
NewStatus entities.UserStatus `json:"new_status"`
}
func NewUserStatusChangedEvent(userID, username string, oldStatus, newStatus entities.UserStatus, correlationID string) *UserStatusChanged {
return &UserStatusChanged{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserStatusChangedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"old_status": oldStatus,
"new_status": newStatus,
},
},
UserID: userID,
Username: username,
OldStatus: oldStatus,
NewStatus: newStatus,
}
}

View File

@@ -0,0 +1,455 @@
package handlers
import (
"strconv"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/services"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/middleware"
)
// UserHandler 用户HTTP处理器
type UserHandler struct {
userService *services.UserService
response interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
jwtAuth *middleware.JWTAuthMiddleware
}
// NewUserHandler 创建用户处理器
func NewUserHandler(
userService *services.UserService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
) *UserHandler {
return &UserHandler{
userService: userService,
response: response,
validator: validator,
logger: logger,
jwtAuth: jwtAuth,
}
}
// GetPath 返回处理器路径
func (h *UserHandler) GetPath() string {
return "/users"
}
// GetMethod 返回HTTP方法
func (h *UserHandler) GetMethod() string {
return "GET" // 主要用于列表,具体方法在路由注册时指定
}
// GetMiddlewares 返回中间件
func (h *UserHandler) GetMiddlewares() []gin.HandlerFunc {
return []gin.HandlerFunc{
// 这里可以添加特定的中间件
}
}
// Handle 主处理函数(用于列表)
func (h *UserHandler) Handle(c *gin.Context) {
h.List(c)
}
// RequiresAuth 是否需要认证
func (h *UserHandler) RequiresAuth() bool {
return true
}
// GetPermissions 获取所需权限
func (h *UserHandler) GetPermissions() []string {
return []string{"user:read"}
}
// REST操作实现
// Create 创建用户
func (h *UserHandler) Create(c *gin.Context) {
var req dto.CreateUserRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return // 响应已在验证器中处理
}
// 创建用户
user, err := h.userService.Create(c.Request.Context(), &req)
if err != nil {
h.logger.Error("Failed to create user", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Created(c, response, "User created successfully")
}
// GetByID 根据ID获取用户
func (h *UserHandler) GetByID(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "User ID is required")
return
}
// 获取用户
user, err := h.userService.GetByID(c.Request.Context(), id)
if err != nil {
h.logger.Error("Failed to get user", zap.Error(err))
h.response.NotFound(c, "User not found")
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response)
}
// Update 更新用户
func (h *UserHandler) Update(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "User ID is required")
return
}
var req dto.UpdateUserRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return
}
// 更新用户
user, err := h.userService.Update(c.Request.Context(), id, &req)
if err != nil {
h.logger.Error("Failed to update user", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response, "User updated successfully")
}
// Delete 删除用户
func (h *UserHandler) Delete(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "User ID is required")
return
}
// 删除用户
if err := h.userService.Delete(c.Request.Context(), id); err != nil {
h.logger.Error("Failed to delete user", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
h.response.Success(c, nil, "User deleted successfully")
}
// List 获取用户列表
func (h *UserHandler) List(c *gin.Context) {
var req dto.UserListRequest
// 验证查询参数
if err := h.validator.ValidateQuery(c, &req); err != nil {
return
}
// 设置默认值
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
// 构建查询选项
options := interfaces.ListOptions{
Page: req.Page,
PageSize: req.PageSize,
Sort: req.Sort,
Order: req.Order,
Search: req.Search,
Filters: req.Filters,
}
// 获取用户列表
users, err := h.userService.List(c.Request.Context(), options)
if err != nil {
h.logger.Error("Failed to get user list", zap.Error(err))
h.response.InternalError(c, "Failed to get user list")
return
}
// 获取总数
countOptions := interfaces.CountOptions{
Search: req.Search,
Filters: req.Filters,
}
total, err := h.userService.Count(c.Request.Context(), countOptions)
if err != nil {
h.logger.Error("Failed to count users", zap.Error(err))
h.response.InternalError(c, "Failed to count users")
return
}
// 构建响应
userResponses := dto.FromEntities(users)
pagination := buildPagination(req.Page, req.PageSize, total)
h.response.Paginated(c, userResponses, pagination)
}
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
var req dto.LoginRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return
}
// 用户登录
user, err := h.userService.Login(c.Request.Context(), &req)
if err != nil {
h.logger.Error("Login failed", zap.Error(err))
h.response.Unauthorized(c, "Invalid credentials")
return
}
// 生成JWT token
accessToken, err := h.jwtAuth.GenerateToken(user.ID, user.Username, user.Email)
if err != nil {
h.logger.Error("Failed to generate token", zap.Error(err))
h.response.InternalError(c, "Failed to generate access token")
return
}
// 构建登录响应
loginResponse := &dto.LoginResponse{
User: dto.FromEntity(user),
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: 86400, // 24小时从配置获取
}
h.response.Success(c, loginResponse, "Login successful")
}
// Logout 用户登出
func (h *UserHandler) Logout(c *gin.Context) {
// 简单实现客户端删除token即可
// 如果需要服务端黑名单,可以在这里实现
h.response.Success(c, nil, "Logout successful")
}
// GetProfile 获取当前用户信息
func (h *UserHandler) GetProfile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "User not authenticated")
return
}
// 获取用户信息
user, err := h.userService.GetByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error("Failed to get user profile", zap.Error(err))
h.response.NotFound(c, "User not found")
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response)
}
// UpdateProfile 更新当前用户信息
func (h *UserHandler) UpdateProfile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "User not authenticated")
return
}
var req dto.UpdateUserRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return
}
// 更新用户
user, err := h.userService.Update(c.Request.Context(), userID, &req)
if err != nil {
h.logger.Error("Failed to update profile", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response, "Profile updated successfully")
}
// ChangePassword 修改密码
func (h *UserHandler) ChangePassword(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "User not authenticated")
return
}
var req dto.ChangePasswordRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return
}
// 修改密码
if err := h.userService.ChangePassword(c.Request.Context(), userID, &req); err != nil {
h.logger.Error("Failed to change password", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "Password changed successfully")
}
// Search 搜索用户
func (h *UserHandler) Search(c *gin.Context) {
var req dto.UserSearchRequest
// 验证查询参数
if err := h.validator.ValidateQuery(c, &req); err != nil {
return
}
// 设置默认值
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
// 构建查询选项
options := interfaces.ListOptions{
Page: req.Page,
PageSize: req.PageSize,
Search: req.Query,
}
// 搜索用户
users, err := h.userService.Search(c.Request.Context(), req.Query, options)
if err != nil {
h.logger.Error("Failed to search users", zap.Error(err))
h.response.InternalError(c, "Failed to search users")
return
}
// 获取搜索结果总数
countOptions := interfaces.CountOptions{
Search: req.Query,
}
total, err := h.userService.Count(c.Request.Context(), countOptions)
if err != nil {
h.logger.Error("Failed to count search results", zap.Error(err))
h.response.InternalError(c, "Failed to count search results")
return
}
// 构建响应
userResponses := dto.FromEntities(users)
pagination := buildPagination(req.Page, req.PageSize, total)
h.response.Paginated(c, userResponses, pagination)
}
// GetStats 获取用户统计
func (h *UserHandler) GetStats(c *gin.Context) {
stats, err := h.userService.GetStats(c.Request.Context())
if err != nil {
h.logger.Error("Failed to get user stats", zap.Error(err))
h.response.InternalError(c, "Failed to get user statistics")
return
}
h.response.Success(c, stats)
}
// 私有方法
// getCurrentUserID 获取当前用户ID
func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// parsePageSize 解析页面大小
func (h *UserHandler) parsePageSize(str string, defaultValue int) int {
if str == "" {
return defaultValue
}
if size, err := strconv.Atoi(str); err == nil && size > 0 && size <= 100 {
return size
}
return defaultValue
}
// parsePage 解析页码
func (h *UserHandler) parsePage(str string, defaultValue int) int {
if str == "" {
return defaultValue
}
if page, err := strconv.Atoi(str); err == nil && page > 0 {
return page
}
return defaultValue
}
// buildPagination 构建分页元数据
func buildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
totalPages := int(float64(total) / float64(pageSize))
if float64(total)/float64(pageSize) > float64(totalPages) {
totalPages++
}
if totalPages < 1 {
totalPages = 1
}
return interfaces.PaginationMeta{
Page: page,
PageSize: pageSize,
Total: total,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
}

View File

@@ -0,0 +1,339 @@
package repositories
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/shared/interfaces"
)
// UserRepository 用户仓储实现
type UserRepository struct {
db *gorm.DB
cache interfaces.CacheService
logger *zap.Logger
}
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *UserRepository {
return &UserRepository{
db: db,
cache: cache,
logger: logger,
}
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, entity *entities.User) error {
if err := r.db.WithContext(ctx).Create(entity).Error; err != nil {
r.logger.Error("Failed to create user", zap.Error(err))
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, entity.ID)
return nil
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entities.User, error) {
// 先尝试从缓存获取
cacheKey := r.GetCacheKey(id)
var user entities.User
if err := r.cache.Get(ctx, cacheKey, &user); err == nil {
return &user, nil
}
// 从数据库获取
if err := r.db.WithContext(ctx).Where("id = ? AND is_deleted = false", id).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("user not found")
}
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, &user, 1*time.Hour)
return &user, nil
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, entity *entities.User) error {
if err := r.db.WithContext(ctx).Save(entity).Error; err != nil {
r.logger.Error("Failed to update user", zap.Error(err))
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, entity.ID)
return nil
}
// Delete 删除用户
func (r *UserRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("Failed to delete user", zap.Error(err))
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, id)
return nil
}
// CreateBatch 批量创建用户
func (r *UserRepository) CreateBatch(ctx context.Context, entities []*entities.User) error {
if err := r.db.WithContext(ctx).CreateInBatches(entities, 100).Error; err != nil {
r.logger.Error("Failed to create users in batch", zap.Error(err))
return err
}
// 清除列表缓存
r.cache.DeletePattern(ctx, "users:list:*")
return nil
}
// GetByIDs 根据ID列表获取用户
func (r *UserRepository) GetByIDs(ctx context.Context, ids []string) ([]*entities.User, error) {
var users []entities.User
if err := r.db.WithContext(ctx).
Where("id IN ? AND is_deleted = false", ids).
Find(&users).Error; err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.User, len(users))
for i := range users {
result[i] = &users[i]
}
return result, nil
}
// UpdateBatch 批量更新用户
func (r *UserRepository) UpdateBatch(ctx context.Context, entities []*entities.User) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, entity := range entities {
if err := tx.Save(entity).Error; err != nil {
return err
}
}
return nil
})
}
// DeleteBatch 批量删除用户
func (r *UserRepository) DeleteBatch(ctx context.Context, ids []string) error {
if err := r.db.WithContext(ctx).
Where("id IN ?", ids).
Delete(&entities.User{}).Error; err != nil {
return err
}
// 清除相关缓存
for _, id := range ids {
r.invalidateUserCaches(ctx, id)
}
return nil
}
// List 获取用户列表
func (r *UserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]*entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("users:list:%d:%d:%s", options.Page, options.PageSize, options.Sort)
var users []*entities.User
if err := r.cache.Get(ctx, cacheKey, &users); err == nil {
return users, nil
}
// 从数据库查询
query := r.db.WithContext(ctx).Where("is_deleted = false")
// 应用过滤条件
if options.Search != "" {
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := options.Order
if order == "" {
order = "asc"
}
query = query.Order(fmt.Sprintf("%s %s", options.Sort, order))
} else {
query = query.Order("created_at desc")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
var userEntities []entities.User
if err := query.Find(&userEntities).Error; err != nil {
return nil, err
}
// 转换为指针切片
users = make([]*entities.User, len(userEntities))
for i := range userEntities {
users[i] = &userEntities[i]
}
// 缓存结果
r.cache.Set(ctx, cacheKey, users, 30*time.Minute)
return users, nil
}
// Count 统计用户数量
func (r *UserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
query := r.db.WithContext(ctx).Model(&entities.User{}).Where("is_deleted = false")
// 应用过滤条件
if options.Search != "" {
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
var count int64
if err := query.Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
// Exists 检查用户是否存在
func (r *UserRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ? AND is_deleted = false", id).
Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
// SoftDelete 软删除用户
func (r *UserRepository) SoftDelete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", id).
Update("is_deleted", true).Error; err != nil {
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, id)
return nil
}
// Restore 恢复用户
func (r *UserRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", id).
Update("is_deleted", false).Error; err != nil {
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, id)
return nil
}
// WithTx 使用事务
func (r *UserRepository) WithTx(tx interface{}) interfaces.Repository[*entities.User] {
gormTx, ok := tx.(*gorm.DB)
if !ok {
return r
}
return &UserRepository{
db: gormTx,
cache: r.cache,
logger: r.logger,
}
}
// InvalidateCache 清除缓存
func (r *UserRepository) InvalidateCache(ctx context.Context, keys ...string) error {
return r.cache.Delete(ctx, keys...)
}
// WarmupCache 预热缓存
func (r *UserRepository) WarmupCache(ctx context.Context) error {
// 预热热门用户数据
// 这里可以实现具体的预热逻辑
return nil
}
// GetCacheKey 获取缓存键
func (r *UserRepository) GetCacheKey(id string) string {
return fmt.Sprintf("user:%s", id)
}
// FindByUsername 根据用户名查找用户
func (r *UserRepository) FindByUsername(ctx context.Context, username string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).
Where("username = ? AND is_deleted = false", username).
First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("user not found")
}
return nil, err
}
return &user, nil
}
// FindByEmail 根据邮箱查找用户
func (r *UserRepository) FindByEmail(ctx context.Context, email string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).
Where("email = ? AND is_deleted = false", email).
First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("user not found")
}
return nil, err
}
return &user, nil
}
// invalidateUserCaches 清除用户相关缓存
func (r *UserRepository) invalidateUserCaches(ctx context.Context, userID string) {
keys := []string{
r.GetCacheKey(userID),
}
r.cache.Delete(ctx, keys...)
r.cache.DeletePattern(ctx, "users:list:*")
}

View File

@@ -0,0 +1,133 @@
package routes
import (
"tyapi-server/internal/domains/user/handlers"
"tyapi-server/internal/shared/middleware"
"github.com/gin-gonic/gin"
)
// UserRoutes 用户路由注册器
type UserRoutes struct {
handler *handlers.UserHandler
jwtAuth *middleware.JWTAuthMiddleware
optionalAuth *middleware.OptionalAuthMiddleware
}
// NewUserRoutes 创建用户路由注册器
func NewUserRoutes(
handler *handlers.UserHandler,
jwtAuth *middleware.JWTAuthMiddleware,
optionalAuth *middleware.OptionalAuthMiddleware,
) *UserRoutes {
return &UserRoutes{
handler: handler,
jwtAuth: jwtAuth,
optionalAuth: optionalAuth,
}
}
// RegisterRoutes 注册用户路由
func (r *UserRoutes) RegisterRoutes(router *gin.Engine) {
// API版本组
v1 := router.Group("/api/v1")
// 公开路由(不需要认证)
public := v1.Group("/auth")
{
public.POST("/login", r.handler.Login)
public.POST("/register", r.handler.Create)
}
// 需要认证的路由
protected := v1.Group("/users")
protected.Use(r.jwtAuth.Handle())
{
// 用户管理(管理员)
protected.GET("", r.handler.List)
protected.POST("", r.handler.Create)
protected.GET("/:id", r.handler.GetByID)
protected.PUT("/:id", r.handler.Update)
protected.DELETE("/:id", r.handler.Delete)
// 用户搜索
protected.GET("/search", r.handler.Search)
// 用户统计
protected.GET("/stats", r.handler.GetStats)
}
// 用户个人操作路由
profile := v1.Group("/profile")
profile.Use(r.jwtAuth.Handle())
{
profile.GET("", r.handler.GetProfile)
profile.PUT("", r.handler.UpdateProfile)
profile.POST("/change-password", r.handler.ChangePassword)
profile.POST("/logout", r.handler.Logout)
}
}
// RegisterPublicRoutes 注册公开路由
func (r *UserRoutes) RegisterPublicRoutes(router *gin.Engine) {
v1 := router.Group("/api/v1")
// 公开的用户相关路由
public := v1.Group("/public")
{
// 可选认证的路由(用户可能登录也可能未登录)
public.Use(r.optionalAuth.Handle())
// 这里可以添加一些公开的用户信息查询接口
// 比如根据用户名查看公开信息(如果用户设置为公开)
}
}
// RegisterAdminRoutes 注册管理员路由
func (r *UserRoutes) RegisterAdminRoutes(router *gin.Engine) {
admin := router.Group("/admin/v1")
admin.Use(r.jwtAuth.Handle())
// 这里可以添加管理员权限检查中间件
// 管理员用户管理
users := admin.Group("/users")
{
users.GET("", r.handler.List)
users.GET("/:id", r.handler.GetByID)
users.PUT("/:id", r.handler.Update)
users.DELETE("/:id", r.handler.Delete)
users.GET("/stats", r.handler.GetStats)
users.GET("/search", r.handler.Search)
// 批量操作
users.POST("/batch-delete", r.handleBatchDelete)
users.POST("/batch-update", r.handleBatchUpdate)
}
}
// 批量删除处理器
func (r *UserRoutes) handleBatchDelete(c *gin.Context) {
// 实现批量删除逻辑
// 这里可以接收用户ID列表并调用服务进行批量删除
c.JSON(200, gin.H{"message": "Batch delete not implemented yet"})
}
// 批量更新处理器
func (r *UserRoutes) handleBatchUpdate(c *gin.Context) {
// 实现批量更新逻辑
c.JSON(200, gin.H{"message": "Batch update not implemented yet"})
}
// RegisterHealthRoutes 注册健康检查路由
func (r *UserRoutes) RegisterHealthRoutes(router *gin.Engine) {
health := router.Group("/health")
{
health.GET("/users", func(c *gin.Context) {
// 用户服务健康检查
c.JSON(200, gin.H{
"service": "users",
"status": "healthy",
})
})
}
}

View File

@@ -0,0 +1,469 @@
package services
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/events"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/shared/interfaces"
)
// UserService 用户服务实现
type UserService struct {
repo *repositories.UserRepository
eventBus interfaces.EventBus
logger *zap.Logger
}
// NewUserService 创建用户服务
func NewUserService(
repo *repositories.UserRepository,
eventBus interfaces.EventBus,
logger *zap.Logger,
) *UserService {
return &UserService{
repo: repo,
eventBus: eventBus,
logger: logger,
}
}
// Name 返回服务名称
func (s *UserService) Name() string {
return "user-service"
}
// Initialize 初始化服务
func (s *UserService) Initialize(ctx context.Context) error {
s.logger.Info("User service initialized")
return nil
}
// HealthCheck 健康检查
func (s *UserService) HealthCheck(ctx context.Context) error {
// 简单检查:尝试查询用户数量
_, err := s.repo.Count(ctx, interfaces.CountOptions{})
return err
}
// Shutdown 关闭服务
func (s *UserService) Shutdown(ctx context.Context) error {
s.logger.Info("User service shutdown")
return nil
}
// Create 创建用户
func (s *UserService) Create(ctx context.Context, createDTO interface{}) (*entities.User, error) {
req, ok := createDTO.(*dto.CreateUserRequest)
if !ok {
return nil, fmt.Errorf("invalid DTO type for user creation")
}
// 验证业务规则
if err := s.ValidateCreate(ctx, req); err != nil {
return nil, err
}
// 检查用户名和邮箱是否已存在
if err := s.checkDuplicates(ctx, req.Username, req.Email); err != nil {
return nil, err
}
// 创建用户实体
user := req.ToEntity()
user.ID = uuid.New().String()
// 加密密码
hashedPassword, err := s.hashPassword(req.Password)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
}
user.Password = hashedPassword
// 保存用户
if err := s.repo.Create(ctx, user); err != nil {
s.logger.Error("Failed to create user", zap.Error(err))
return nil, fmt.Errorf("failed to create user: %w", err)
}
// 发布用户创建事件
event := events.NewUserCreatedEvent(user, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user created event", zap.Error(err))
}
s.logger.Info("User created successfully",
zap.String("user_id", user.ID),
zap.String("username", user.Username))
return user, nil
}
// GetByID 根据ID获取用户
func (s *UserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
if id == "" {
return nil, fmt.Errorf("user ID is required")
}
user, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("user not found: %w", err)
}
return user, nil
}
// Update 更新用户
func (s *UserService) Update(ctx context.Context, id string, updateDTO interface{}) (*entities.User, error) {
req, ok := updateDTO.(*dto.UpdateUserRequest)
if !ok {
return nil, fmt.Errorf("invalid DTO type for user update")
}
// 验证业务规则
if err := s.ValidateUpdate(ctx, id, req); err != nil {
return nil, err
}
// 获取现有用户
user, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("user not found: %w", err)
}
// 记录变更前的值
oldValues := s.captureUserValues(user)
// 应用更新
s.applyUserUpdates(user, req)
// 保存更新
if err := s.repo.Update(ctx, user); err != nil {
s.logger.Error("Failed to update user", zap.Error(err))
return nil, fmt.Errorf("failed to update user: %w", err)
}
// 发布用户更新事件
newValues := s.captureUserValues(user)
changes := s.findChanges(oldValues, newValues)
if len(changes) > 0 {
event := events.NewUserUpdatedEvent(user.ID, changes, oldValues, newValues, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user updated event", zap.Error(err))
}
}
s.logger.Info("User updated successfully",
zap.String("user_id", user.ID),
zap.Int("changes", len(changes)))
return user, nil
}
// Delete 删除用户
func (s *UserService) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("user ID is required")
}
// 获取用户信息用于事件
user, err := s.repo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("user not found: %w", err)
}
// 软删除用户
if err := s.repo.SoftDelete(ctx, id); err != nil {
s.logger.Error("Failed to delete user", zap.Error(err))
return fmt.Errorf("failed to delete user: %w", err)
}
// 发布用户删除事件
event := events.NewUserDeletedEvent(user.ID, user.Username, user.Email, true, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user deleted event", zap.Error(err))
}
s.logger.Info("User deleted successfully", zap.String("user_id", id))
return nil
}
// List 获取用户列表
func (s *UserService) List(ctx context.Context, options interfaces.ListOptions) ([]*entities.User, error) {
return s.repo.List(ctx, options)
}
// Search 搜索用户
func (s *UserService) Search(ctx context.Context, query string, options interfaces.ListOptions) ([]*entities.User, error) {
// 设置搜索关键字
searchOptions := options
searchOptions.Search = query
return s.repo.List(ctx, searchOptions)
}
// Count 统计用户数量
func (s *UserService) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
return s.repo.Count(ctx, options)
}
// Validate 验证用户实体
func (s *UserService) Validate(ctx context.Context, entity *entities.User) error {
return entity.Validate()
}
// ValidateCreate 验证创建请求
func (s *UserService) ValidateCreate(ctx context.Context, createDTO interface{}) error {
req, ok := createDTO.(*dto.CreateUserRequest)
if !ok {
return fmt.Errorf("invalid DTO type")
}
// 基础验证已经由binding标签处理这里添加业务规则验证
if req.Username == "admin" || req.Username == "root" {
return fmt.Errorf("username '%s' is reserved", req.Username)
}
return nil
}
// ValidateUpdate 验证更新请求
func (s *UserService) ValidateUpdate(ctx context.Context, id string, updateDTO interface{}) error {
_, ok := updateDTO.(*dto.UpdateUserRequest)
if !ok {
return fmt.Errorf("invalid DTO type")
}
if id == "" {
return fmt.Errorf("user ID is required")
}
return nil
}
// 业务方法
// Login 用户登录
func (s *UserService) Login(ctx context.Context, loginReq *dto.LoginRequest) (*entities.User, error) {
// 根据用户名或邮箱查找用户
var user *entities.User
var err error
if s.isEmail(loginReq.Login) {
user, err = s.repo.FindByEmail(ctx, loginReq.Login)
} else {
user, err = s.repo.FindByUsername(ctx, loginReq.Login)
}
if err != nil {
return nil, fmt.Errorf("invalid credentials")
}
// 验证密码
if !s.checkPassword(loginReq.Password, user.Password) {
return nil, fmt.Errorf("invalid credentials")
}
// 检查用户状态
if !user.CanLogin() {
return nil, fmt.Errorf("account is disabled or suspended")
}
// 更新最后登录时间
user.UpdateLastLogin()
if err := s.repo.Update(ctx, user); err != nil {
s.logger.Warn("Failed to update last login time", zap.Error(err))
}
// 发布登录事件
event := events.NewUserLoggedInEvent(
user.ID, user.Username,
s.getClientIP(ctx), s.getUserAgent(ctx),
s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user logged in event", zap.Error(err))
}
s.logger.Info("User logged in successfully",
zap.String("user_id", user.ID),
zap.String("username", user.Username))
return user, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
// 获取用户
user, err := s.repo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("user not found: %w", err)
}
// 验证旧密码
if !s.checkPassword(req.OldPassword, user.Password) {
return fmt.Errorf("current password is incorrect")
}
// 加密新密码
hashedPassword, err := s.hashPassword(req.NewPassword)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
}
// 更新密码
user.Password = hashedPassword
if err := s.repo.Update(ctx, user); err != nil {
return fmt.Errorf("failed to update password: %w", err)
}
// 发布密码修改事件
event := events.NewUserPasswordChangedEvent(user.ID, user.Username, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish password changed event", zap.Error(err))
}
s.logger.Info("Password changed successfully", zap.String("user_id", userID))
return nil
}
// GetStats 获取用户统计
func (s *UserService) GetStats(ctx context.Context) (*dto.UserStatsResponse, error) {
total, err := s.repo.Count(ctx, interfaces.CountOptions{})
if err != nil {
return nil, err
}
// 这里可以并行查询不同状态的用户数量
// 简化实现,返回基础统计
return &dto.UserStatsResponse{
TotalUsers: total,
ActiveUsers: total, // 简化
InactiveUsers: 0,
SuspendedUsers: 0,
NewUsersToday: 0,
NewUsersWeek: 0,
NewUsersMonth: 0,
}, nil
}
// 私有方法
// checkDuplicates 检查重复的用户名和邮箱
func (s *UserService) checkDuplicates(ctx context.Context, username, email string) error {
// 检查用户名
if existingUser, err := s.repo.FindByUsername(ctx, username); err == nil && existingUser != nil {
return fmt.Errorf("username already exists")
}
// 检查邮箱
if existingUser, err := s.repo.FindByEmail(ctx, email); err == nil && existingUser != nil {
return fmt.Errorf("email already exists")
}
return nil
}
// hashPassword 加密密码
func (s *UserService) hashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hash), nil
}
// checkPassword 验证密码
func (s *UserService) checkPassword(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// isEmail 检查是否为邮箱格式
func (s *UserService) isEmail(str string) bool {
return len(str) > 0 && len(str) < 255 &&
len(str) > 5 &&
str[len(str)-4:] != ".." &&
(len(str) > 6 && str[len(str)-4:] == ".com") ||
(len(str) > 5 && str[len(str)-3:] == ".cn") ||
(len(str) > 6 && str[len(str)-4:] == ".org") ||
(len(str) > 6 && str[len(str)-4:] == ".net")
// 简化的邮箱检查,实际应该使用正则表达式
}
// applyUserUpdates 应用用户更新
func (s *UserService) applyUserUpdates(user *entities.User, req *dto.UpdateUserRequest) {
if req.FirstName != nil {
user.FirstName = *req.FirstName
}
if req.LastName != nil {
user.LastName = *req.LastName
}
if req.Phone != nil {
user.Phone = *req.Phone
}
if req.Avatar != nil {
user.Avatar = *req.Avatar
}
user.UpdatedAt = time.Now()
}
// captureUserValues 捕获用户值用于变更比较
func (s *UserService) captureUserValues(user *entities.User) map[string]interface{} {
return map[string]interface{}{
"first_name": user.FirstName,
"last_name": user.LastName,
"phone": user.Phone,
"avatar": user.Avatar,
}
}
// findChanges 找出变更的字段
func (s *UserService) findChanges(oldValues, newValues map[string]interface{}) map[string]interface{} {
changes := make(map[string]interface{})
for key, newValue := range newValues {
if oldValue, exists := oldValues[key]; !exists || oldValue != newValue {
changes[key] = newValue
}
}
return changes
}
// getCorrelationID 获取关联ID
func (s *UserService) getCorrelationID(ctx context.Context) string {
if id := ctx.Value("correlation_id"); id != nil {
if correlationID, ok := id.(string); ok {
return correlationID
}
}
return uuid.New().String()
}
// getClientIP 获取客户端IP
func (s *UserService) getClientIP(ctx context.Context) string {
if ip := ctx.Value("client_ip"); ip != nil {
if clientIP, ok := ip.(string); ok {
return clientIP
}
}
return "unknown"
}
// getUserAgent 获取用户代理
func (s *UserService) getUserAgent(ctx context.Context) string {
if ua := ctx.Value("user_agent"); ua != nil {
if userAgent, ok := ua.(string); ok {
return userAgent
}
}
return "unknown"
}

284
internal/shared/cache/redis_cache.go vendored Normal file
View File

@@ -0,0 +1,284 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
logger *zap.Logger
prefix string
// 统计信息
hits int64
misses int64
}
// NewRedisCache 创建Redis缓存实例
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
return &RedisCache{
client: client,
logger: logger,
prefix: prefix,
}
}
// Name 返回服务名称
func (r *RedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (r *RedisCache) Initialize(ctx context.Context) error {
// 测试连接
_, err := r.client.Ping(ctx).Result()
if err != nil {
r.logger.Error("Failed to connect to Redis", zap.Error(err))
return fmt.Errorf("redis connection failed: %w", err)
}
r.logger.Info("Redis cache service initialized")
return nil
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
_, err := r.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (r *RedisCache) Shutdown(ctx context.Context) error {
return r.client.Close()
}
// Get 获取缓存值
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
fullKey := r.getFullKey(key)
val, err := r.client.Get(ctx, fullKey).Result()
if err != nil {
if err == redis.Nil {
r.misses++
return fmt.Errorf("cache miss: key %s not found", key)
}
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
return err
}
r.hits++
return json.Unmarshal([]byte(val), dest)
}
// Set 设置缓存值
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
fullKey := r.getFullKey(key)
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
case string:
expiration, _ = time.ParseDuration(v)
default:
expiration = 24 * time.Hour // 默认24小时
}
} else {
expiration = 24 * time.Hour // 默认24小时
}
err = r.client.Set(ctx, fullKey, data, expiration).Err()
if err != nil {
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
return err
}
return nil
}
// Delete 删除缓存
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
err := r.client.Del(ctx, fullKeys...).Err()
if err != nil {
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
return err
}
return nil
}
// Exists 检查键是否存在
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
fullKey := r.getFullKey(key)
count, err := r.client.Exists(ctx, fullKey).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// GetMultiple 批量获取
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
values, err := r.client.MGet(ctx, fullKeys...).Result()
if err != nil {
return nil, err
}
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
var data interface{}
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
result[keys[i]] = data
}
}
}
return result, nil
}
// SetMultiple 批量设置
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
if len(data) == 0 {
return nil
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
default:
expiration = 24 * time.Hour
}
} else {
expiration = 24 * time.Hour
}
pipe := r.client.Pipeline()
for key, value := range data {
fullKey := r.getFullKey(key)
jsonData, err := json.Marshal(value)
if err != nil {
continue
}
pipe.Set(ctx, fullKey, jsonData, expiration)
}
_, err := pipe.Exec(ctx)
return err
}
// DeletePattern 按模式删除
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return err
}
if len(keys) > 0 {
return r.client.Del(ctx, keys...).Err()
}
return nil
}
// Keys 获取匹配的键
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return nil, err
}
// 移除前缀
result := make([]string, len(keys))
prefixLen := len(r.prefix) + 1 // +1 for ":"
for i, key := range keys {
if len(key) > prefixLen {
result[i] = key[prefixLen:]
} else {
result[i] = key
}
}
return result, nil
}
// Stats 获取缓存统计
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
dbSize, _ := r.client.DBSize(ctx).Result()
return interfaces.CacheStats{
Hits: r.hits,
Misses: r.misses,
Keys: dbSize,
Memory: 0, // 暂时设为0后续可解析Redis info
Connections: 0, // 暂时设为0后续可解析Redis info
}, nil
}
// getFullKey 获取完整键名
func (r *RedisCache) getFullKey(key string) string {
if r.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", r.prefix, key)
}
// Flush 清空所有缓存
func (r *RedisCache) Flush(ctx context.Context) error {
if r.prefix == "" {
return r.client.FlushDB(ctx).Err()
}
// 只删除带前缀的键
return r.DeletePattern(ctx, "*")
}
// GetClient 获取原始Redis客户端
func (r *RedisCache) GetClient() *redis.Client {
return r.client
}

View File

@@ -0,0 +1,195 @@
package database
import (
"context"
"fmt"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// Config 数据库配置
type Config struct {
Host string
Port string
User string
Password string
Name string
SSLMode string
Timezone string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
}
// DB 数据库包装器
type DB struct {
*gorm.DB
config Config
}
// NewConnection 创建新的数据库连接
func NewConnection(config Config) (*DB, error) {
// 构建DSN
dsn := buildDSN(config)
// 配置GORM
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 使用单数表名
},
DisableForeignKeyConstraintWhenMigrating: true,
}
// 连接数据库
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
// 获取底层sql.DB
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 配置连接池
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
return &DB{
DB: db,
config: config,
}, nil
}
// buildDSN 构建数据库连接字符串
func buildDSN(config Config) string {
return fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s",
config.Host,
config.User,
config.Password,
config.Name,
config.Port,
config.SSLMode,
config.Timezone,
)
}
// Close 关闭数据库连接
func (db *DB) Close() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Ping 检查数据库连接
func (db *DB) Ping() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}
// GetStats 获取连接池统计信息
func (db *DB) GetStats() (map[string]interface{}, error) {
sqlDB, err := db.DB.DB()
if err != nil {
return nil, err
}
stats := sqlDB.Stats()
return map[string]interface{}{
"max_open_connections": stats.MaxOpenConnections,
"open_connections": stats.OpenConnections,
"in_use": stats.InUse,
"idle": stats.Idle,
"wait_count": stats.WaitCount,
"wait_duration": stats.WaitDuration,
"max_idle_closed": stats.MaxIdleClosed,
"max_idle_time_closed": stats.MaxIdleTimeClosed,
"max_lifetime_closed": stats.MaxLifetimeClosed,
}, nil
}
// BeginTx 开始事务
func (db *DB) BeginTx() *gorm.DB {
return db.DB.Begin()
}
// Migrate 执行数据库迁移
func (db *DB) Migrate(models ...interface{}) error {
return db.DB.AutoMigrate(models...)
}
// IsHealthy 检查数据库健康状态
func (db *DB) IsHealthy() bool {
return db.Ping() == nil
}
// WithContext 返回带上下文的数据库实例
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
if c, ok := ctx.(context.Context); ok {
return db.DB.WithContext(c)
}
return db.DB
}
// 事务包装器
type TxWrapper struct {
tx *gorm.DB
}
// NewTxWrapper 创建事务包装器
func (db *DB) NewTxWrapper() *TxWrapper {
return &TxWrapper{
tx: db.BeginTx(),
}
}
// Commit 提交事务
func (tx *TxWrapper) Commit() error {
return tx.tx.Commit().Error
}
// Rollback 回滚事务
func (tx *TxWrapper) Rollback() error {
return tx.tx.Rollback().Error
}
// GetDB 获取事务数据库实例
func (tx *TxWrapper) GetDB() *gorm.DB {
return tx.tx
}
// WithTx 在事务中执行函数
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
tx := db.BeginTx()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}

View File

@@ -0,0 +1,313 @@
package events
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// MemoryEventBus 内存事件总线实现
type MemoryEventBus struct {
subscribers map[string][]interfaces.EventHandler
mutex sync.RWMutex
logger *zap.Logger
running bool
stopCh chan struct{}
eventQueue chan eventTask
workerCount int
}
// eventTask 事件任务
type eventTask struct {
event interfaces.Event
handler interfaces.EventHandler
retries int
}
// NewMemoryEventBus 创建内存事件总线
func NewMemoryEventBus(logger *zap.Logger, workerCount int) *MemoryEventBus {
if workerCount <= 0 {
workerCount = 5 // 默认5个工作协程
}
return &MemoryEventBus{
subscribers: make(map[string][]interfaces.EventHandler),
logger: logger,
eventQueue: make(chan eventTask, 1000), // 缓冲1000个事件
workerCount: workerCount,
stopCh: make(chan struct{}),
}
}
// Name 返回服务名称
func (bus *MemoryEventBus) Name() string {
return "memory-event-bus"
}
// Initialize 初始化服务
func (bus *MemoryEventBus) Initialize(ctx context.Context) error {
bus.logger.Info("Memory event bus service initialized")
return nil
}
// HealthCheck 健康检查
func (bus *MemoryEventBus) HealthCheck(ctx context.Context) error {
if !bus.running {
return fmt.Errorf("event bus is not running")
}
return nil
}
// Shutdown 关闭服务
func (bus *MemoryEventBus) Shutdown(ctx context.Context) error {
bus.Stop(ctx)
return nil
}
// Start 启动事件总线
func (bus *MemoryEventBus) Start(ctx context.Context) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
if bus.running {
return nil
}
bus.running = true
// 启动工作协程
for i := 0; i < bus.workerCount; i++ {
go bus.worker(i)
}
bus.logger.Info("Event bus started", zap.Int("workers", bus.workerCount))
return nil
}
// Stop 停止事件总线
func (bus *MemoryEventBus) Stop(ctx context.Context) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
if !bus.running {
return nil
}
bus.running = false
close(bus.stopCh)
// 等待所有工作协程结束或超时
done := make(chan struct{})
go func() {
time.Sleep(5 * time.Second) // 给工作协程5秒时间结束
close(done)
}()
select {
case <-done:
case <-ctx.Done():
}
bus.logger.Info("Event bus stopped")
return nil
}
// Publish 发布事件(同步)
func (bus *MemoryEventBus) Publish(ctx context.Context, event interfaces.Event) error {
bus.mutex.RLock()
handlers := bus.subscribers[event.GetType()]
bus.mutex.RUnlock()
if len(handlers) == 0 {
bus.logger.Debug("No handlers for event type", zap.String("type", event.GetType()))
return nil
}
for _, handler := range handlers {
if handler.IsAsync() {
// 异步处理
select {
case bus.eventQueue <- eventTask{event: event, handler: handler, retries: 0}:
default:
bus.logger.Warn("Event queue is full, dropping event",
zap.String("type", event.GetType()),
zap.String("handler", handler.GetName()))
}
} else {
// 同步处理
if err := bus.handleEventWithRetry(ctx, event, handler); err != nil {
bus.logger.Error("Failed to handle event synchronously",
zap.String("type", event.GetType()),
zap.String("handler", handler.GetName()),
zap.Error(err))
}
}
}
return nil
}
// PublishBatch 批量发布事件
func (bus *MemoryEventBus) PublishBatch(ctx context.Context, events []interfaces.Event) error {
for _, event := range events {
if err := bus.Publish(ctx, event); err != nil {
return err
}
}
return nil
}
// Subscribe 订阅事件
func (bus *MemoryEventBus) Subscribe(eventType string, handler interfaces.EventHandler) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
handlers := bus.subscribers[eventType]
// 检查是否已经订阅
for _, h := range handlers {
if h.GetName() == handler.GetName() {
return fmt.Errorf("handler %s already subscribed to event type %s", handler.GetName(), eventType)
}
}
bus.subscribers[eventType] = append(handlers, handler)
bus.logger.Info("Handler subscribed to event",
zap.String("handler", handler.GetName()),
zap.String("event_type", eventType))
return nil
}
// Unsubscribe 取消订阅
func (bus *MemoryEventBus) Unsubscribe(eventType string, handler interfaces.EventHandler) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
handlers := bus.subscribers[eventType]
for i, h := range handlers {
if h.GetName() == handler.GetName() {
// 删除处理器
bus.subscribers[eventType] = append(handlers[:i], handlers[i+1:]...)
bus.logger.Info("Handler unsubscribed from event",
zap.String("handler", handler.GetName()),
zap.String("event_type", eventType))
return nil
}
}
return fmt.Errorf("handler %s not found for event type %s", handler.GetName(), eventType)
}
// GetSubscribers 获取订阅者
func (bus *MemoryEventBus) GetSubscribers(eventType string) []interfaces.EventHandler {
bus.mutex.RLock()
defer bus.mutex.RUnlock()
handlers := bus.subscribers[eventType]
result := make([]interfaces.EventHandler, len(handlers))
copy(result, handlers)
return result
}
// worker 工作协程
func (bus *MemoryEventBus) worker(id int) {
bus.logger.Debug("Event worker started", zap.Int("worker_id", id))
for {
select {
case task := <-bus.eventQueue:
bus.processEventTask(task)
case <-bus.stopCh:
bus.logger.Debug("Event worker stopped", zap.Int("worker_id", id))
return
}
}
}
// processEventTask 处理事件任务
func (bus *MemoryEventBus) processEventTask(task eventTask) {
ctx := context.Background()
err := bus.handleEventWithRetry(ctx, task.event, task.handler)
if err != nil {
retryConfig := task.handler.GetRetryConfig()
if task.retries < retryConfig.MaxRetries {
// 重试
delay := time.Duration(float64(retryConfig.RetryDelay) *
(1 + retryConfig.BackoffFactor*float64(task.retries)))
if delay > retryConfig.MaxDelay {
delay = retryConfig.MaxDelay
}
go func() {
time.Sleep(delay)
task.retries++
select {
case bus.eventQueue <- task:
default:
bus.logger.Error("Failed to requeue event for retry",
zap.String("type", task.event.GetType()),
zap.String("handler", task.handler.GetName()),
zap.Int("retries", task.retries))
}
}()
} else {
bus.logger.Error("Event processing failed after max retries",
zap.String("type", task.event.GetType()),
zap.String("handler", task.handler.GetName()),
zap.Int("retries", task.retries),
zap.Error(err))
}
}
}
// handleEventWithRetry 处理事件并支持重试
func (bus *MemoryEventBus) handleEventWithRetry(ctx context.Context, event interfaces.Event, handler interfaces.EventHandler) error {
start := time.Now()
defer func() {
duration := time.Since(start)
bus.logger.Debug("Event handled",
zap.String("type", event.GetType()),
zap.String("handler", handler.GetName()),
zap.Duration("duration", duration))
}()
return handler.Handle(ctx, event)
}
// GetStats 获取事件总线统计信息
func (bus *MemoryEventBus) GetStats() map[string]interface{} {
bus.mutex.RLock()
defer bus.mutex.RUnlock()
stats := map[string]interface{}{
"running": bus.running,
"worker_count": bus.workerCount,
"queue_length": len(bus.eventQueue),
"queue_capacity": cap(bus.eventQueue),
"event_types": len(bus.subscribers),
}
// 各事件类型的订阅者数量
eventTypes := make(map[string]int)
for eventType, handlers := range bus.subscribers {
eventTypes[eventType] = len(handlers)
}
stats["subscribers"] = eventTypes
return stats
}

View File

@@ -0,0 +1,282 @@
package health
import (
"context"
"fmt"
"sync"
"time"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
)
// HealthChecker 健康检查器实现
type HealthChecker struct {
services map[string]interfaces.Service
cache map[string]*interfaces.HealthStatus
cacheTTL time.Duration
mutex sync.RWMutex
logger *zap.Logger
}
// NewHealthChecker 创建健康检查器
func NewHealthChecker(logger *zap.Logger) *HealthChecker {
return &HealthChecker{
services: make(map[string]interfaces.Service),
cache: make(map[string]*interfaces.HealthStatus),
cacheTTL: 30 * time.Second, // 缓存30秒
logger: logger,
}
}
// RegisterService 注册服务
func (h *HealthChecker) RegisterService(service interfaces.Service) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.services[service.Name()] = service
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
}
// CheckHealth 检查单个服务健康状态
func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *interfaces.HealthStatus {
h.mutex.RLock()
service, exists := h.services[serviceName]
if !exists {
h.mutex.RUnlock()
return &interfaces.HealthStatus{
Status: "DOWN",
Message: "Service not found",
Details: map[string]interface{}{"error": "service not registered"},
CheckedAt: time.Now().Unix(),
ResponseTime: 0,
}
}
// 检查缓存
if cached, exists := h.cache[serviceName]; exists {
if time.Since(time.Unix(cached.CheckedAt, 0)) < h.cacheTTL {
h.mutex.RUnlock()
return cached
}
}
h.mutex.RUnlock()
// 执行健康检查
start := time.Now()
status := &interfaces.HealthStatus{
CheckedAt: start.Unix(),
}
// 设置超时上下文
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
err := service.HealthCheck(checkCtx)
responseTime := time.Since(start).Milliseconds()
status.ResponseTime = responseTime
if err != nil {
status.Status = "DOWN"
status.Message = "Health check failed"
status.Details = map[string]interface{}{
"error": err.Error(),
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Warn("Service health check failed",
zap.String("service", serviceName),
zap.Error(err),
zap.Int64("response_time_ms", responseTime))
} else {
status.Status = "UP"
status.Message = "Service is healthy"
status.Details = map[string]interface{}{
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Debug("Service health check passed",
zap.String("service", serviceName),
zap.Int64("response_time_ms", responseTime))
}
// 更新缓存
h.mutex.Lock()
h.cache[serviceName] = status
h.mutex.Unlock()
return status
}
// CheckAllHealth 检查所有服务的健康状态
func (h *HealthChecker) CheckAllHealth(ctx context.Context) map[string]*interfaces.HealthStatus {
h.mutex.RLock()
serviceNames := make([]string, 0, len(h.services))
for name := range h.services {
serviceNames = append(serviceNames, name)
}
h.mutex.RUnlock()
results := make(map[string]*interfaces.HealthStatus)
var wg sync.WaitGroup
var mutex sync.Mutex
// 并发检查所有服务
for _, serviceName := range serviceNames {
wg.Add(1)
go func(name string) {
defer wg.Done()
status := h.CheckHealth(ctx, name)
mutex.Lock()
results[name] = status
mutex.Unlock()
}(serviceName)
}
wg.Wait()
return results
}
// GetOverallStatus 获取整体健康状态
func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.HealthStatus {
allStatus := h.CheckAllHealth(ctx)
overall := &interfaces.HealthStatus{
CheckedAt: time.Now().Unix(),
ResponseTime: 0,
Details: make(map[string]interface{}),
}
var totalResponseTime int64
healthyCount := 0
totalCount := len(allStatus)
for serviceName, status := range allStatus {
overall.Details[serviceName] = map[string]interface{}{
"status": status.Status,
"message": status.Message,
"response_time": status.ResponseTime,
}
totalResponseTime += status.ResponseTime
if status.Status == "UP" {
healthyCount++
}
}
if totalCount > 0 {
overall.ResponseTime = totalResponseTime / int64(totalCount)
}
// 确定整体状态
if healthyCount == totalCount {
overall.Status = "UP"
overall.Message = "All services are healthy"
} else if healthyCount == 0 {
overall.Status = "DOWN"
overall.Message = "All services are down"
} else {
overall.Status = "DEGRADED"
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
}
return overall
}
// GetServiceNames 获取所有注册的服务名称
func (h *HealthChecker) GetServiceNames() []string {
h.mutex.RLock()
defer h.mutex.RUnlock()
names := make([]string, 0, len(h.services))
for name := range h.services {
names = append(names, name)
}
return names
}
// RemoveService 移除服务
func (h *HealthChecker) RemoveService(serviceName string) {
h.mutex.Lock()
defer h.mutex.Unlock()
delete(h.services, serviceName)
delete(h.cache, serviceName)
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
}
// ClearCache 清除缓存
func (h *HealthChecker) ClearCache() {
h.mutex.Lock()
defer h.mutex.Unlock()
h.cache = make(map[string]*interfaces.HealthStatus)
h.logger.Debug("Health check cache cleared")
}
// GetCacheStats 获取缓存统计
func (h *HealthChecker) GetCacheStats() map[string]interface{} {
h.mutex.RLock()
defer h.mutex.RUnlock()
stats := map[string]interface{}{
"total_services": len(h.services),
"cached_results": len(h.cache),
"cache_ttl_seconds": h.cacheTTL.Seconds(),
}
// 计算缓存命中率
if len(h.services) > 0 {
hitRate := float64(len(h.cache)) / float64(len(h.services)) * 100
stats["cache_hit_rate"] = fmt.Sprintf("%.2f%%", hitRate)
}
return stats
}
// SetCacheTTL 设置缓存TTL
func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.cacheTTL = ttl
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
}
// StartPeriodicCheck 启动定期健康检查
func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
for {
select {
case <-ctx.Done():
h.logger.Info("Stopped periodic health check")
return
case <-ticker.C:
h.performPeriodicCheck(ctx)
}
}
}
// performPeriodicCheck 执行定期检查
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
overall := h.GetOverallStatus(ctx)
h.logger.Info("Periodic health check completed",
zap.String("overall_status", overall.Status),
zap.String("message", overall.Message),
zap.Int64("response_time_ms", overall.ResponseTime))
// 如果有服务下线,记录警告
if overall.Status != "UP" {
h.logger.Warn("Some services are not healthy",
zap.String("status", overall.Status),
zap.Any("details", overall.Details))
}
}

View File

@@ -0,0 +1,260 @@
package http
import (
"math"
"net/http"
"time"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
)
// ResponseBuilder 响应构建器实现
type ResponseBuilder struct{}
// NewResponseBuilder 创建响应构建器
func NewResponseBuilder() interfaces.ResponseBuilder {
return &ResponseBuilder{}
}
// Success 成功响应
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
msg := "Success"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: true,
Message: msg,
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusOK, response)
}
// Created 创建成功响应
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
msg := "Created successfully"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: true,
Message: msg,
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusCreated, response)
}
// Error 错误响应
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
// 根据错误类型确定状态码
statusCode := http.StatusInternalServerError
message := "Internal server error"
errorDetail := err.Error()
// 这里可以根据不同的错误类型设置不同的状态码
// 例如ValidationError -> 400, NotFoundError -> 404, etc.
response := interfaces.APIResponse{
Success: false,
Message: message,
Errors: errorDetail,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(statusCode, response)
}
// BadRequest 400错误响应
func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...interface{}) {
response := interfaces.APIResponse{
Success: false,
Message: message,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
if len(errors) > 0 {
response.Errors = errors[0]
}
c.JSON(http.StatusBadRequest, response)
}
// Unauthorized 401错误响应
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
msg := "Unauthorized"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusUnauthorized, response)
}
// Forbidden 403错误响应
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
msg := "Forbidden"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusForbidden, response)
}
// NotFound 404错误响应
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
msg := "Resource not found"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusNotFound, response)
}
// Conflict 409错误响应
func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
response := interfaces.APIResponse{
Success: false,
Message: message,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusConflict, response)
}
// InternalError 500错误响应
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
msg := "Internal server error"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusInternalServerError, response)
}
// Paginated 分页响应
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
response := interfaces.APIResponse{
Success: true,
Message: "Success",
Data: data,
Pagination: &pagination,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusOK, response)
}
// getRequestID 从上下文获取请求ID
func (r *ResponseBuilder) getRequestID(c *gin.Context) string {
if requestID, exists := c.Get("request_id"); exists {
if id, ok := requestID.(string); ok {
return id
}
}
return ""
}
// BuildPagination 构建分页元数据
func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
totalPages := int(math.Ceil(float64(total) / float64(pageSize)))
if totalPages < 1 {
totalPages = 1
}
return interfaces.PaginationMeta{
Page: page,
PageSize: pageSize,
Total: total,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
}
// CustomResponse 自定义响应
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
response := interfaces.APIResponse{
Success: statusCode >= 200 && statusCode < 300,
Message: http.StatusText(statusCode),
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(statusCode, response)
}
// ValidationError 验证错误响应
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
response := interfaces.APIResponse{
Success: false,
Message: "Validation failed",
Errors: errors,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusUnprocessableEntity, response)
}
// TooManyRequests 限流错误响应
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
msg := "Too many requests"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
Meta: map[string]interface{}{
"retry_after": "60s",
},
}
c.JSON(http.StatusTooManyRequests, response)
}

View File

@@ -0,0 +1,258 @@
package http
import (
"context"
"fmt"
"net/http"
"sort"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
)
// GinRouter Gin路由器实现
type GinRouter struct {
engine *gin.Engine
config *config.Config
logger *zap.Logger
middlewares []interfaces.Middleware
server *http.Server
}
// NewGinRouter 创建Gin路由器
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *GinRouter {
// 设置Gin模式
if cfg.App.IsProduction() {
gin.SetMode(gin.ReleaseMode)
} else {
gin.SetMode(gin.DebugMode)
}
// 创建Gin引擎
engine := gin.New()
return &GinRouter{
engine: engine,
config: cfg,
logger: logger,
middlewares: make([]interfaces.Middleware, 0),
}
}
// RegisterHandler 注册处理器
func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
// 应用处理器中间件
middlewares := handler.GetMiddlewares()
// 注册路由
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
r.logger.Info("Registered HTTP handler",
zap.String("method", handler.GetMethod()),
zap.String("path", handler.GetPath()))
return nil
}
// RegisterMiddleware 注册中间件
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
r.middlewares = append(r.middlewares, middleware)
r.logger.Info("Registered middleware",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
return nil
}
// RegisterGroup 注册路由组
func (r *GinRouter) RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes {
return r.engine.Group(prefix, middlewares...)
}
// GetRoutes 获取路由信息
func (r *GinRouter) GetRoutes() gin.RoutesInfo {
return r.engine.Routes()
}
// Start 启动路由器
func (r *GinRouter) Start(addr string) error {
// 应用中间件(按优先级排序)
r.applyMiddlewares()
// 创建HTTP服务器
r.server = &http.Server{
Addr: addr,
Handler: r.engine,
ReadTimeout: r.config.Server.ReadTimeout,
WriteTimeout: r.config.Server.WriteTimeout,
IdleTimeout: r.config.Server.IdleTimeout,
}
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
// 启动服务器
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("failed to start server: %w", err)
}
return nil
}
// Stop 停止路由器
func (r *GinRouter) Stop(ctx context.Context) error {
if r.server == nil {
return nil
}
r.logger.Info("Stopping HTTP server...")
// 优雅关闭服务器
if err := r.server.Shutdown(ctx); err != nil {
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
return err
}
r.logger.Info("HTTP server stopped")
return nil
}
// GetEngine 获取Gin引擎
func (r *GinRouter) GetEngine() *gin.Engine {
return r.engine
}
// applyMiddlewares 应用中间件
func (r *GinRouter) applyMiddlewares() {
// 按优先级排序中间件
sort.Slice(r.middlewares, func(i, j int) bool {
return r.middlewares[i].GetPriority() > r.middlewares[j].GetPriority()
})
// 应用全局中间件
for _, middleware := range r.middlewares {
if middleware.IsGlobal() {
r.engine.Use(middleware.Handle())
r.logger.Debug("Applied global middleware",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
}
}
}
// SetupDefaultRoutes 设置默认路由
func (r *GinRouter) SetupDefaultRoutes() {
// 健康检查
r.engine.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp": time.Now().Unix(),
"service": r.config.App.Name,
"version": r.config.App.Version,
})
})
// API信息
r.engine.GET("/info", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"name": r.config.App.Name,
"version": r.config.App.Version,
"environment": r.config.App.Env,
"timestamp": time.Now().Unix(),
})
})
// 404处理
r.engine.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Route not found",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),
})
})
// 405处理
r.engine.NoMethod(func(c *gin.Context) {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"success": false,
"message": "Method not allowed",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),
})
})
}
// PrintRoutes 打印路由信息
func (r *GinRouter) PrintRoutes() {
routes := r.GetRoutes()
r.logger.Info("Registered routes:")
for _, route := range routes {
r.logger.Info("Route",
zap.String("method", route.Method),
zap.String("path", route.Path),
zap.String("handler", route.Handler))
}
}
// GetStats 获取路由器统计信息
func (r *GinRouter) GetStats() map[string]interface{} {
routes := r.GetRoutes()
stats := map[string]interface{}{
"total_routes": len(routes),
"total_middlewares": len(r.middlewares),
"server_config": map[string]interface{}{
"read_timeout": r.config.Server.ReadTimeout,
"write_timeout": r.config.Server.WriteTimeout,
"idle_timeout": r.config.Server.IdleTimeout,
},
}
// 按方法统计路由数量
methodStats := make(map[string]int)
for _, route := range routes {
methodStats[route.Method]++
}
stats["routes_by_method"] = methodStats
// 中间件统计
middlewareStats := make([]map[string]interface{}, 0, len(r.middlewares))
for _, middleware := range r.middlewares {
middlewareStats = append(middlewareStats, map[string]interface{}{
"name": middleware.GetName(),
"priority": middleware.GetPriority(),
"global": middleware.IsGlobal(),
})
}
stats["middlewares"] = middlewareStats
return stats
}
// EnableMetrics 启用指标收集
func (r *GinRouter) EnableMetrics(collector interfaces.MetricsCollector) {
r.engine.Use(func(c *gin.Context) {
start := time.Now()
c.Next()
duration := time.Since(start).Seconds()
collector.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
})
}
// EnableProfiling 启用性能分析
func (r *GinRouter) EnableProfiling() {
if r.config.Development.EnableProfiler {
// 这里可以集成pprof
r.logger.Info("Profiling enabled")
}
}

View File

@@ -0,0 +1,273 @@
package http
import (
"fmt"
"strings"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
// RequestValidator 请求验证器实现
type RequestValidator struct {
validator *validator.Validate
response interfaces.ResponseBuilder
}
// NewRequestValidator 创建请求验证器
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
v := validator.New()
// 注册自定义验证器
registerCustomValidators(v)
return &RequestValidator{
validator: v,
response: response,
}
}
// Validate 验证请求体
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
return err
}
return nil
}
// ValidateQuery 验证查询参数
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
v.response.BadRequest(c, "Invalid query parameters", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
return err
}
return nil
}
// ValidateParam 验证路径参数
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
v.response.BadRequest(c, "Invalid path parameters", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
return err
}
return nil
}
// BindAndValidate 绑定并验证请求
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
// 绑定请求体
if err := c.ShouldBindJSON(dto); err != nil {
v.response.BadRequest(c, "Invalid request body", err.Error())
return err
}
// 验证数据
return v.Validate(c, dto)
}
// formatValidationErrors 格式化验证错误
func (v *RequestValidator) formatValidationErrors(err error) map[string][]string {
errors := make(map[string][]string)
if validationErrors, ok := err.(validator.ValidationErrors); ok {
for _, fieldError := range validationErrors {
fieldName := v.getFieldName(fieldError)
errorMessage := v.getErrorMessage(fieldError)
if _, exists := errors[fieldName]; !exists {
errors[fieldName] = []string{}
}
errors[fieldName] = append(errors[fieldName], errorMessage)
}
}
return errors
}
// getFieldName 获取字段名JSON标签优先
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
// 可以通过反射获取JSON标签这里简化处理
fieldName := fieldError.Field()
// 转换为snake_case可选
return v.toSnakeCase(fieldName)
}
// getErrorMessage 获取错误消息
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
field := fieldError.Field()
tag := fieldError.Tag()
param := fieldError.Param()
switch tag {
case "required":
return fmt.Sprintf("%s is required", field)
case "email":
return fmt.Sprintf("%s must be a valid email address", field)
case "min":
return fmt.Sprintf("%s must be at least %s characters", field, param)
case "max":
return fmt.Sprintf("%s must be at most %s characters", field, param)
case "len":
return fmt.Sprintf("%s must be exactly %s characters", field, param)
case "gt":
return fmt.Sprintf("%s must be greater than %s", field, param)
case "gte":
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
case "lt":
return fmt.Sprintf("%s must be less than %s", field, param)
case "lte":
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
case "oneof":
return fmt.Sprintf("%s must be one of [%s]", field, param)
case "url":
return fmt.Sprintf("%s must be a valid URL", field)
case "alpha":
return fmt.Sprintf("%s must contain only alphabetic characters", field)
case "alphanum":
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
case "numeric":
return fmt.Sprintf("%s must be numeric", field)
case "phone":
return fmt.Sprintf("%s must be a valid phone number", field)
case "username":
return fmt.Sprintf("%s must be a valid username", field)
default:
return fmt.Sprintf("%s is invalid", field)
}
}
// toSnakeCase 转换为snake_case
func (v *RequestValidator) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if i > 0 && (r >= 'A' && r <= 'Z') {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// registerCustomValidators 注册自定义验证器
func registerCustomValidators(v *validator.Validate) {
// 注册手机号验证器
v.RegisterValidation("phone", validatePhone)
// 注册用户名验证器
v.RegisterValidation("username", validateUsername)
// 注册密码强度验证器
v.RegisterValidation("strong_password", validateStrongPassword)
}
// validatePhone 验证手机号
func validatePhone(fl validator.FieldLevel) bool {
phone := fl.Field().String()
if phone == "" {
return true // 空值由required标签处理
}
// 简单的手机号验证(可根据需要完善)
if len(phone) < 10 || len(phone) > 15 {
return false
}
// 检查是否以+开头或全是数字
if strings.HasPrefix(phone, "+") {
phone = phone[1:]
}
for _, r := range phone {
if r < '0' || r > '9' {
if r != '-' && r != ' ' && r != '(' && r != ')' {
return false
}
}
}
return true
}
// validateUsername 验证用户名
func validateUsername(fl validator.FieldLevel) bool {
username := fl.Field().String()
if username == "" {
return true // 空值由required标签处理
}
// 用户名规则3-30个字符只能包含字母、数字、下划线不能以数字开头
if len(username) < 3 || len(username) > 30 {
return false
}
// 不能以数字开头
if username[0] >= '0' && username[0] <= '9' {
return false
}
// 只能包含字母、数字、下划线
for _, r := range username {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
return false
}
}
return true
}
// validateStrongPassword 验证密码强度
func validateStrongPassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
if password == "" {
return true // 空值由required标签处理
}
// 密码强度规则至少8个字符包含大小写字母、数字
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
for _, r := range password {
switch {
case r >= 'A' && r <= 'Z':
hasUpper = true
case r >= 'a' && r <= 'z':
hasLower = true
case r >= '0' && r <= '9':
hasDigit = true
}
}
return hasUpper && hasLower && hasDigit
}
// ValidateStruct 直接验证结构体不通过HTTP上下文
func (v *RequestValidator) ValidateStruct(dto interface{}) error {
return v.validator.Struct(dto)
}
// GetValidator 获取原始验证器(用于特殊情况)
func (v *RequestValidator) GetValidator() *validator.Validate {
return v.validator
}

View File

@@ -0,0 +1,92 @@
package interfaces
import (
"context"
"time"
)
// Event 事件接口
type Event interface {
// 事件基础信息
GetID() string
GetType() string
GetVersion() string
GetTimestamp() time.Time
// 事件数据
GetPayload() interface{}
GetMetadata() map[string]interface{}
// 事件来源
GetSource() string
GetAggregateID() string
GetAggregateType() string
// 序列化
Marshal() ([]byte, error)
Unmarshal(data []byte) error
}
// EventHandler 事件处理器接口
type EventHandler interface {
// 处理器标识
GetName() string
GetEventTypes() []string
// 事件处理
Handle(ctx context.Context, event Event) error
// 处理器配置
IsAsync() bool
GetRetryConfig() RetryConfig
}
// DomainEvent 领域事件基础接口
type DomainEvent interface {
Event
// 领域特定信息
GetDomainVersion() string
GetCausationID() string
GetCorrelationID() string
}
// RetryConfig 重试配置
type RetryConfig struct {
MaxRetries int `json:"max_retries"`
RetryDelay time.Duration `json:"retry_delay"`
BackoffFactor float64 `json:"backoff_factor"`
MaxDelay time.Duration `json:"max_delay"`
}
// EventStore 事件存储接口
type EventStore interface {
// 事件存储
SaveEvent(ctx context.Context, event Event) error
SaveEvents(ctx context.Context, events []Event) error
// 事件查询
GetEvents(ctx context.Context, aggregateID string, fromVersion int) ([]Event, error)
GetEventsByType(ctx context.Context, eventType string, limit int) ([]Event, error)
GetEventsSince(ctx context.Context, timestamp time.Time, limit int) ([]Event, error)
// 快照支持
SaveSnapshot(ctx context.Context, aggregateID string, snapshot interface{}) error
GetSnapshot(ctx context.Context, aggregateID string) (interface{}, error)
}
// EventBus 事件总线接口
type EventBus interface {
// 事件发布
Publish(ctx context.Context, event Event) error
PublishBatch(ctx context.Context, events []Event) error
// 事件订阅
Subscribe(eventType string, handler EventHandler) error
Unsubscribe(eventType string, handler EventHandler) error
// 订阅管理
GetSubscribers(eventType string) []EventHandler
Start(ctx context.Context) error
Stop(ctx context.Context) error
}

View File

@@ -0,0 +1,152 @@
package interfaces
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
)
// HTTPHandler HTTP处理器接口
type HTTPHandler interface {
// 处理器信息
GetPath() string
GetMethod() string
GetMiddlewares() []gin.HandlerFunc
// 处理函数
Handle(c *gin.Context)
// 权限验证
RequiresAuth() bool
GetPermissions() []string
}
// RESTHandler REST风格处理器接口
type RESTHandler interface {
HTTPHandler
// CRUD操作
Create(c *gin.Context)
GetByID(c *gin.Context)
Update(c *gin.Context)
Delete(c *gin.Context)
List(c *gin.Context)
}
// Middleware 中间件接口
type Middleware interface {
// 中间件名称
GetName() string
// 中间件优先级
GetPriority() int
// 中间件处理函数
Handle() gin.HandlerFunc
// 是否全局中间件
IsGlobal() bool
}
// Router 路由器接口
type Router interface {
// 路由注册
RegisterHandler(handler HTTPHandler) error
RegisterMiddleware(middleware Middleware) error
RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes
// 路由管理
GetRoutes() gin.RoutesInfo
Start(addr string) error
Stop(ctx context.Context) error
// 引擎获取
GetEngine() *gin.Engine
}
// ResponseBuilder 响应构建器接口
type ResponseBuilder interface {
// 成功响应
Success(c *gin.Context, data interface{}, message ...string)
Created(c *gin.Context, data interface{}, message ...string)
// 错误响应
Error(c *gin.Context, err error)
BadRequest(c *gin.Context, message string, errors ...interface{})
Unauthorized(c *gin.Context, message ...string)
Forbidden(c *gin.Context, message ...string)
NotFound(c *gin.Context, message ...string)
Conflict(c *gin.Context, message string)
InternalError(c *gin.Context, message ...string)
// 分页响应
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
}
// RequestValidator 请求验证器接口
type RequestValidator interface {
// 验证请求
Validate(c *gin.Context, dto interface{}) error
ValidateQuery(c *gin.Context, dto interface{}) error
ValidateParam(c *gin.Context, dto interface{}) error
// 绑定和验证
BindAndValidate(c *gin.Context, dto interface{}) error
}
// PaginationMeta 分页元数据
type PaginationMeta struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int64 `json:"total"`
TotalPages int `json:"total_pages"`
HasNext bool `json:"has_next"`
HasPrev bool `json:"has_prev"`
}
// APIResponse 标准API响应结构
type APIResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
Errors interface{} `json:"errors,omitempty"`
Pagination *PaginationMeta `json:"pagination,omitempty"`
Meta map[string]interface{} `json:"meta,omitempty"`
RequestID string `json:"request_id"`
Timestamp int64 `json:"timestamp"`
}
// HealthChecker 健康检查器接口
type HealthChecker interface {
// 健康检查
CheckHealth(ctx context.Context) HealthStatus
GetName() string
GetDependencies() []string
}
// HealthStatus 健康状态
type HealthStatus struct {
Status string `json:"status"` // UP, DOWN, DEGRADED
Message string `json:"message"`
Details map[string]interface{} `json:"details"`
CheckedAt int64 `json:"checked_at"`
ResponseTime int64 `json:"response_time_ms"`
}
// MetricsCollector 指标收集器接口
type MetricsCollector interface {
// HTTP指标
RecordHTTPRequest(method, path string, status int, duration float64)
RecordHTTPDuration(method, path string, duration float64)
// 业务指标
IncrementCounter(name string, labels map[string]string)
RecordGauge(name string, value float64, labels map[string]string)
RecordHistogram(name string, value float64, labels map[string]string)
// 自定义指标
RegisterCounter(name, help string, labels []string) error
RegisterGauge(name, help string, labels []string) error
RegisterHistogram(name, help string, labels []string, buckets []float64) error
// 指标导出
GetHandler() http.Handler
}

View File

@@ -0,0 +1,74 @@
package interfaces
import (
"context"
"time"
)
// Entity 通用实体接口
type Entity interface {
GetID() string
GetCreatedAt() time.Time
GetUpdatedAt() time.Time
}
// BaseRepository 基础仓储接口
type BaseRepository interface {
// 基础操作
Delete(ctx context.Context, id string) error
Count(ctx context.Context, options CountOptions) (int64, error)
Exists(ctx context.Context, id string) (bool, error)
// 软删除支持
SoftDelete(ctx context.Context, id string) error
Restore(ctx context.Context, id string) error
}
// Repository 通用仓储接口,支持泛型
type Repository[T any] interface {
BaseRepository
// 基础CRUD操作
Create(ctx context.Context, entity T) error
GetByID(ctx context.Context, id string) (T, error)
Update(ctx context.Context, entity T) error
// 批量操作
CreateBatch(ctx context.Context, entities []T) error
GetByIDs(ctx context.Context, ids []string) ([]T, error)
UpdateBatch(ctx context.Context, entities []T) error
DeleteBatch(ctx context.Context, ids []string) error
// 查询操作
List(ctx context.Context, options ListOptions) ([]T, error)
// 事务支持
WithTx(tx interface{}) Repository[T]
}
// ListOptions 列表查询选项
type ListOptions struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Sort string `json:"sort"`
Order string `json:"order"`
Filters map[string]interface{} `json:"filters"`
Search string `json:"search"`
Include []string `json:"include"`
}
// CountOptions 计数查询选项
type CountOptions struct {
Filters map[string]interface{} `json:"filters"`
Search string `json:"search"`
}
// CachedRepository 支持缓存的仓储接口
type CachedRepository[T Entity] interface {
Repository[T]
// 缓存操作
InvalidateCache(ctx context.Context, keys ...string) error
WarmupCache(ctx context.Context) error
GetCacheKey(id string) string
}

View File

@@ -0,0 +1,101 @@
package interfaces
import (
"context"
)
// Service 通用服务接口
type Service interface {
// 服务名称
Name() string
// 服务初始化
Initialize(ctx context.Context) error
// 服务健康检查
HealthCheck(ctx context.Context) error
// 服务关闭
Shutdown(ctx context.Context) error
}
// DomainService 领域服务接口,支持泛型
type DomainService[T Entity] interface {
Service
// 基础业务操作
Create(ctx context.Context, dto interface{}) (*T, error)
GetByID(ctx context.Context, id string) (*T, error)
Update(ctx context.Context, id string, dto interface{}) (*T, error)
Delete(ctx context.Context, id string) error
// 列表和查询
List(ctx context.Context, options ListOptions) ([]*T, error)
Search(ctx context.Context, query string, options ListOptions) ([]*T, error)
Count(ctx context.Context, options CountOptions) (int64, error)
// 业务规则验证
Validate(ctx context.Context, entity *T) error
ValidateCreate(ctx context.Context, dto interface{}) error
ValidateUpdate(ctx context.Context, id string, dto interface{}) error
}
// EventService 事件服务接口
type EventService interface {
Service
// 事件发布
Publish(ctx context.Context, event Event) error
PublishBatch(ctx context.Context, events []Event) error
// 事件订阅
Subscribe(eventType string, handler EventHandler) error
Unsubscribe(eventType string, handler EventHandler) error
// 异步处理
PublishAsync(ctx context.Context, event Event) error
}
// CacheService 缓存服务接口
type CacheService interface {
Service
// 基础缓存操作
Get(ctx context.Context, key string, dest interface{}) error
Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error
Delete(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
// 批量操作
GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error)
SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error
// 模式操作
DeletePattern(ctx context.Context, pattern string) error
Keys(ctx context.Context, pattern string) ([]string, error)
// 缓存统计
Stats(ctx context.Context) (CacheStats, error)
}
// CacheStats 缓存统计信息
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
Memory int64 `json:"memory"`
Connections int64 `json:"connections"`
}
// TransactionService 事务服务接口
type TransactionService interface {
Service
// 事务操作
Begin(ctx context.Context) (Transaction, error)
RunInTransaction(ctx context.Context, fn func(Transaction) error) error
}
// Transaction 事务接口
type Transaction interface {
Commit() error
Rollback() error
GetDB() interface{}
}

View File

@@ -0,0 +1,241 @@
package logger
import (
"context"
"fmt"
"os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Logger 日志接口
type Logger interface {
Debug(msg string, fields ...zapcore.Field)
Info(msg string, fields ...zapcore.Field)
Warn(msg string, fields ...zapcore.Field)
Error(msg string, fields ...zapcore.Field)
Fatal(msg string, fields ...zapcore.Field)
Panic(msg string, fields ...zapcore.Field)
With(fields ...zapcore.Field) Logger
WithContext(ctx context.Context) Logger
Sync() error
}
// ZapLogger Zap日志实现
type ZapLogger struct {
logger *zap.Logger
}
// Config 日志配置
type Config struct {
Level string
Format string
Output string
FilePath string
MaxSize int
MaxBackups int
MaxAge int
Compress bool
}
// NewLogger 创建新的日志实例
func NewLogger(config Config) (Logger, error) {
// 设置日志级别
level, err := zapcore.ParseLevel(config.Level)
if err != nil {
return nil, fmt.Errorf("无效的日志级别: %w", err)
}
// 配置编码器
var encoder zapcore.Encoder
encoderConfig := getEncoderConfig()
switch config.Format {
case "json":
encoder = zapcore.NewJSONEncoder(encoderConfig)
case "console":
encoder = zapcore.NewConsoleEncoder(encoderConfig)
default:
encoder = zapcore.NewJSONEncoder(encoderConfig)
}
// 配置输出
var writeSyncer zapcore.WriteSyncer
switch config.Output {
case "stdout":
writeSyncer = zapcore.AddSync(os.Stdout)
case "stderr":
writeSyncer = zapcore.AddSync(os.Stderr)
case "file":
if config.FilePath == "" {
config.FilePath = "logs/app.log"
}
// 确保目录存在
if err := os.MkdirAll("logs", 0755); err != nil {
return nil, fmt.Errorf("创建日志目录失败: %w", err)
}
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, fmt.Errorf("打开日志文件失败: %w", err)
}
writeSyncer = zapcore.AddSync(file)
default:
writeSyncer = zapcore.AddSync(os.Stdout)
}
// 创建核心
core := zapcore.NewCore(encoder, writeSyncer, level)
// 创建logger
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return &ZapLogger{
logger: logger,
}, nil
}
// getEncoderConfig 获取编码器配置
func getEncoderConfig() zapcore.EncoderConfig {
return zapcore.EncoderConfig{
TimeKey: "timestamp",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
FunctionKey: zapcore.OmitKey,
MessageKey: "message",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.LowercaseLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.StringDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
}
// Debug 调试日志
func (l *ZapLogger) Debug(msg string, fields ...zapcore.Field) {
l.logger.Debug(msg, fields...)
}
// Info 信息日志
func (l *ZapLogger) Info(msg string, fields ...zapcore.Field) {
l.logger.Info(msg, fields...)
}
// Warn 警告日志
func (l *ZapLogger) Warn(msg string, fields ...zapcore.Field) {
l.logger.Warn(msg, fields...)
}
// Error 错误日志
func (l *ZapLogger) Error(msg string, fields ...zapcore.Field) {
l.logger.Error(msg, fields...)
}
// Fatal 致命错误日志
func (l *ZapLogger) Fatal(msg string, fields ...zapcore.Field) {
l.logger.Fatal(msg, fields...)
}
// Panic 恐慌日志
func (l *ZapLogger) Panic(msg string, fields ...zapcore.Field) {
l.logger.Panic(msg, fields...)
}
// With 添加字段
func (l *ZapLogger) With(fields ...zapcore.Field) Logger {
return &ZapLogger{
logger: l.logger.With(fields...),
}
}
// WithContext 从上下文添加字段
func (l *ZapLogger) WithContext(ctx context.Context) Logger {
// 从上下文中提取常用字段
fields := []zapcore.Field{}
if traceID := getTraceIDFromContext(ctx); traceID != "" {
fields = append(fields, zap.String("trace_id", traceID))
}
if userID := getUserIDFromContext(ctx); userID != "" {
fields = append(fields, zap.String("user_id", userID))
}
if requestID := getRequestIDFromContext(ctx); requestID != "" {
fields = append(fields, zap.String("request_id", requestID))
}
return l.With(fields...)
}
// Sync 同步日志
func (l *ZapLogger) Sync() error {
return l.logger.Sync()
}
// getTraceIDFromContext 从上下文获取追踪ID
func getTraceIDFromContext(ctx context.Context) string {
if traceID := ctx.Value("trace_id"); traceID != nil {
if id, ok := traceID.(string); ok {
return id
}
}
return ""
}
// getUserIDFromContext 从上下文获取用户ID
func getUserIDFromContext(ctx context.Context) string {
if userID := ctx.Value("user_id"); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// getRequestIDFromContext 从上下文获取请求ID
func getRequestIDFromContext(ctx context.Context) string {
if requestID := ctx.Value("request_id"); requestID != nil {
if id, ok := requestID.(string); ok {
return id
}
}
return ""
}
// Field 创建日志字段的便捷函数
func String(key, val string) zapcore.Field {
return zap.String(key, val)
}
func Int(key string, val int) zapcore.Field {
return zap.Int(key, val)
}
func Int64(key string, val int64) zapcore.Field {
return zap.Int64(key, val)
}
func Float64(key string, val float64) zapcore.Field {
return zap.Float64(key, val)
}
func Bool(key string, val bool) zapcore.Field {
return zap.Bool(key, val)
}
func Error(err error) zapcore.Field {
return zap.Error(err)
}
func Any(key string, val interface{}) zapcore.Field {
return zap.Any(key, val)
}
func Duration(key string, val interface{}) zapcore.Field {
return zap.Any(key, val)
}

View File

@@ -0,0 +1,261 @@
package middleware
import (
"net/http"
"strings"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
)
// JWTAuthMiddleware JWT认证中间件
type JWTAuthMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
return &JWTAuthMiddleware{
config: cfg,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *JWTAuthMiddleware) GetName() string {
return "jwt_auth"
}
// GetPriority 返回中间件优先级
func (m *JWTAuthMiddleware) GetPriority() int {
return 60 // 中等优先级,在日志之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "Missing authorization header")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "Invalid authorization header format")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "Missing token")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
m.logger.Warn("Invalid token",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "Invalid token")
return
}
// 将用户信息添加到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *JWTAuthMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到需要认证的路由
}
// JWTClaims JWT声明结构
type JWTClaims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
jwt.RegisteredClaims
}
// validateToken 验证JWT token
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, jwt.ErrSignatureInvalid
}
return claims, nil
}
// respondUnauthorized 返回未授权响应
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "Unauthorized",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// GenerateToken 生成JWT token
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
now := time.Now()
claims := &JWTClaims{
UserID: userID,
Username: username,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-client"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// GenerateRefreshToken 生成刷新token
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := &jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-refresh"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// ValidateRefreshToken 验证刷新token
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
// 检查是否为刷新token
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
return "", jwt.ErrSignatureInvalid
}
return claims.Subject, nil
}
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
type OptionalAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
return &OptionalAuthMiddleware{
jwtAuth: jwtAuth,
}
}
// GetName 返回中间件名称
func (m *OptionalAuthMiddleware) GetName() string {
return "optional_auth"
}
// GetPriority 返回中间件优先级
func (m *OptionalAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 返回中间件处理函数
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// 没有认证头部,设置匿名用户标识
c.Set("is_authenticated", false)
c.Next()
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
c.Set("is_authenticated", false)
c.Next()
return
}
// 提取并验证token
tokenString := authHeader[len(bearerPrefix):]
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
// token无效但不返回错误设置为未认证
c.Set("is_authenticated", false)
c.Next()
return
}
// token有效设置用户信息
c.Set("is_authenticated", true)
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *OptionalAuthMiddleware) IsGlobal() bool {
return false
}

View File

@@ -0,0 +1,104 @@
package middleware
import (
"tyapi-server/internal/config"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// CORSMiddleware CORS中间件
type CORSMiddleware struct {
config *config.Config
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
return &CORSMiddleware{
config: cfg,
}
}
// GetName 返回中间件名称
func (m *CORSMiddleware) GetName() string {
return "cors"
}
// GetPriority 返回中间件优先级
func (m *CORSMiddleware) GetPriority() int {
return 100 // 高优先级,最先执行
}
// Handle 返回中间件处理函数
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
if !m.config.Development.EnableCors {
// 如果没有启用CORS返回空处理函数
return func(c *gin.Context) {
c.Next()
}
}
config := cors.Config{
AllowAllOrigins: false,
AllowOrigins: m.getAllowedOrigins(),
AllowMethods: m.getAllowedMethods(),
AllowHeaders: m.getAllowedHeaders(),
ExposeHeaders: []string{
"Content-Length",
"Content-Type",
"X-Request-ID",
"X-Response-Time",
},
AllowCredentials: true,
MaxAge: 86400, // 24小时
}
return cors.New(config)
}
// IsGlobal 是否为全局中间件
func (m *CORSMiddleware) IsGlobal() bool {
return true
}
// getAllowedOrigins 获取允许的来源
func (m *CORSMiddleware) getAllowedOrigins() []string {
if m.config.Development.CorsOrigins == "" {
return []string{"http://localhost:3000", "http://localhost:8080"}
}
// TODO: 解析配置中的origins字符串
return []string{m.config.Development.CorsOrigins}
}
// getAllowedMethods 获取允许的方法
func (m *CORSMiddleware) getAllowedMethods() []string {
if m.config.Development.CorsMethods == "" {
return []string{
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
}
}
// TODO: 解析配置中的methods字符串
return []string{m.config.Development.CorsMethods}
}
// getAllowedHeaders 获取允许的头部
func (m *CORSMiddleware) getAllowedHeaders() []string {
if m.config.Development.CorsHeaders == "" {
return []string{
"Origin",
"Content-Length",
"Content-Type",
"Authorization",
"X-Requested-With",
"Accept",
"Accept-Encoding",
"Accept-Language",
"X-Request-ID",
}
}
// TODO: 解析配置中的headers字符串
return []string{m.config.Development.CorsHeaders}
}

View File

@@ -0,0 +1,166 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
limiters: make(map[string]*rate.Limiter),
}
}
// GetName 返回中间件名称
func (m *RateLimitMiddleware) GetName() string {
return "ratelimit"
}
// GetPriority 返回中间件优先级
func (m *RateLimitMiddleware) GetPriority() int {
return 90 // 高优先级
}
// Handle 返回中间件处理函数
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端标识IP地址
clientID := m.getClientID(c)
// 获取或创建限流器
limiter := m.getLimiter(clientID)
// 检查是否允许请求
if !limiter.Allow() {
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Header("Retry-After", "60")
c.JSON(http.StatusTooManyRequests, gin.H{
"success": false,
"message": "Rate limit exceeded",
"error": "Too many requests",
})
c.Abort()
return
}
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RateLimitMiddleware) IsGlobal() bool {
return true
}
// getClientID 获取客户端标识
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
// 优先使用X-Forwarded-For头部
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
return xff
}
// 使用X-Real-IP头部
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
// 使用RemoteAddr
return c.ClientIP()
}
// getLimiter 获取或创建限流器
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
m.mutex.RLock()
limiter, exists := m.limiters[clientID]
m.mutex.RUnlock()
if exists {
return limiter
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 双重检查
if limiter, exists := m.limiters[clientID]; exists {
return limiter
}
// 创建新的限流器
// rate.Every计算每个请求之间的间隔
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
m.limiters[clientID] = limiter
// 启动清理协程(仅第一次创建时)
if len(m.limiters) == 1 {
go m.cleanupRoutine()
}
return limiter
}
// cleanupRoutine 定期清理不活跃的限流器
func (m *RateLimitMiddleware) cleanupRoutine() {
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanup()
}
}
}
// cleanup 清理不活跃的限流器
func (m *RateLimitMiddleware) cleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
for clientID, limiter := range m.limiters {
// 如果限流器在过去1小时内没有被使用则删除它
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
delete(m.limiters, clientID)
}
}
}
// GetStats 获取限流统计
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"active_limiters": len(m.limiters),
"rate_limit": map[string]interface{}{
"requests": m.config.RateLimit.Requests,
"window": m.config.RateLimit.Window,
"burst": m.config.RateLimit.Burst,
},
}
}

View File

@@ -0,0 +1,241 @@
package middleware
import (
"bytes"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// RequestLoggerMiddleware 请求日志中间件
type RequestLoggerMiddleware struct {
logger *zap.Logger
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
return &RequestLoggerMiddleware{
logger: logger,
}
}
// GetName 返回中间件名称
func (m *RequestLoggerMiddleware) GetName() string {
return "request_logger"
}
// GetPriority 返回中间件优先级
func (m *RequestLoggerMiddleware) GetPriority() int {
return 80 // 中等优先级
}
// Handle 返回中间件处理函数
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 使用zap logger记录请求信息
m.logger.Info("HTTP Request",
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
)
// 返回空字符串因为我们已经用zap记录了
return ""
})
}
// IsGlobal 是否为全局中间件
func (m *RequestLoggerMiddleware) IsGlobal() bool {
return true
}
// RequestIDMiddleware 请求ID中间件
type RequestIDMiddleware struct{}
// NewRequestIDMiddleware 创建请求ID中间件
func NewRequestIDMiddleware() *RequestIDMiddleware {
return &RequestIDMiddleware{}
}
// GetName 返回中间件名称
func (m *RequestIDMiddleware) GetName() string {
return "request_id"
}
// GetPriority 返回中间件优先级
func (m *RequestIDMiddleware) GetPriority() int {
return 95 // 最高优先级,第一个执行
}
// Handle 返回中间件处理函数
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取或生成请求ID
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
// 设置请求ID到上下文和响应头
c.Set("request_id", requestID)
c.Header("X-Request-ID", requestID)
// 添加到响应头,方便客户端追踪
c.Writer.Header().Set("X-Request-ID", requestID)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RequestIDMiddleware) IsGlobal() bool {
return true
}
// SecurityHeadersMiddleware 安全头部中间件
type SecurityHeadersMiddleware struct{}
// NewSecurityHeadersMiddleware 创建安全头部中间件
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
return &SecurityHeadersMiddleware{}
}
// GetName 返回中间件名称
func (m *SecurityHeadersMiddleware) GetName() string {
return "security_headers"
}
// GetPriority 返回中间件优先级
func (m *SecurityHeadersMiddleware) GetPriority() int {
return 85 // 高优先级
}
// Handle 返回中间件处理函数
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置安全头部
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Content-Security-Policy", "default-src 'self'")
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
return true
}
// ResponseTimeMiddleware 响应时间中间件
type ResponseTimeMiddleware struct{}
// NewResponseTimeMiddleware 创建响应时间中间件
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
return &ResponseTimeMiddleware{}
}
// GetName 返回中间件名称
func (m *ResponseTimeMiddleware) GetName() string {
return "response_time"
}
// GetPriority 返回中间件优先级
func (m *ResponseTimeMiddleware) GetPriority() int {
return 75 // 中等优先级
}
// Handle 返回中间件处理函数
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
// 计算响应时间并添加到头部
duration := time.Since(start)
c.Header("X-Response-Time", duration.String())
// 记录到上下文中,供其他中间件使用
c.Set("response_time", duration)
}
}
// IsGlobal 是否为全局中间件
func (m *ResponseTimeMiddleware) IsGlobal() bool {
return true
}
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
type RequestBodyLoggerMiddleware struct {
logger *zap.Logger
enable bool
}
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
return &RequestBodyLoggerMiddleware{
logger: logger,
enable: enable,
}
}
// GetName 返回中间件名称
func (m *RequestBodyLoggerMiddleware) GetName() string {
return "request_body_logger"
}
// GetPriority 返回中间件优先级
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
return 70 // 较低优先级
}
// Handle 返回中间件处理函数
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
if !m.enable {
return func(c *gin.Context) {
c.Next()
}
}
return func(c *gin.Context) {
// 只记录POST, PUT, PATCH请求的body
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
if c.Request.Body != nil {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err == nil {
// 重新设置body供后续处理使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
m.logger.Debug("Request Body",
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.String("body", string(bodyBytes)),
zap.String("request_id", c.GetString("request_id")),
)
}
}
}
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
return false // 可选中间件,不是全局的
}