feat(架构): 完善基础架构设计

This commit is contained in:
2025-07-02 16:17:59 +08:00
parent 03e615a8fd
commit 5b4392894f
89 changed files with 18555 additions and 3521 deletions

View File

@@ -50,6 +50,16 @@ func (a *Application) Run() error {
// 打印启动信息
a.printBanner()
// 检查是否需要自动迁移
if a.config.Database.AutoMigrate {
a.logger.Info("Auto migration is enabled, running database migrations...")
if err := a.RunMigrations(); err != nil {
a.logger.Error("Auto migration failed", zap.Error(err))
return fmt.Errorf("auto migration failed: %w", err)
}
a.logger.Info("Auto migration completed successfully")
}
// 启动容器
a.logger.Info("Starting application container...")
if err := a.container.Start(); err != nil {
@@ -92,10 +102,10 @@ func (a *Application) RunMigrations() error {
func (a *Application) printBanner() {
banner := fmt.Sprintf(`
╔══════════════════════════════════════════════════════════════╗
║ %s
║ Version: %s
║ Environment: %s
║ Port: %s
║ %s
║ Version: %s
║ Environment: %s
║ Port: %s
╚══════════════════════════════════════════════════════════════╝
`,
a.config.App.Name,
@@ -151,9 +161,20 @@ func (a *Application) createDatabaseConnection() (*gorm.DB, error) {
// autoMigrate 自动迁移
func (a *Application) autoMigrate(db *gorm.DB) error {
// 如果需要删除某些表,可以在这里手动删除
// 注意:这会永久删除数据,请谨慎使用!
/*
// 删除不再需要的表(示例,请根据实际情况使用)
if err := db.Migrator().DropTable(&entities.FavoriteItem{}); err != nil {
a.logger.Warn("Failed to drop table", zap.Error(err))
// 继续执行,不阻断迁移
}
*/
// 迁移用户相关表
return db.AutoMigrate(
&entities.User{},
&entities.SMSCode{},
// 后续可以添加其他实体
)
}

View File

@@ -12,6 +12,7 @@ type Config struct {
Cache CacheConfig `mapstructure:"cache"`
Logger LoggerConfig `mapstructure:"logger"`
JWT JWTConfig `mapstructure:"jwt"`
SMS SMSConfig `mapstructure:"sms"`
RateLimit RateLimitConfig `mapstructure:"ratelimit"`
Monitoring MonitoringConfig `mapstructure:"monitoring"`
Health HealthConfig `mapstructure:"health"`
@@ -42,6 +43,7 @@ type DatabaseConfig struct {
MaxOpenConns int `mapstructure:"max_open_conns"`
MaxIdleConns int `mapstructure:"max_idle_conns"`
ConnMaxLifetime time.Duration `mapstructure:"conn_max_lifetime"`
AutoMigrate bool `mapstructure:"auto_migrate"`
}
// RedisConfig Redis配置
@@ -67,14 +69,15 @@ type CacheConfig struct {
// LoggerConfig 日志配置
type LoggerConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
Output string `mapstructure:"output"`
FilePath string `mapstructure:"file_path"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
Output string `mapstructure:"output"`
FilePath string `mapstructure:"file_path"`
MaxSize int `mapstructure:"max_size"`
MaxBackups int `mapstructure:"max_backups"`
MaxAge int `mapstructure:"max_age"`
Compress bool `mapstructure:"compress"`
UseColor bool `mapstructure:"use_color"`
}
// JWTConfig JWT配置
@@ -119,12 +122,12 @@ type ResilienceConfig struct {
// DevelopmentConfig 开发配置
type DevelopmentConfig struct {
Debug bool `mapstructure:"debug"`
EnableProfiler bool `mapstructure:"enable_profiler"`
EnableCors bool `mapstructure:"enable_cors"`
CorsOrigins string `mapstructure:"cors_allowed_origins"`
CorsMethods string `mapstructure:"cors_allowed_methods"`
CorsHeaders string `mapstructure:"cors_allowed_headers"`
Debug bool `mapstructure:"debug"`
EnableProfiler bool `mapstructure:"enable_profiler"`
EnableCors bool `mapstructure:"enable_cors"`
CorsOrigins string `mapstructure:"cors_allowed_origins"`
CorsMethods string `mapstructure:"cors_allowed_methods"`
CorsHeaders string `mapstructure:"cors_allowed_headers"`
}
// AppConfig 应用程序配置
@@ -134,6 +137,26 @@ type AppConfig struct {
Env string `mapstructure:"env"`
}
// SMSConfig 短信配置
type SMSConfig struct {
AccessKeyID string `mapstructure:"access_key_id"`
AccessKeySecret string `mapstructure:"access_key_secret"`
EndpointURL string `mapstructure:"endpoint_url"`
SignName string `mapstructure:"sign_name"`
TemplateCode string `mapstructure:"template_code"`
CodeLength int `mapstructure:"code_length"`
ExpireTime time.Duration `mapstructure:"expire_time"`
RateLimit SMSRateLimit `mapstructure:"rate_limit"`
MockEnabled bool `mapstructure:"mock_enabled"` // 是否启用模拟短信服务
}
// SMSRateLimit 短信限流配置
type SMSRateLimit struct {
DailyLimit int `mapstructure:"daily_limit"` // 每日发送限制
HourlyLimit int `mapstructure:"hourly_limit"` // 每小时发送限制
MinInterval time.Duration `mapstructure:"min_interval"` // 最小发送间隔
}
// GetDSN 获取数据库DSN连接字符串
func (d DatabaseConfig) GetDSN() string {
return "host=" + d.Host +
@@ -163,4 +186,4 @@ func (a AppConfig) IsDevelopment() bool {
// IsStaging 检查是否为测试环境
func (a AppConfig) IsStaging() bool {
return a.Env == "staging"
}
}

View File

@@ -3,6 +3,7 @@ package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
@@ -11,228 +12,173 @@ import (
// LoadConfig 加载应用程序配置
func LoadConfig() (*Config, error) {
// 设置配置文件名和路径
viper.SetConfigName("config")
viper.SetConfigType("yaml")
viper.AddConfigPath(".")
viper.AddConfigPath("./configs")
viper.AddConfigPath("$HOME/.tyapi")
// 1⃣ 获取环境变量决定配置文件
env := getEnvironment()
fmt.Printf("🔧 当前运行环境: %s\n", env)
// 设置环境变量前缀
viper.SetEnvPrefix("")
viper.AutomaticEnv()
// 2⃣ 加载基础配置文件
baseConfig := viper.New()
baseConfig.SetConfigName("config")
baseConfig.SetConfigType("yaml")
baseConfig.AddConfigPath(".")
baseConfig.AddConfigPath("./configs")
baseConfig.AddConfigPath("$HOME/.tyapi")
// 配置环境变量键名映射
setupEnvKeyMapping()
// 设置默认值
setDefaults()
// 尝试读取配置文件(可选)
if err := viper.ReadInConfig(); err != nil {
// 读取基础配置文件
if err := baseConfig.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
return nil, fmt.Errorf("读取基础配置文件失败: %w", err)
}
// 配置文件存在时使用环境变量和默认值
return nil, fmt.Errorf("未找到 config.yaml 文件,请确保配置文件存在")
}
fmt.Printf("✅ 已加载配置文件: %s\n", baseConfig.ConfigFileUsed())
// 3⃣ 加载环境特定配置文件
envConfigFile := findEnvConfigFile(env)
if envConfigFile != "" {
// 创建一个新的viper实例来读取环境配置
envConfig := viper.New()
envConfig.SetConfigFile(envConfigFile)
if err := envConfig.ReadInConfig(); err != nil {
fmt.Printf("⚠️ 环境配置文件加载警告: %v\n", err)
} else {
fmt.Printf("✅ 已加载环境配置: %s\n", envConfigFile)
// 将环境配置合并到基础配置中
if err := mergeConfigs(baseConfig, envConfig.AllSettings()); err != nil {
return nil, fmt.Errorf("合并配置失败: %w", err)
}
}
} else {
fmt.Printf("⚠️ 未找到环境配置文件 env.%s.yaml\n", env)
}
// 4⃣ 设置环境变量前缀和自动读取
baseConfig.SetEnvPrefix("")
baseConfig.AutomaticEnv()
baseConfig.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
// 5⃣ 解析配置到结构体
var config Config
if err := viper.Unmarshal(&config); err != nil {
if err := baseConfig.Unmarshal(&config); err != nil {
return nil, fmt.Errorf("解析配置失败: %w", err)
}
// 验证配置
// 6 验证配置
if err := validateConfig(&config); err != nil {
return nil, fmt.Errorf("配置验证失败: %w", err)
}
// 7⃣ 输出配置摘要
printConfigSummary(&config, env)
return &config, nil
}
// setupEnvKeyMapping 设置环境变量到配置键的映射
func setupEnvKeyMapping() {
// 服务器配置
viper.BindEnv("server.port", "SERVER_PORT")
viper.BindEnv("server.mode", "SERVER_MODE")
viper.BindEnv("server.host", "SERVER_HOST")
viper.BindEnv("server.read_timeout", "SERVER_READ_TIMEOUT")
viper.BindEnv("server.write_timeout", "SERVER_WRITE_TIMEOUT")
viper.BindEnv("server.idle_timeout", "SERVER_IDLE_TIMEOUT")
// mergeConfigs 递归合并配置
func mergeConfigs(baseConfig *viper.Viper, overrideSettings map[string]interface{}) error {
for key, val := range overrideSettings {
// 如果值是一个嵌套的map则递归合并
if subMap, ok := val.(map[string]interface{}); ok {
// 创建子键路径
subKey := key
// 数据库配置
viper.BindEnv("database.host", "DB_HOST")
viper.BindEnv("database.port", "DB_PORT")
viper.BindEnv("database.user", "DB_USER")
viper.BindEnv("database.password", "DB_PASSWORD")
viper.BindEnv("database.name", "DB_NAME")
viper.BindEnv("database.sslmode", "DB_SSLMODE")
viper.BindEnv("database.timezone", "DB_TIMEZONE")
viper.BindEnv("database.max_open_conns", "DB_MAX_OPEN_CONNS")
viper.BindEnv("database.max_idle_conns", "DB_MAX_IDLE_CONNS")
viper.BindEnv("database.conn_max_lifetime", "DB_CONN_MAX_LIFETIME")
// Redis配置
viper.BindEnv("redis.host", "REDIS_HOST")
viper.BindEnv("redis.port", "REDIS_PORT")
viper.BindEnv("redis.password", "REDIS_PASSWORD")
viper.BindEnv("redis.db", "REDIS_DB")
viper.BindEnv("redis.pool_size", "REDIS_POOL_SIZE")
viper.BindEnv("redis.min_idle_conns", "REDIS_MIN_IDLE_CONNS")
viper.BindEnv("redis.max_retries", "REDIS_MAX_RETRIES")
viper.BindEnv("redis.dial_timeout", "REDIS_DIAL_TIMEOUT")
viper.BindEnv("redis.read_timeout", "REDIS_READ_TIMEOUT")
viper.BindEnv("redis.write_timeout", "REDIS_WRITE_TIMEOUT")
// 缓存配置
viper.BindEnv("cache.default_ttl", "CACHE_DEFAULT_TTL")
viper.BindEnv("cache.cleanup_interval", "CACHE_CLEANUP_INTERVAL")
viper.BindEnv("cache.max_size", "CACHE_MAX_SIZE")
// 日志配置
viper.BindEnv("logger.level", "LOG_LEVEL")
viper.BindEnv("logger.format", "LOG_FORMAT")
viper.BindEnv("logger.output", "LOG_OUTPUT")
viper.BindEnv("logger.file_path", "LOG_FILE_PATH")
viper.BindEnv("logger.max_size", "LOG_MAX_SIZE")
viper.BindEnv("logger.max_backups", "LOG_MAX_BACKUPS")
viper.BindEnv("logger.max_age", "LOG_MAX_AGE")
viper.BindEnv("logger.compress", "LOG_COMPRESS")
// JWT配置
viper.BindEnv("jwt.secret", "JWT_SECRET")
viper.BindEnv("jwt.expires_in", "JWT_EXPIRES_IN")
viper.BindEnv("jwt.refresh_expires_in", "JWT_REFRESH_EXPIRES_IN")
// 限流配置
viper.BindEnv("ratelimit.requests", "RATE_LIMIT_REQUESTS")
viper.BindEnv("ratelimit.window", "RATE_LIMIT_WINDOW")
viper.BindEnv("ratelimit.burst", "RATE_LIMIT_BURST")
// 监控配置
viper.BindEnv("monitoring.metrics_enabled", "METRICS_ENABLED")
viper.BindEnv("monitoring.metrics_port", "METRICS_PORT")
viper.BindEnv("monitoring.tracing_enabled", "TRACING_ENABLED")
viper.BindEnv("monitoring.tracing_endpoint", "TRACING_ENDPOINT")
viper.BindEnv("monitoring.sample_rate", "TRACING_SAMPLE_RATE")
// 健康检查配置
viper.BindEnv("health.enabled", "HEALTH_CHECK_ENABLED")
viper.BindEnv("health.interval", "HEALTH_CHECK_INTERVAL")
viper.BindEnv("health.timeout", "HEALTH_CHECK_TIMEOUT")
// 容错配置
viper.BindEnv("resilience.circuit_breaker_enabled", "CIRCUIT_BREAKER_ENABLED")
viper.BindEnv("resilience.circuit_breaker_threshold", "CIRCUIT_BREAKER_THRESHOLD")
viper.BindEnv("resilience.circuit_breaker_timeout", "CIRCUIT_BREAKER_TIMEOUT")
viper.BindEnv("resilience.retry_max_attempts", "RETRY_MAX_ATTEMPTS")
viper.BindEnv("resilience.retry_initial_delay", "RETRY_INITIAL_DELAY")
viper.BindEnv("resilience.retry_max_delay", "RETRY_MAX_DELAY")
// 开发配置
viper.BindEnv("development.debug", "DEBUG")
viper.BindEnv("development.enable_profiler", "ENABLE_PROFILER")
viper.BindEnv("development.enable_cors", "ENABLE_CORS")
viper.BindEnv("development.cors_allowed_origins", "CORS_ALLOWED_ORIGINS")
viper.BindEnv("development.cors_allowed_methods", "CORS_ALLOWED_METHODS")
viper.BindEnv("development.cors_allowed_headers", "CORS_ALLOWED_HEADERS")
// 应用程序配置
viper.BindEnv("app.name", "APP_NAME")
viper.BindEnv("app.version", "APP_VERSION")
viper.BindEnv("app.env", "ENV")
// 递归合并子配置
for subK, subV := range subMap {
fullKey := fmt.Sprintf("%s.%s", subKey, subK)
baseConfig.Set(fullKey, subV)
}
} else {
// 直接设置值
baseConfig.Set(key, val)
}
}
return nil
}
// setDefaults 设置默认配置值
func setDefaults() {
// 服务器默认值
viper.SetDefault("server.port", "8080")
viper.SetDefault("server.mode", "debug")
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.read_timeout", "30s")
viper.SetDefault("server.write_timeout", "30s")
viper.SetDefault("server.idle_timeout", "120s")
// findEnvConfigFile 查找环境特定的配置文件
func findEnvConfigFile(env string) string {
// 尝试查找的配置文件路径
possiblePaths := []string{
fmt.Sprintf("configs/env.%s.yaml", env),
fmt.Sprintf("configs/env.%s.yml", env),
fmt.Sprintf("configs/env.%s", env),
fmt.Sprintf("env.%s.yaml", env),
fmt.Sprintf("env.%s.yml", env),
fmt.Sprintf("env.%s", env),
}
// 数据库默认值
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", "5432")
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "password")
viper.SetDefault("database.name", "tyapi_db")
viper.SetDefault("database.sslmode", "disable")
viper.SetDefault("database.timezone", "Asia/Shanghai")
viper.SetDefault("database.max_open_conns", 100)
viper.SetDefault("database.max_idle_conns", 10)
viper.SetDefault("database.conn_max_lifetime", "300s")
// 如果有自定义环境文件路径
if customEnvFile := os.Getenv("ENV_FILE"); customEnvFile != "" {
possiblePaths = append([]string{customEnvFile}, possiblePaths...)
}
// Redis默认值
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", "6379")
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.pool_size", 10)
viper.SetDefault("redis.min_idle_conns", 5)
viper.SetDefault("redis.max_retries", 3)
viper.SetDefault("redis.dial_timeout", "5s")
viper.SetDefault("redis.read_timeout", "3s")
viper.SetDefault("redis.write_timeout", "3s")
for _, path := range possiblePaths {
if _, err := os.Stat(path); err == nil {
absPath, _ := filepath.Abs(path)
return absPath
}
}
// 缓存默认值
viper.SetDefault("cache.default_ttl", "300s")
viper.SetDefault("cache.cleanup_interval", "600s")
viper.SetDefault("cache.max_size", 1000)
return ""
}
// 日志默认值
viper.SetDefault("logger.level", "info")
viper.SetDefault("logger.format", "json")
viper.SetDefault("logger.output", "stdout")
viper.SetDefault("logger.file_path", "logs/app.log")
viper.SetDefault("logger.max_size", 100)
viper.SetDefault("logger.max_backups", 5)
viper.SetDefault("logger.max_age", 30)
viper.SetDefault("logger.compress", true)
// getEnvironment 获取当前环境
func getEnvironment() string {
var env string
var source string
// JWT默认值
viper.SetDefault("jwt.secret", "your-super-secret-jwt-key-change-this-in-production")
viper.SetDefault("jwt.expires_in", "24h")
viper.SetDefault("jwt.refresh_expires_in", "168h")
// 优先级CONFIG_ENV > ENV > APP_ENV > 默认值
if env = os.Getenv("CONFIG_ENV"); env != "" {
source = "CONFIG_ENV"
} else if env = os.Getenv("ENV"); env != "" {
source = "ENV"
} else if env = os.Getenv("APP_ENV"); env != "" {
source = "APP_ENV"
} else {
env = "development"
source = "默认值"
}
// 限流默认值
viper.SetDefault("ratelimit.requests", 100)
viper.SetDefault("ratelimit.window", "60s")
viper.SetDefault("ratelimit.burst", 10)
fmt.Printf("🌍 环境检测: %s (来源: %s)\n", env, source)
// 监控默认
viper.SetDefault("monitoring.metrics_enabled", true)
viper.SetDefault("monitoring.metrics_port", "9090")
viper.SetDefault("monitoring.tracing_enabled", false)
viper.SetDefault("monitoring.tracing_endpoint", "http://localhost:14268/api/traces")
viper.SetDefault("monitoring.sample_rate", 0.1)
// 验证环境
validEnvs := []string{"development", "production", "testing"}
isValid := false
for _, validEnv := range validEnvs {
if env == validEnv {
isValid = true
break
}
}
// 健康检查默认值
viper.SetDefault("health.enabled", true)
viper.SetDefault("health.interval", "30s")
viper.SetDefault("health.timeout", "5s")
if !isValid {
fmt.Printf("⚠️ 警告: 未识别的环境 '%s',将使用默认环境 'development'\n", env)
return "development"
}
// 容错默认值
viper.SetDefault("resilience.circuit_breaker_enabled", true)
viper.SetDefault("resilience.circuit_breaker_threshold", 5)
viper.SetDefault("resilience.circuit_breaker_timeout", "60s")
viper.SetDefault("resilience.retry_max_attempts", 3)
viper.SetDefault("resilience.retry_initial_delay", "100ms")
viper.SetDefault("resilience.retry_max_delay", "2s")
return env
}
// 开发默认值
viper.SetDefault("development.debug", true)
viper.SetDefault("development.enable_profiler", false)
viper.SetDefault("development.enable_cors", true)
viper.SetDefault("development.cors_allowed_origins", "*")
viper.SetDefault("development.cors_allowed_methods", "GET,POST,PUT,DELETE,OPTIONS")
viper.SetDefault("development.cors_allowed_headers", "*")
// 应用程序默认值
viper.SetDefault("app.name", "tyapi-server")
viper.SetDefault("app.version", "1.0.0")
viper.SetDefault("app.env", "development")
// printConfigSummary 打印配置摘要
func printConfigSummary(config *Config, env string) {
fmt.Printf("\n🔧 配置摘要:\n")
fmt.Printf(" 🌍 环境: %s\n", env)
fmt.Printf(" 📄 配置模板: config.yaml\n")
fmt.Printf(" 📱 应用名称: %s\n", config.App.Name)
fmt.Printf(" 🔖 版本: %s\n", config.App.Version)
fmt.Printf(" 🌐 服务端口: %s\n", config.Server.Port)
fmt.Printf(" 🗄️ 数据库: %s@%s:%s/%s\n",
config.Database.User,
config.Database.Host,
config.Database.Port,
config.Database.Name)
fmt.Printf(" 📊 追踪状态: %v (端点: %s)\n",
config.Monitoring.TracingEnabled,
config.Monitoring.TracingEndpoint)
fmt.Printf(" 📈 采样率: %.1f%%\n", config.Monitoring.SampleRate*100)
fmt.Printf("\n")
}
// validateConfig 验证配置

View File

@@ -3,11 +3,14 @@ package container
import (
"context"
"fmt"
nethttp "net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
"go.uber.org/fx"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gorm.io/gorm"
"tyapi-server/internal/config"
@@ -19,9 +22,15 @@ import (
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/events"
"tyapi-server/internal/shared/health"
"tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/hooks"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/metrics"
"tyapi-server/internal/shared/middleware"
"tyapi-server/internal/shared/resilience"
"tyapi-server/internal/shared/saga"
"tyapi-server/internal/shared/sms"
"tyapi-server/internal/shared/tracing"
)
// Container 应用容器
@@ -40,11 +49,24 @@ func NewContainer() *Container {
// 基础设施模块
fx.Provide(
NewLogger,
NewDatabase,
// 使用带追踪的组件
NewTracedDatabase,
NewRedisClient,
NewRedisCache,
fx.Annotate(NewTracedRedisCache, fx.As(new(interfaces.CacheService))),
NewEventBus,
NewHealthChecker,
NewSMSService,
),
// 高级特性模块
fx.Provide(
NewTracer,
NewPrometheusMetrics,
NewBusinessMetrics,
NewCircuitBreakerWrapper,
NewRetryerWrapper,
NewSagaManager,
NewHookSystem,
),
// HTTP基础组件
@@ -64,14 +86,22 @@ func NewContainer() *Container {
NewRequestLoggerMiddleware,
NewJWTAuthMiddleware,
NewOptionalAuthMiddleware,
NewTracingMiddleware,
NewMetricsMiddleware,
NewTraceIDMiddleware,
NewErrorTrackingMiddleware,
NewRequestBodyLoggerMiddleware,
),
// 用户域组件
fx.Provide(
NewUserRepository,
NewSMSCodeRepository,
NewSMSCodeService,
NewUserService,
// 使用带自动追踪的用户服务
fx.Annotate(NewTracedUserService, fx.As(new(interfaces.UserService))),
NewUserHandler,
NewUserRoutes,
),
// 应用生命周期
@@ -101,7 +131,7 @@ func (c *Container) Stop() error {
return c.App.Stop(ctx)
}
// 基础设施构造函数
// ================ 基础设施构造函数 ================
// NewLogger 创建日志器
func NewLogger(cfg *config.Config) (*zap.Logger, error) {
@@ -110,18 +140,27 @@ func NewLogger(cfg *config.Config) (*zap.Logger, error) {
level = zap.NewAtomicLevelAt(zap.InfoLevel)
}
config := zap.Config{
Level: level,
Development: cfg.App.IsDevelopment(),
Encoding: cfg.Logger.Format,
EncoderConfig: zap.NewProductionEncoderConfig(),
OutputPaths: []string{cfg.Logger.Output},
ErrorOutputPaths: []string{"stderr"},
var config zap.Config
if cfg.App.IsDevelopment() {
config = zap.NewDevelopmentConfig()
config.Level = level
config.Encoding = "console"
if cfg.Logger.UseColor {
config.EncoderConfig.EncodeLevel = zapcore.LowercaseColorLevelEncoder
}
} else {
config = zap.NewProductionConfig()
config.Level = level
config.Encoding = cfg.Logger.Format
if config.Encoding == "" {
config.Encoding = "json"
}
}
if cfg.Logger.Format == "" {
config.Encoding = "json"
}
config.OutputPaths = []string{cfg.Logger.Output}
config.ErrorOutputPaths = []string{"stderr"}
if cfg.Logger.Output == "" {
config.OutputPaths = []string{"stdout"}
}
@@ -152,6 +191,25 @@ func NewDatabase(cfg *config.Config, logger *zap.Logger) (*gorm.DB, error) {
return db.DB, nil
}
// NewTracedDatabase 创建带追踪的数据库连接
func NewTracedDatabase(cfg *config.Config, tracer *tracing.Tracer, logger *zap.Logger) (*gorm.DB, error) {
// 先创建基础数据库连接
db, err := NewDatabase(cfg, logger)
if err != nil {
return nil, err
}
// 创建并注册GORM追踪插件
tracingPlugin := tracing.NewGormTracingPlugin(tracer, logger)
if err := db.Use(tracingPlugin); err != nil {
logger.Error("注册GORM追踪插件失败", zap.Error(err))
return nil, err
}
logger.Info("GORM自动追踪已启用")
return db, nil
}
// NewRedisClient 创建Redis客户端
func NewRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, error) {
client := redis.NewClient(&redis.Options{
@@ -165,17 +223,16 @@ func NewRedisClient(cfg *config.Config, logger *zap.Logger) (*redis.Client, erro
WriteTimeout: cfg.Redis.WriteTimeout,
})
// 测试连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_, err := client.Ping(ctx).Result()
if err != nil {
logger.Error("Failed to connect to Redis", zap.Error(err))
logger.Error("Redis连接失败", zap.Error(err))
return nil, err
}
logger.Info("Redis connection established")
logger.Info("Redis连接已建立")
return client, nil
}
@@ -184,9 +241,14 @@ func NewRedisCache(client *redis.Client, logger *zap.Logger, cfg *config.Config)
return cache.NewRedisCache(client, logger, "app")
}
// NewTracedRedisCache 创建带追踪的Redis缓存服务
func NewTracedRedisCache(client *redis.Client, tracer *tracing.Tracer, logger *zap.Logger, cfg *config.Config) interfaces.CacheService {
return tracing.NewTracedRedisCache(client, tracer, logger, "app")
}
// NewEventBus 创建事件总线
func NewEventBus(logger *zap.Logger, cfg *config.Config) interfaces.EventBus {
return events.NewMemoryEventBus(logger, 5) // 默认5个工作协程
return events.NewMemoryEventBus(logger, 5)
}
// NewHealthChecker 创建健康检查器
@@ -194,142 +256,293 @@ func NewHealthChecker(logger *zap.Logger) *health.HealthChecker {
return health.NewHealthChecker(logger)
}
// HTTP组件构造函数
// NewSMSService 创建短信服务
func NewSMSService(cfg *config.Config, logger *zap.Logger) (sms.Service, error) {
if cfg.SMS.MockEnabled {
logger.Info("使用模拟短信服务 (mock_enabled=true)")
return sms.NewMockSMSService(logger), nil
}
logger.Info("使用阿里云短信服务")
return sms.NewAliSMSService(cfg.SMS, logger)
}
// ================ HTTP组件构造函数 ================
// NewResponseBuilder 创建响应构建器
func NewResponseBuilder() interfaces.ResponseBuilder {
return http.NewResponseBuilder()
return sharedhttp.NewResponseBuilder()
}
// NewRequestValidator 创建请求验证器
// NewRequestValidator 创建中文请求验证器
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
return http.NewRequestValidator(response)
return sharedhttp.NewRequestValidatorZh(response)
}
// NewGinRouter 创建Gin路由器
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *http.GinRouter {
return http.NewGinRouter(cfg, logger)
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *sharedhttp.GinRouter {
return sharedhttp.NewGinRouter(cfg, logger)
}
// 中间件构造函数
// ================ 中间件构造函数 ================
// NewRequestIDMiddleware 创建请求ID中间件
func NewRequestIDMiddleware() *middleware.RequestIDMiddleware {
return middleware.NewRequestIDMiddleware()
}
// NewSecurityHeadersMiddleware 创建安全头部中间件
func NewSecurityHeadersMiddleware() *middleware.SecurityHeadersMiddleware {
return middleware.NewSecurityHeadersMiddleware()
}
// NewResponseTimeMiddleware 创建响应时间中间件
func NewResponseTimeMiddleware() *middleware.ResponseTimeMiddleware {
return middleware.NewResponseTimeMiddleware()
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *middleware.CORSMiddleware {
return middleware.NewCORSMiddleware(cfg)
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *middleware.RateLimitMiddleware {
return middleware.NewRateLimitMiddleware(cfg)
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *middleware.RateLimitMiddleware {
return middleware.NewRateLimitMiddleware(cfg, response)
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger) *middleware.RequestLoggerMiddleware {
return middleware.NewRequestLoggerMiddleware(logger)
func NewRequestLoggerMiddleware(logger *zap.Logger, cfg *config.Config, tracer *tracing.Tracer) *middleware.RequestLoggerMiddleware {
return middleware.NewRequestLoggerMiddleware(logger, cfg.App.IsDevelopment(), tracer)
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *middleware.JWTAuthMiddleware {
return middleware.NewJWTAuthMiddleware(cfg, logger)
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *middleware.JWTAuthMiddleware) *middleware.OptionalAuthMiddleware {
return middleware.NewOptionalAuthMiddleware(jwtAuth)
}
// 用户域构造函数
func NewTraceIDMiddleware(tracer *tracing.Tracer) *middleware.TraceIDMiddleware {
return middleware.NewTraceIDMiddleware(tracer)
}
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *middleware.ErrorTrackingMiddleware {
return middleware.NewErrorTrackingMiddleware(logger, tracer)
}
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, cfg *config.Config, tracer *tracing.Tracer) *middleware.RequestBodyLoggerMiddleware {
return middleware.NewRequestBodyLoggerMiddleware(logger, cfg.App.IsDevelopment(), tracer)
}
// ================ 高级特性构造函数 ================
// NewTracer 创建链路追踪器
func NewTracer(cfg *config.Config, logger *zap.Logger) *tracing.Tracer {
tracingConfig := tracing.TracerConfig{
ServiceName: cfg.App.Name,
ServiceVersion: cfg.App.Version,
Environment: cfg.App.Env,
Endpoint: cfg.Monitoring.TracingEndpoint,
SampleRate: cfg.Monitoring.SampleRate,
Enabled: cfg.Monitoring.TracingEnabled,
}
return tracing.NewTracer(tracingConfig, logger)
}
// NewPrometheusMetrics 创建Prometheus指标收集器
func NewPrometheusMetrics(logger *zap.Logger, cfg *config.Config) interfaces.MetricsCollector {
if !cfg.Monitoring.MetricsEnabled {
return &NoopMetricsCollector{}
}
return metrics.NewPrometheusMetrics(logger)
}
// NewBusinessMetrics 创建业务指标收集器
func NewBusinessMetrics(prometheusMetrics interfaces.MetricsCollector, logger *zap.Logger) *metrics.BusinessMetrics {
return metrics.NewBusinessMetrics(prometheusMetrics, logger)
}
// NewCircuitBreakerWrapper 创建熔断器包装器
func NewCircuitBreakerWrapper(logger *zap.Logger, cfg *config.Config) *resilience.Wrapper {
return resilience.NewWrapper(logger)
}
// NewRetryerWrapper 创建重试器包装器
func NewRetryerWrapper(logger *zap.Logger) *resilience.RetryerWrapper {
return resilience.NewRetryerWrapper(logger)
}
// NewSagaManager 创建Saga管理器
func NewSagaManager(cfg *config.Config, logger *zap.Logger) *saga.SagaManager {
sagaConfig := saga.SagaConfig{
DefaultTimeout: 30 * time.Second,
DefaultMaxRetries: 3,
Parallel: false,
}
return saga.NewSagaManager(sagaConfig, logger)
}
// NewHookSystem 创建钩子系统
func NewHookSystem(logger *zap.Logger) *hooks.HookSystem {
hookConfig := hooks.HookConfig{
DefaultTimeout: 30 * time.Second,
TrackDuration: true,
ErrorStrategy: hooks.ContinueOnError,
}
return hooks.NewHookSystem(hookConfig, logger)
}
// NewTracingMiddleware 创建追踪中间件
func NewTracingMiddleware(tracer *tracing.Tracer) *TracingMiddleware {
return &TracingMiddleware{tracer: tracer}
}
// NewMetricsMiddleware 创建指标中间件
func NewMetricsMiddleware(metricsCollector interfaces.MetricsCollector) *MetricsMiddleware {
return &MetricsMiddleware{metrics: metricsCollector}
}
// ================ 用户域构造函数 ================
// NewUserRepository 创建用户仓储
func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *repositories.UserRepository {
return repositories.NewUserRepository(db, cache, logger)
}
// NewUserService 创建用户服务
func NewSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *repositories.SMSCodeRepository {
return repositories.NewSMSCodeRepository(db, cache, logger)
}
func NewSMSCodeService(
repo *repositories.SMSCodeRepository,
smsClient sms.Service,
cache interfaces.CacheService,
cfg *config.Config,
logger *zap.Logger,
) *services.SMSCodeService {
return services.NewSMSCodeService(repo, smsClient, cache, cfg.SMS, logger)
}
func NewUserService(
repo *repositories.UserRepository,
smsCodeService *services.SMSCodeService,
eventBus interfaces.EventBus,
logger *zap.Logger,
) *services.UserService {
return services.NewUserService(repo, eventBus, logger)
return services.NewUserService(repo, smsCodeService, eventBus, logger)
}
// NewTracedUserService 创建带自动追踪的用户服务
func NewTracedUserService(
baseService *services.UserService,
tracer *tracing.Tracer,
logger *zap.Logger,
) interfaces.UserService {
serviceWrapper := tracing.NewServiceWrapper(tracer, logger)
return tracing.NewTracedUserService(baseService, serviceWrapper)
}
// NewUserHandler 创建用户处理器
func NewUserHandler(
userService *services.UserService,
userService interfaces.UserService,
smsCodeService *services.SMSCodeService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
) *handlers.UserHandler {
return handlers.NewUserHandler(userService, response, validator, logger, jwtAuth)
return handlers.NewUserHandler(userService, smsCodeService, response, validator, logger, jwtAuth)
}
// NewUserRoutes 创建用户路由
func NewUserRoutes(
handler *handlers.UserHandler,
jwtAuth *middleware.JWTAuthMiddleware,
optionalAuth *middleware.OptionalAuthMiddleware,
) *routes.UserRoutes {
return routes.NewUserRoutes(handler, jwtAuth, optionalAuth)
// ================ 中间件定义 ================
// TracingMiddleware 追踪中间件
type TracingMiddleware struct {
tracer *tracing.Tracer
}
// 注册函数
func (tm *TracingMiddleware) GetName() string { return "tracing" }
func (tm *TracingMiddleware) GetPriority() int { return 1 }
func (tm *TracingMiddleware) IsGlobal() bool { return true }
func (tm *TracingMiddleware) Handle() gin.HandlerFunc {
return tm.tracer.TraceMiddleware()
}
// MetricsMiddleware 指标中间件
type MetricsMiddleware struct {
metrics interfaces.MetricsCollector
}
func (mm *MetricsMiddleware) GetName() string { return "metrics" }
func (mm *MetricsMiddleware) GetPriority() int { return 2 }
func (mm *MetricsMiddleware) IsGlobal() bool { return true }
func (mm *MetricsMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
duration := time.Since(start).Seconds()
mm.metrics.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
mm.metrics.RecordHTTPDuration(c.Request.Method, c.FullPath(), duration)
}
}
// NoopMetricsCollector 空的指标收集器实现
type NoopMetricsCollector struct{}
func (n *NoopMetricsCollector) RecordHTTPRequest(method, path string, status int, duration float64) {}
func (n *NoopMetricsCollector) RecordHTTPDuration(method, path string, duration float64) {}
func (n *NoopMetricsCollector) IncrementCounter(name string, labels map[string]string) {}
func (n *NoopMetricsCollector) RecordGauge(name string, value float64, labels map[string]string) {}
func (n *NoopMetricsCollector) RecordHistogram(name string, value float64, labels map[string]string) {
}
func (n *NoopMetricsCollector) RegisterCounter(name, help string, labels []string) error { return nil }
func (n *NoopMetricsCollector) RegisterGauge(name, help string, labels []string) error { return nil }
func (n *NoopMetricsCollector) RegisterHistogram(name, help string, labels []string, buckets []float64) error {
return nil
}
func (n *NoopMetricsCollector) GetHandler() nethttp.Handler { return nil }
// ================ 注册函数 ================
// RegisterMiddlewares 注册中间件
func RegisterMiddlewares(
router *http.GinRouter,
router *sharedhttp.GinRouter,
requestID *middleware.RequestIDMiddleware,
security *middleware.SecurityHeadersMiddleware,
responseTime *middleware.ResponseTimeMiddleware,
cors *middleware.CORSMiddleware,
rateLimit *middleware.RateLimitMiddleware,
requestLogger *middleware.RequestLoggerMiddleware,
tracingMiddleware *TracingMiddleware,
metricsMiddleware *MetricsMiddleware,
traceIDMiddleware *middleware.TraceIDMiddleware,
errorTrackingMiddleware *middleware.ErrorTrackingMiddleware,
requestBodyLogger *middleware.RequestBodyLoggerMiddleware,
) {
// 注册全局中间件
router.RegisterMiddleware(requestID)
router.RegisterMiddleware(security)
router.RegisterMiddleware(responseTime)
router.RegisterMiddleware(cors)
router.RegisterMiddleware(rateLimit)
router.RegisterMiddleware(requestLogger)
router.RegisterMiddleware(tracingMiddleware)
router.RegisterMiddleware(metricsMiddleware)
router.RegisterMiddleware(traceIDMiddleware)
router.RegisterMiddleware(errorTrackingMiddleware)
router.RegisterMiddleware(requestBodyLogger)
}
// RegisterRoutes 注册路由
func RegisterRoutes(
router *http.GinRouter,
userRoutes *routes.UserRoutes,
router *sharedhttp.GinRouter,
userHandler *handlers.UserHandler,
jwtAuth *middleware.JWTAuthMiddleware,
metricsCollector interfaces.MetricsCollector,
) {
// 设置默认路由
router.SetupDefaultRoutes()
// 注册用户路由
userRoutes.RegisterRoutes(router.GetEngine())
userRoutes.RegisterPublicRoutes(router.GetEngine())
userRoutes.RegisterAdminRoutes(router.GetEngine())
userRoutes.RegisterHealthRoutes(router.GetEngine())
if handler := metricsCollector.GetHandler(); handler != nil {
router.GetEngine().GET("/metrics", gin.WrapH(handler))
}
// 打印路由信息
routes.UserRoutes(router.GetEngine(), userHandler, jwtAuth)
router.PrintRoutes()
}
// 生命周期钩子
// RegisterLifecycleHooks 注册生命周期钩子
func RegisterLifecycleHooks(
lc fx.Lifecycle,
@@ -339,29 +552,60 @@ func RegisterLifecycleHooks(
cache interfaces.CacheService,
eventBus interfaces.EventBus,
healthChecker *health.HealthChecker,
router *http.GinRouter,
router *sharedhttp.GinRouter,
userService *services.UserService,
tracer *tracing.Tracer,
prometheusMetrics interfaces.MetricsCollector,
businessMetrics *metrics.BusinessMetrics,
circuitBreakerWrapper *resilience.Wrapper,
retryerWrapper *resilience.RetryerWrapper,
sagaManager *saga.SagaManager,
hookSystem *hooks.HookSystem,
) {
lc.Append(fx.Hook{
OnStart: func(ctx context.Context) error {
logger.Info("Starting application services...")
logger.Info("正在启动应用服务...")
// 初始化高级特性
if err := tracer.Initialize(ctx); err != nil {
logger.Error("初始化追踪器失败", zap.Error(err))
return err
}
if err := hookSystem.Initialize(ctx); err != nil {
logger.Error("初始化钩子系统失败", zap.Error(err))
return err
}
if err := sagaManager.Initialize(ctx); err != nil {
logger.Error("初始化事务管理器失败", zap.Error(err))
return err
}
// 注册服务到健康检查器
healthChecker.RegisterService(userService)
healthChecker.RegisterService(sagaManager)
// 初始化缓存服务
// 启动基础服务
if err := cache.Initialize(ctx); err != nil {
logger.Error("Failed to initialize cache", zap.Error(err))
logger.Error("初始化缓存失败", zap.Error(err))
return err
}
// 启动事件总线
if err := eventBus.Start(ctx); err != nil {
logger.Error("Failed to start event bus", zap.Error(err))
logger.Error("启动事件总线失败", zap.Error(err))
return err
}
// 启动健康检查(如果启用)
if err := hookSystem.Start(ctx); err != nil {
logger.Error("启动钩子系统失败", zap.Error(err))
return err
}
// 注册应用钩子
registerApplicationHooks(hookSystem, businessMetrics, logger)
// 启动健康检查
if cfg.Health.Enabled {
go healthChecker.StartPeriodicCheck(ctx, cfg.Health.Interval)
}
@@ -370,44 +614,65 @@ func RegisterLifecycleHooks(
go func() {
addr := fmt.Sprintf("%s:%s", cfg.Server.Host, cfg.Server.Port)
if err := router.Start(addr); err != nil {
logger.Error("Failed to start HTTP server", zap.Error(err))
logger.Error("启动HTTP服务器失败", zap.Error(err))
}
}()
logger.Info("All services started successfully")
logger.Info("所有服务已成功启动")
return nil
},
OnStop: func(ctx context.Context) error {
logger.Info("Stopping application services...")
logger.Info("正在停止应用服务...")
// 停止HTTP服务
if err := router.Stop(ctx); err != nil {
logger.Error("Failed to stop HTTP server", zap.Error(err))
}
// 按顺序关闭服务
router.Stop(ctx)
hookSystem.Shutdown(ctx)
sagaManager.Shutdown(ctx)
eventBus.Stop(ctx)
tracer.Shutdown(ctx)
cache.Shutdown(ctx)
// 停止事件总线
if err := eventBus.Stop(ctx); err != nil {
logger.Error("Failed to stop event bus", zap.Error(err))
}
// 关闭缓存服务
if err := cache.Shutdown(ctx); err != nil {
logger.Error("Failed to shutdown cache service", zap.Error(err))
}
// 关闭数据库连接
if sqlDB, err := db.DB(); err == nil {
if err := sqlDB.Close(); err != nil {
logger.Error("Failed to close database", zap.Error(err))
}
sqlDB.Close()
}
logger.Info("All services stopped")
logger.Info("所有服务已停止")
return nil
},
})
}
// registerApplicationHooks 注册应用钩子
func registerApplicationHooks(hookSystem *hooks.HookSystem, businessMetrics *metrics.BusinessMetrics, logger *zap.Logger) {
userCreatedHook := &hooks.Hook{
Name: "metrics.user_created",
Priority: 1,
Async: false,
Timeout: 5 * time.Second,
Func: func(ctx context.Context, data interface{}) error {
businessMetrics.RecordUserCreated("register")
logger.Debug("记录用户创建指标")
return nil
},
}
hookSystem.Register("user.created", userCreatedHook)
userLoginHook := &hooks.Hook{
Name: "metrics.user_login",
Priority: 1,
Async: false,
Timeout: 5 * time.Second,
Func: func(ctx context.Context, data interface{}) error {
businessMetrics.RecordUserLogin("web", "success")
logger.Debug("记录用户登录指标")
return nil
},
}
hookSystem.Register("user.logged_in", userLoginHook)
logger.Info("应用钩子已成功注册")
}
// ServiceRegistrar 服务注册器接口
type ServiceRegistrar interface {
RegisterServices() fx.Option

View File

@@ -0,0 +1,72 @@
package dto
import (
"time"
"tyapi-server/internal/domains/user/entities"
)
// SendCodeRequest 发送验证码请求
type SendCodeRequest struct {
Phone string `json:"phone" binding:"required,len=11" example:"13800138000"`
Scene entities.SMSScene `json:"scene" binding:"required,oneof=register login change_password reset_password bind unbind" example:"register"`
}
// SendCodeResponse 发送验证码响应
type SendCodeResponse struct {
Message string `json:"message" example:"验证码发送成功"`
ExpiresAt time.Time `json:"expires_at" example:"2024-01-01T00:05:00Z"`
}
// VerifyCodeRequest 验证验证码请求
type VerifyCodeRequest struct {
Phone string `json:"phone" binding:"required,len=11" example:"13800138000"`
Code string `json:"code" binding:"required,len=6" example:"123456"`
Scene entities.SMSScene `json:"scene" binding:"required,oneof=register login change_password reset_password bind unbind" example:"register"`
}
// SMSCodeResponse SMS验证码记录响应
type SMSCodeResponse struct {
ID string `json:"id" example:"123e4567-e89b-12d3-a456-426614174000"`
Phone string `json:"phone" example:"13800138000"`
Scene entities.SMSScene `json:"scene" example:"register"`
Used bool `json:"used" example:"false"`
ExpiresAt time.Time `json:"expires_at" example:"2024-01-01T00:05:00Z"`
CreatedAt time.Time `json:"created_at" example:"2024-01-01T00:00:00Z"`
}
// SMSCodeListRequest SMS验证码列表请求
type SMSCodeListRequest struct {
Phone string `form:"phone" binding:"omitempty,len=11" example:"13800138000"`
Scene entities.SMSScene `form:"scene" binding:"omitempty,oneof=register login change_password reset_password bind unbind" example:"register"`
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
PageSize int `form:"page_size" binding:"omitempty,min=1,max=100" example:"20"`
}
// 转换方法
func FromSMSCodeEntity(smsCode *entities.SMSCode) *SMSCodeResponse {
if smsCode == nil {
return nil
}
return &SMSCodeResponse{
ID: smsCode.ID,
Phone: smsCode.Phone,
Scene: smsCode.Scene,
Used: smsCode.Used,
ExpiresAt: smsCode.ExpiresAt,
CreatedAt: smsCode.CreatedAt,
}
}
func FromSMSCodeEntities(smsCodes []*entities.SMSCode) []*SMSCodeResponse {
if smsCodes == nil {
return []*SMSCodeResponse{}
}
responses := make([]*SMSCodeResponse, len(smsCodes))
for i, smsCode := range smsCodes {
responses[i] = FromSMSCodeEntity(smsCode)
}
return responses
}

View File

@@ -6,88 +6,40 @@ import (
"tyapi-server/internal/domains/user/entities"
)
// CreateUserRequest 创建用户请求
type CreateUserRequest struct {
Username string `json:"username" binding:"required,min=3,max=50" example:"john_doe"`
Email string `json:"email" binding:"required,email" example:"john@example.com"`
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
FirstName string `json:"first_name" binding:"max=50" example:"John"`
LastName string `json:"last_name" binding:"max=50" example:"Doe"`
Phone string `json:"phone" binding:"omitempty,max=20" example:"+86-13800138000"`
// RegisterRequest 用户注册请求
type RegisterRequest struct {
Phone string `json:"phone" binding:"required,len=11" example:"13800138000"`
Password string `json:"password" binding:"required,min=6,max=128" example:"password123"`
ConfirmPassword string `json:"confirm_password" binding:"required,eqfield=Password" example:"password123"`
Code string `json:"code" binding:"required,len=6" example:"123456"`
}
// UpdateUserRequest 更新用户请求
type UpdateUserRequest struct {
FirstName *string `json:"first_name,omitempty" binding:"omitempty,max=50" example:"John"`
LastName *string `json:"last_name,omitempty" binding:"omitempty,max=50" example:"Doe"`
Phone *string `json:"phone,omitempty" binding:"omitempty,max=20" example:"+86-13800138000"`
Avatar *string `json:"avatar,omitempty" binding:"omitempty,url" example:"https://example.com/avatar.jpg"`
// LoginWithPasswordRequest 密码登录请求
type LoginWithPasswordRequest struct {
Phone string `json:"phone" binding:"required,len=11" example:"13800138000"`
Password string `json:"password" binding:"required" example:"password123"`
}
// LoginWithSMSRequest 短信验证码登录请求
type LoginWithSMSRequest struct {
Phone string `json:"phone" binding:"required,len=11" example:"13800138000"`
Code string `json:"code" binding:"required,len=6" example:"123456"`
}
// ChangePasswordRequest 修改密码请求
type ChangePasswordRequest struct {
OldPassword string `json:"old_password" binding:"required" example:"oldpassword123"`
NewPassword string `json:"new_password" binding:"required,min=6,max=128" example:"newpassword123"`
OldPassword string `json:"old_password" binding:"required" example:"oldpassword123"`
NewPassword string `json:"new_password" binding:"required,min=6,max=128" example:"newpassword123"`
ConfirmNewPassword string `json:"confirm_new_password" binding:"required,eqfield=NewPassword" example:"newpassword123"`
Code string `json:"code" binding:"required,len=6" example:"123456"`
}
// UserResponse 用户响应
type UserResponse struct {
ID string `json:"id" example:"123e4567-e89b-12d3-a456-426614174000"`
Username string `json:"username" example:"john_doe"`
Email string `json:"email" example:"john@example.com"`
FirstName string `json:"first_name" example:"John"`
LastName string `json:"last_name" example:"Doe"`
Phone string `json:"phone" example:"+86-13800138000"`
Avatar string `json:"avatar" example:"https://example.com/avatar.jpg"`
Status entities.UserStatus `json:"status" example:"active"`
LastLoginAt *time.Time `json:"last_login_at,omitempty" example:"2024-01-01T00:00:00Z"`
CreatedAt time.Time `json:"created_at" example:"2024-01-01T00:00:00Z"`
UpdatedAt time.Time `json:"updated_at" example:"2024-01-01T00:00:00Z"`
Profile *UserProfileResponse `json:"profile,omitempty"`
}
// UserProfileResponse 用户档案响应
type UserProfileResponse struct {
Bio string `json:"bio,omitempty" example:"Software Developer"`
Location string `json:"location,omitempty" example:"Beijing, China"`
Website string `json:"website,omitempty" example:"https://johndoe.com"`
Birthday *time.Time `json:"birthday,omitempty" example:"1990-01-01T00:00:00Z"`
Gender string `json:"gender,omitempty" example:"male"`
Timezone string `json:"timezone,omitempty" example:"Asia/Shanghai"`
Language string `json:"language,omitempty" example:"zh-CN"`
}
// UserListRequest 用户列表请求
type UserListRequest struct {
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
PageSize int `form:"page_size" binding:"omitempty,min=1,max=100" example:"20"`
Sort string `form:"sort" binding:"omitempty,oneof=created_at updated_at username email" example:"created_at"`
Order string `form:"order" binding:"omitempty,oneof=asc desc" example:"desc"`
Status entities.UserStatus `form:"status" binding:"omitempty,oneof=active inactive suspended pending" example:"active"`
Search string `form:"search" binding:"omitempty,max=100" example:"john"`
Filters map[string]interface{} `form:"-"`
}
// UserListResponse 用户列表响应
type UserListResponse struct {
Users []*UserResponse `json:"users"`
Pagination PaginationMeta `json:"pagination"`
}
// PaginationMeta 分页元数据
type PaginationMeta struct {
Page int `json:"page" example:"1"`
PageSize int `json:"page_size" example:"20"`
Total int64 `json:"total" example:"100"`
TotalPages int `json:"total_pages" example:"5"`
HasNext bool `json:"has_next" example:"true"`
HasPrev bool `json:"has_prev" example:"false"`
}
// LoginRequest 登录请求
type LoginRequest struct {
Login string `json:"login" binding:"required" example:"john_doe"`
Password string `json:"password" binding:"required" example:"password123"`
ID string `json:"id" example:"123e4567-e89b-12d3-a456-426614174000"`
Phone string `json:"phone" example:"13800138000"`
CreatedAt time.Time `json:"created_at" example:"2024-01-01T00:00:00Z"`
UpdatedAt time.Time `json:"updated_at" example:"2024-01-01T00:00:00Z"`
}
// LoginResponse 登录响应
@@ -96,47 +48,27 @@ type LoginResponse struct {
AccessToken string `json:"access_token" example:"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9..."`
TokenType string `json:"token_type" example:"Bearer"`
ExpiresIn int64 `json:"expires_in" example:"86400"`
}
// UpdateProfileRequest 更新用户档案请求
type UpdateProfileRequest struct {
Bio *string `json:"bio,omitempty" binding:"omitempty,max=500" example:"Software Developer"`
Location *string `json:"location,omitempty" binding:"omitempty,max=100" example:"Beijing, China"`
Website *string `json:"website,omitempty" binding:"omitempty,url" example:"https://johndoe.com"`
Birthday *time.Time `json:"birthday,omitempty" example:"1990-01-01T00:00:00Z"`
Gender *string `json:"gender,omitempty" binding:"omitempty,oneof=male female other" example:"male"`
Timezone *string `json:"timezone,omitempty" binding:"omitempty,max=50" example:"Asia/Shanghai"`
Language *string `json:"language,omitempty" binding:"omitempty,max=10" example:"zh-CN"`
}
// UserStatsResponse 用户统计响应
type UserStatsResponse struct {
TotalUsers int64 `json:"total_users" example:"1000"`
ActiveUsers int64 `json:"active_users" example:"950"`
InactiveUsers int64 `json:"inactive_users" example:"30"`
SuspendedUsers int64 `json:"suspended_users" example:"20"`
NewUsersToday int64 `json:"new_users_today" example:"5"`
NewUsersWeek int64 `json:"new_users_week" example:"25"`
NewUsersMonth int64 `json:"new_users_month" example:"120"`
}
// UserSearchRequest 用户搜索请求
type UserSearchRequest struct {
Query string `form:"q" binding:"required,min=1,max=100" example:"john"`
Page int `form:"page" binding:"omitempty,min=1" example:"1"`
PageSize int `form:"page_size" binding:"omitempty,min=1,max=50" example:"10"`
LoginMethod string `json:"login_method" example:"password"` // password 或 sms
}
// 转换方法
func (r *CreateUserRequest) ToEntity() *entities.User {
func (r *RegisterRequest) ToEntity() *entities.User {
return &entities.User{
Username: r.Username,
Email: r.Email,
Password: r.Password,
FirstName: r.FirstName,
LastName: r.LastName,
Phone: r.Phone,
Status: entities.UserStatusActive,
Phone: r.Phone,
Password: r.Password,
}
}
func (r *LoginWithPasswordRequest) ToEntity() *entities.User {
return &entities.User{
Phone: r.Phone,
Password: r.Password,
}
}
func (r *LoginWithSMSRequest) ToEntity() *entities.User {
return &entities.User{
Phone: r.Phone,
}
}
@@ -146,28 +78,9 @@ func FromEntity(user *entities.User) *UserResponse {
}
return &UserResponse{
ID: user.ID,
Username: user.Username,
Email: user.Email,
FirstName: user.FirstName,
LastName: user.LastName,
Phone: user.Phone,
Avatar: user.Avatar,
Status: user.Status,
LastLoginAt: user.LastLoginAt,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
ID: user.ID,
Phone: user.Phone,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
}
func FromEntities(users []*entities.User) []*UserResponse {
if users == nil {
return []*UserResponse{}
}
responses := make([]*UserResponse, len(users))
for i, user := range users {
responses[i] = FromEntity(user)
}
return responses
}

View File

@@ -0,0 +1,88 @@
package entities
import (
"time"
"gorm.io/gorm"
)
// SMSCode 短信验证码记录
type SMSCode struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Phone string `gorm:"type:varchar(20);not null;index" json:"phone"`
Code string `gorm:"type:varchar(10);not null" json:"-"` // 不返回给前端
Scene SMSScene `gorm:"type:varchar(20);not null" json:"scene"`
Used bool `gorm:"default:false" json:"used"`
ExpiresAt time.Time `gorm:"not null" json:"expires_at"`
UsedAt *time.Time `json:"used_at,omitempty"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 额外信息
IP string `gorm:"type:varchar(45)" json:"ip"`
UserAgent string `gorm:"type:varchar(500)" json:"user_agent"`
}
// SMSScene 短信验证码使用场景
type SMSScene string
const (
SMSSceneRegister SMSScene = "register" // 注册
SMSSceneLogin SMSScene = "login" // 登录
SMSSceneChangePassword SMSScene = "change_password" // 修改密码
SMSSceneResetPassword SMSScene = "reset_password" // 重置密码
SMSSceneBind SMSScene = "bind" // 绑定手机号
SMSSceneUnbind SMSScene = "unbind" // 解绑手机号
)
// 实现 Entity 接口
func (s *SMSCode) GetID() string {
return s.ID
}
func (s *SMSCode) GetCreatedAt() time.Time {
return s.CreatedAt
}
func (s *SMSCode) GetUpdatedAt() time.Time {
return s.UpdatedAt
}
// Validate 验证短信验证码
func (s *SMSCode) Validate() error {
if s.Phone == "" {
return &ValidationError{Message: "手机号不能为空"}
}
if s.Code == "" {
return &ValidationError{Message: "验证码不能为空"}
}
if s.Scene == "" {
return &ValidationError{Message: "使用场景不能为空"}
}
if s.ExpiresAt.IsZero() {
return &ValidationError{Message: "过期时间不能为空"}
}
return nil
}
// 业务方法
func (s *SMSCode) IsExpired() bool {
return time.Now().After(s.ExpiresAt)
}
func (s *SMSCode) IsValid() bool {
return !s.Used && !s.IsExpired()
}
func (s *SMSCode) MarkAsUsed() {
s.Used = true
now := time.Now()
s.UsedAt = &now
}
// TableName 指定表名
func (SMSCode) TableName() string {
return "sms_codes"
}

View File

@@ -8,37 +8,14 @@ import (
// User 用户实体
type User struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Username string `gorm:"uniqueIndex;type:varchar(50);not null" json:"username"`
Email string `gorm:"uniqueIndex;type:varchar(100);not null" json:"email"`
Password string `gorm:"type:varchar(255);not null" json:"-"`
FirstName string `gorm:"type:varchar(50)" json:"first_name"`
LastName string `gorm:"type:varchar(50)" json:"last_name"`
Phone string `gorm:"type:varchar(20)" json:"phone"`
Avatar string `gorm:"type:varchar(255)" json:"avatar"`
Status UserStatus `gorm:"type:varchar(20);default:'active'" json:"status"`
LastLoginAt *time.Time `json:"last_login_at"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
// 软删除字段
IsDeleted bool `gorm:"default:false" json:"is_deleted"`
// 版本控制
Version int `gorm:"default:1" json:"version"`
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
Phone string `gorm:"uniqueIndex;type:varchar(20);not null" json:"phone"`
Password string `gorm:"type:varchar(255);not null" json:"-"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
DeletedAt gorm.DeletedAt `gorm:"index" json:"-"`
}
// UserStatus 用户状态枚举
type UserStatus string
const (
UserStatusActive UserStatus = "active"
UserStatusInactive UserStatus = "inactive"
UserStatusSuspended UserStatus = "suspended"
UserStatusPending UserStatus = "pending"
)
// 实现 Entity 接口
func (u *User) GetID() string {
return u.ID
@@ -52,47 +29,13 @@ func (u *User) GetUpdatedAt() time.Time {
return u.UpdatedAt
}
// 业务方法
func (u *User) IsActive() bool {
return u.Status == UserStatusActive && !u.IsDeleted
}
func (u *User) GetFullName() string {
if u.FirstName == "" && u.LastName == "" {
return u.Username
}
return u.FirstName + " " + u.LastName
}
func (u *User) CanLogin() bool {
return u.IsActive() && u.Status != UserStatusSuspended
}
func (u *User) MarkAsDeleted() {
u.IsDeleted = true
u.Status = UserStatusInactive
}
func (u *User) Restore() {
u.IsDeleted = false
u.Status = UserStatusActive
}
func (u *User) UpdateLastLogin() {
now := time.Now()
u.LastLoginAt = &now
}
// 验证方法
func (u *User) Validate() error {
if u.Username == "" {
return NewValidationError("username is required")
}
if u.Email == "" {
return NewValidationError("email is required")
if u.Phone == "" {
return NewValidationError("手机号不能为空")
}
if u.Password == "" {
return NewValidationError("password is required")
return NewValidationError("密码不能为空")
}
return nil
}
@@ -114,25 +57,3 @@ func (e *ValidationError) Error() string {
func NewValidationError(message string) *ValidationError {
return &ValidationError{Message: message}
}
// UserProfile 用户档案(扩展信息)
type UserProfile struct {
ID string `gorm:"primaryKey;type:varchar(36)" json:"id"`
UserID string `gorm:"type:varchar(36);not null;index" json:"user_id"`
Bio string `gorm:"type:text" json:"bio"`
Location string `gorm:"type:varchar(100)" json:"location"`
Website string `gorm:"type:varchar(255)" json:"website"`
Birthday *time.Time `json:"birthday"`
Gender string `gorm:"type:varchar(10)" json:"gender"`
Timezone string `gorm:"type:varchar(50)" json:"timezone"`
Language string `gorm:"type:varchar(10);default:'zh-CN'" json:"language"`
CreatedAt time.Time `gorm:"autoCreateTime" json:"created_at"`
UpdatedAt time.Time `gorm:"autoUpdateTime" json:"updated_at"`
// 关联关系
User *User `gorm:"foreignKey:UserID;references:ID" json:"user,omitempty"`
}
func (UserProfile) TableName() string {
return "user_profiles"
}

View File

@@ -13,15 +13,9 @@ import (
type UserEventType string
const (
UserCreatedEvent UserEventType = "user.created"
UserUpdatedEvent UserEventType = "user.updated"
UserDeletedEvent UserEventType = "user.deleted"
UserRestoredEvent UserEventType = "user.restored"
UserRegisteredEvent UserEventType = "user.registered"
UserLoggedInEvent UserEventType = "user.logged_in"
UserLoggedOutEvent UserEventType = "user.logged_out"
UserPasswordChangedEvent UserEventType = "user.password_changed"
UserStatusChangedEvent UserEventType = "user.status_changed"
UserProfileUpdatedEvent UserEventType = "user.profile_updated"
)
// BaseUserEvent 用户事件基础结构
@@ -99,17 +93,17 @@ func (e *BaseUserEvent) Unmarshal(data []byte) error {
return json.Unmarshal(data, e)
}
// UserCreated 用户创建事件
type UserCreated struct {
// UserRegistered 用户注册事件
type UserRegistered struct {
*BaseUserEvent
User *entities.User `json:"user"`
}
func NewUserCreatedEvent(user *entities.User, correlationID string) *UserCreated {
return &UserCreated{
func NewUserRegisteredEvent(user *entities.User, correlationID string) *UserRegistered {
return &UserRegistered{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserCreatedEvent),
Type: string(UserRegisteredEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
@@ -118,97 +112,28 @@ func NewUserCreatedEvent(user *entities.User, correlationID string) *UserCreated
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email,
"user_id": user.ID,
"phone": user.Phone,
},
},
User: user,
}
}
func (e *UserCreated) GetPayload() interface{} {
func (e *UserRegistered) GetPayload() interface{} {
return e.User
}
// UserUpdated 用户更新事件
type UserUpdated struct {
*BaseUserEvent
UserID string `json:"user_id"`
Changes map[string]interface{} `json:"changes"`
OldValues map[string]interface{} `json:"old_values"`
NewValues map[string]interface{} `json:"new_values"`
}
func NewUserUpdatedEvent(userID string, changes, oldValues, newValues map[string]interface{}, correlationID string) *UserUpdated {
return &UserUpdated{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserUpdatedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"changed_fields": len(changes),
},
},
UserID: userID,
Changes: changes,
OldValues: oldValues,
NewValues: newValues,
}
}
// UserDeleted 用户删除事件
type UserDeleted struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
SoftDelete bool `json:"soft_delete"`
}
func NewUserDeletedEvent(userID, username, email string, softDelete bool, correlationID string) *UserDeleted {
return &UserDeleted{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserDeletedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"email": email,
"soft_delete": softDelete,
},
},
UserID: userID,
Username: username,
Email: email,
SoftDelete: softDelete,
}
}
// UserLoggedIn 用户登录事件
type UserLoggedIn struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
Phone string `json:"phone"`
IPAddress string `json:"ip_address"`
UserAgent string `json:"user_agent"`
}
func NewUserLoggedInEvent(userID, username, ipAddress, userAgent, correlationID string) *UserLoggedIn {
func NewUserLoggedInEvent(userID, phone, ipAddress, userAgent, correlationID string) *UserLoggedIn {
return &UserLoggedIn{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
@@ -222,13 +147,13 @@ func NewUserLoggedInEvent(userID, username, ipAddress, userAgent, correlationID
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"phone": phone,
"ip_address": ipAddress,
"user_agent": userAgent,
},
},
UserID: userID,
Username: username,
Phone: phone,
IPAddress: ipAddress,
UserAgent: userAgent,
}
@@ -237,11 +162,11 @@ func NewUserLoggedInEvent(userID, username, ipAddress, userAgent, correlationID
// UserPasswordChanged 用户密码修改事件
type UserPasswordChanged struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
UserID string `json:"user_id"`
Phone string `json:"phone"`
}
func NewUserPasswordChangedEvent(userID, username, correlationID string) *UserPasswordChanged {
func NewUserPasswordChangedEvent(userID, phone, correlationID string) *UserPasswordChanged {
return &UserPasswordChanged{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
@@ -254,46 +179,11 @@ func NewUserPasswordChangedEvent(userID, username, correlationID string) *UserPa
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"user_id": userID,
"phone": phone,
},
},
UserID: userID,
Username: username,
}
}
// UserStatusChanged 用户状态变更事件
type UserStatusChanged struct {
*BaseUserEvent
UserID string `json:"user_id"`
Username string `json:"username"`
OldStatus entities.UserStatus `json:"old_status"`
NewStatus entities.UserStatus `json:"new_status"`
}
func NewUserStatusChangedEvent(userID, username string, oldStatus, newStatus entities.UserStatus, correlationID string) *UserStatusChanged {
return &UserStatusChanged{
BaseUserEvent: &BaseUserEvent{
ID: uuid.New().String(),
Type: string(UserStatusChangedEvent),
Version: "1.0",
Timestamp: time.Now(),
Source: "user-service",
AggregateID: userID,
AggregateType: "User",
DomainVersion: "1.0",
CorrelationID: correlationID,
Metadata: map[string]interface{}{
"user_id": userID,
"username": username,
"old_status": oldStatus,
"new_status": newStatus,
},
},
UserID: userID,
Username: username,
OldStatus: oldStatus,
NewStatus: newStatus,
UserID: userID,
Phone: phone,
}
}

View File

@@ -1,7 +1,7 @@
package handlers
import (
"strconv"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
@@ -14,211 +14,123 @@ import (
// UserHandler 用户HTTP处理器
type UserHandler struct {
userService *services.UserService
response interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
jwtAuth *middleware.JWTAuthMiddleware
userService interfaces.UserService
smsCodeService *services.SMSCodeService
response interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
jwtAuth *middleware.JWTAuthMiddleware
}
// NewUserHandler 创建用户处理器
func NewUserHandler(
userService *services.UserService,
userService interfaces.UserService,
smsCodeService *services.SMSCodeService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
) *UserHandler {
return &UserHandler{
userService: userService,
response: response,
validator: validator,
logger: logger,
jwtAuth: jwtAuth,
userService: userService,
smsCodeService: smsCodeService,
response: response,
validator: validator,
logger: logger,
jwtAuth: jwtAuth,
}
}
// GetPath 返回处理器路径
func (h *UserHandler) GetPath() string {
return "/users"
}
// GetMethod 返回HTTP方法
func (h *UserHandler) GetMethod() string {
return "GET" // 主要用于列表,具体方法在路由注册时指定
}
// GetMiddlewares 返回中间件
func (h *UserHandler) GetMiddlewares() []gin.HandlerFunc {
return []gin.HandlerFunc{
// 这里可以添加特定的中间件
}
}
// Handle 主处理函数(用于列表)
func (h *UserHandler) Handle(c *gin.Context) {
h.List(c)
}
// RequiresAuth 是否需要认证
func (h *UserHandler) RequiresAuth() bool {
return true
}
// GetPermissions 获取所需权限
func (h *UserHandler) GetPermissions() []string {
return []string{"user:read"}
}
// REST操作实现
// Create 创建用户
func (h *UserHandler) Create(c *gin.Context) {
var req dto.CreateUserRequest
// SendCode 发送验证码
// @Summary 发送短信验证码
// @Description 向指定手机号发送验证码,支持注册、登录、修改密码等场景
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body dto.SendCodeRequest true "发送验证码请求"
// @Success 200 {object} dto.SendCodeResponse "验证码发送成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 429 {object} map[string]interface{} "请求频率限制"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /users/send-code [post]
func (h *UserHandler) SendCode(c *gin.Context) {
var req dto.SendCodeRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return // 响应已在验证器中处理
}
// 创建用户
user, err := h.userService.Create(c.Request.Context(), &req)
if err != nil {
h.logger.Error("Failed to create user", zap.Error(err))
// 获取客户端信息
clientIP := c.ClientIP()
userAgent := c.GetHeader("User-Agent")
// 发送验证码
if err := h.smsCodeService.SendCode(c.Request.Context(), req.Phone, req.Scene, clientIP, userAgent); err != nil {
h.logger.Error("发送验证码失败",
zap.String("phone", req.Phone),
zap.String("scene", string(req.Scene)),
zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Created(c, response, "User created successfully")
}
// GetByID 根据ID获取用户
func (h *UserHandler) GetByID(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "User ID is required")
return
}
// 获取用户
user, err := h.userService.GetByID(c.Request.Context(), id)
if err != nil {
h.logger.Error("Failed to get user", zap.Error(err))
h.response.NotFound(c, "User not found")
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response)
}
// Update 更新用户
func (h *UserHandler) Update(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "User ID is required")
return
}
var req dto.UpdateUserRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return
}
// 更新用户
user, err := h.userService.Update(c.Request.Context(), id, &req)
if err != nil {
h.logger.Error("Failed to update user", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response, "User updated successfully")
}
// Delete 删除用户
func (h *UserHandler) Delete(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "User ID is required")
return
}
// 删除用户
if err := h.userService.Delete(c.Request.Context(), id); err != nil {
h.logger.Error("Failed to delete user", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
h.response.Success(c, nil, "User deleted successfully")
}
// List 获取用户列表
func (h *UserHandler) List(c *gin.Context) {
var req dto.UserListRequest
// 验证查询参数
if err := h.validator.ValidateQuery(c, &req); err != nil {
return
}
// 设置默认值
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 20
}
// 构建查询选项
options := interfaces.ListOptions{
Page: req.Page,
PageSize: req.PageSize,
Sort: req.Sort,
Order: req.Order,
Search: req.Search,
Filters: req.Filters,
}
// 获取用户列表
users, err := h.userService.List(c.Request.Context(), options)
if err != nil {
h.logger.Error("Failed to get user list", zap.Error(err))
h.response.InternalError(c, "Failed to get user list")
return
}
// 获取总数
countOptions := interfaces.CountOptions{
Search: req.Search,
Filters: req.Filters,
}
total, err := h.userService.Count(c.Request.Context(), countOptions)
if err != nil {
h.logger.Error("Failed to count users", zap.Error(err))
h.response.InternalError(c, "Failed to count users")
return
}
// 构建响应
userResponses := dto.FromEntities(users)
pagination := buildPagination(req.Page, req.PageSize, total)
response := &dto.SendCodeResponse{
Message: "验证码发送成功",
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟过期
}
h.response.Paginated(c, userResponses, pagination)
h.response.Success(c, response, "验证码发送成功")
}
// Login 用户登录
func (h *UserHandler) Login(c *gin.Context) {
var req dto.LoginRequest
// Register 用户注册
// @Summary 用户注册
// @Description 使用手机号、密码和验证码进行用户注册,需要确认密码
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body dto.RegisterRequest true "用户注册请求"
// @Success 201 {object} dto.UserResponse "注册成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
// @Failure 409 {object} map[string]interface{} "手机号已存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /users/register [post]
func (h *UserHandler) Register(c *gin.Context) {
var req dto.RegisterRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return // 响应已在验证器中处理
}
// 注册用户
user, err := h.userService.Register(c.Request.Context(), &req)
if err != nil {
h.logger.Error("注册用户失败", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Created(c, response, "用户注册成功")
}
// LoginWithPassword 密码登录
// @Summary 用户密码登录
// @Description 使用手机号和密码进行用户登录返回JWT令牌
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body dto.LoginWithPasswordRequest true "密码登录请求"
// @Success 200 {object} dto.LoginResponse "登录成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "认证失败"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /users/login-password [post]
func (h *UserHandler) LoginWithPassword(c *gin.Context) {
var req dto.LoginWithPasswordRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
@@ -226,18 +138,18 @@ func (h *UserHandler) Login(c *gin.Context) {
}
// 用户登录
user, err := h.userService.Login(c.Request.Context(), &req)
user, err := h.userService.LoginWithPassword(c.Request.Context(), &req)
if err != nil {
h.logger.Error("Login failed", zap.Error(err))
h.response.Unauthorized(c, "Invalid credentials")
h.logger.Error("密码登录失败", zap.Error(err))
h.response.Unauthorized(c, "用户名或密码错误")
return
}
// 生成JWT token
accessToken, err := h.jwtAuth.GenerateToken(user.ID, user.Username, user.Email)
accessToken, err := h.jwtAuth.GenerateToken(user.ID, user.Phone, user.Phone)
if err != nil {
h.logger.Error("Failed to generate token", zap.Error(err))
h.response.InternalError(c, "Failed to generate access token")
h.logger.Error("生成令牌失败", zap.Error(err))
h.response.InternalError(c, "生成访问令牌失败")
return
}
@@ -247,72 +159,109 @@ func (h *UserHandler) Login(c *gin.Context) {
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: 86400, // 24小时从配置获取
LoginMethod: "password",
}
h.response.Success(c, loginResponse, "Login successful")
h.response.Success(c, loginResponse, "登录成功")
}
// Logout 用户登出
func (h *UserHandler) Logout(c *gin.Context) {
// 简单实现客户端删除token即可
// 如果需要服务端黑名单,可以在这里实现
h.response.Success(c, nil, "Logout successful")
}
// GetProfile 获取当前用户信息
func (h *UserHandler) GetProfile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "User not authenticated")
return
}
// 获取用户信息
user, err := h.userService.GetByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error("Failed to get user profile", zap.Error(err))
h.response.NotFound(c, "User not found")
return
}
// 返回响应
response := dto.FromEntity(user)
h.response.Success(c, response)
}
// UpdateProfile 更新当前用户信息
func (h *UserHandler) UpdateProfile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "User not authenticated")
return
}
var req dto.UpdateUserRequest
// LoginWithSMS 短信验证码登录
// @Summary 用户短信验证码登录
// @Description 使用手机号和短信验证码进行用户登录返回JWT令牌
// @Tags 用户认证
// @Accept json
// @Produce json
// @Param request body dto.LoginWithSMSRequest true "短信登录请求"
// @Success 200 {object} dto.LoginResponse "登录成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
// @Failure 401 {object} map[string]interface{} "认证失败"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /users/login-sms [post]
func (h *UserHandler) LoginWithSMS(c *gin.Context) {
var req dto.LoginWithSMSRequest
// 验证请求体
if err := h.validator.BindAndValidate(c, &req); err != nil {
return
}
// 更新用户
user, err := h.userService.Update(c.Request.Context(), userID, &req)
// 用户登录
user, err := h.userService.LoginWithSMS(c.Request.Context(), &req)
if err != nil {
h.logger.Error("Failed to update profile", zap.Error(err))
h.response.BadRequest(c, err.Error())
h.logger.Error("短信登录失败", zap.Error(err))
h.response.Unauthorized(c, err.Error())
return
}
// 返回响应
// 生成JWT token
accessToken, err := h.jwtAuth.GenerateToken(user.ID, user.Phone, user.Phone)
if err != nil {
h.logger.Error("生成令牌失败", zap.Error(err))
h.response.InternalError(c, "生成访问令牌失败")
return
}
// 构建登录响应
loginResponse := &dto.LoginResponse{
User: dto.FromEntity(user),
AccessToken: accessToken,
TokenType: "Bearer",
ExpiresIn: 86400, // 24小时从配置获取
LoginMethod: "sms",
}
h.response.Success(c, loginResponse, "登录成功")
}
// GetProfile 获取当前用户信息
// @Summary 获取当前用户信息
// @Description 根据JWT令牌获取当前登录用户的详细信息
// @Tags 用户管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} dto.UserResponse "用户信息"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "用户不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /users/me [get]
func (h *UserHandler) GetProfile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未认证")
return
}
// 获取用户信息
user, err := h.userService.GetByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error("获取用户资料失败", zap.Error(err))
h.response.NotFound(c, "用户不存在")
return
}
// 返回用户信息
response := dto.FromEntity(user)
h.response.Success(c, response, "Profile updated successfully")
h.response.Success(c, response, "获取用户资料成功")
}
// ChangePassword 修改密码
// @Summary 修改密码
// @Description 使用旧密码、新密码确认和验证码修改当前用户的密码
// @Tags 用户管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body dto.ChangePasswordRequest true "修改密码请求"
// @Success 200 {object} map[string]interface{} "密码修改成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /users/me/password [put]
func (h *UserHandler) ChangePassword(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "User not authenticated")
h.response.Unauthorized(c, "用户未认证")
return
}
@@ -325,78 +274,14 @@ func (h *UserHandler) ChangePassword(c *gin.Context) {
// 修改密码
if err := h.userService.ChangePassword(c.Request.Context(), userID, &req); err != nil {
h.logger.Error("Failed to change password", zap.Error(err))
h.logger.Error("修改密码失败", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "Password changed successfully")
h.response.Success(c, nil, "密码修改成功")
}
// Search 搜索用户
func (h *UserHandler) Search(c *gin.Context) {
var req dto.UserSearchRequest
// 验证查询参数
if err := h.validator.ValidateQuery(c, &req); err != nil {
return
}
// 设置默认值
if req.Page <= 0 {
req.Page = 1
}
if req.PageSize <= 0 {
req.PageSize = 10
}
// 构建查询选项
options := interfaces.ListOptions{
Page: req.Page,
PageSize: req.PageSize,
Search: req.Query,
}
// 搜索用户
users, err := h.userService.Search(c.Request.Context(), req.Query, options)
if err != nil {
h.logger.Error("Failed to search users", zap.Error(err))
h.response.InternalError(c, "Failed to search users")
return
}
// 获取搜索结果总数
countOptions := interfaces.CountOptions{
Search: req.Query,
}
total, err := h.userService.Count(c.Request.Context(), countOptions)
if err != nil {
h.logger.Error("Failed to count search results", zap.Error(err))
h.response.InternalError(c, "Failed to count search results")
return
}
// 构建响应
userResponses := dto.FromEntities(users)
pagination := buildPagination(req.Page, req.PageSize, total)
h.response.Paginated(c, userResponses, pagination)
}
// GetStats 获取用户统计
func (h *UserHandler) GetStats(c *gin.Context) {
stats, err := h.userService.GetStats(c.Request.Context())
if err != nil {
h.logger.Error("Failed to get user stats", zap.Error(err))
h.response.InternalError(c, "Failed to get user statistics")
return
}
h.response.Success(c, stats)
}
// 私有方法
// getCurrentUserID 获取当前用户ID
func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
if userID, exists := c.Get("user_id"); exists {
@@ -406,50 +291,3 @@ func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
}
return ""
}
// parsePageSize 解析页面大小
func (h *UserHandler) parsePageSize(str string, defaultValue int) int {
if str == "" {
return defaultValue
}
if size, err := strconv.Atoi(str); err == nil && size > 0 && size <= 100 {
return size
}
return defaultValue
}
// parsePage 解析页码
func (h *UserHandler) parsePage(str string, defaultValue int) int {
if str == "" {
return defaultValue
}
if page, err := strconv.Atoi(str); err == nil && page > 0 {
return page
}
return defaultValue
}
// buildPagination 构建分页元数据
func buildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
totalPages := int(float64(total) / float64(pageSize))
if float64(total)/float64(pageSize) > float64(totalPages) {
totalPages++
}
if totalPages < 1 {
totalPages = 1
}
return interfaces.PaginationMeta{
Page: page,
PageSize: pageSize,
Total: total,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
}

View File

@@ -0,0 +1,120 @@
package repositories
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/shared/interfaces"
)
// SMSCodeRepository 短信验证码仓储
type SMSCodeRepository struct {
db *gorm.DB
cache interfaces.CacheService
logger *zap.Logger
}
// NewSMSCodeRepository 创建短信验证码仓储
func NewSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) *SMSCodeRepository {
return &SMSCodeRepository{
db: db,
cache: cache,
logger: logger,
}
}
// Create 创建短信验证码记录
func (r *SMSCodeRepository) Create(ctx context.Context, smsCode *entities.SMSCode) error {
if err := r.db.WithContext(ctx).Create(smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
return err
}
// 缓存验证码
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
r.cache.Set(ctx, cacheKey, smsCode, 5*time.Minute)
return nil
}
// GetValidCode 获取有效的验证码
func (r *SMSCodeRepository) GetValidCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
// 先从缓存查找
cacheKey := r.buildCacheKey(phone, scene)
var smsCode entities.SMSCode
if err := r.cache.Get(ctx, cacheKey, &smsCode); err == nil {
return &smsCode, nil
}
// 从数据库查找最新的有效验证码
if err := r.db.WithContext(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
return &smsCode, nil
}
// MarkAsUsed 标记验证码为已使用
func (r *SMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error {
now := time.Now()
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("id = ?", id).
Update("used_at", now).Error; err != nil {
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
return err
}
r.logger.Info("验证码已标记为使用", zap.String("code_id", id))
return nil
}
// CleanupExpired 清理过期的验证码
func (r *SMSCodeRepository) CleanupExpired(ctx context.Context) error {
result := r.db.WithContext(ctx).
Where("expires_at < ?", time.Now()).
Delete(&entities.SMSCode{})
if result.Error != nil {
r.logger.Error("清理过期验证码失败", zap.Error(result.Error))
return result.Error
}
if result.RowsAffected > 0 {
r.logger.Info("清理过期验证码完成", zap.Int64("count", result.RowsAffected))
}
return nil
}
// CountRecentCodes 统计最近发送的验证码数量
func (r *SMSCodeRepository) CountRecentCodes(ctx context.Context, phone string, scene entities.SMSScene, duration time.Duration) (int64, error) {
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?",
phone, scene, time.Now().Add(-duration)).
Count(&count).Error; err != nil {
r.logger.Error("统计最近验证码数量失败", zap.Error(err))
return 0, err
}
return count, nil
}
// buildCacheKey 构建缓存键
func (r *SMSCodeRepository) buildCacheKey(phone string, scene entities.SMSScene) string {
return fmt.Sprintf("sms_code:%s:%s", phone, string(scene))
}

View File

@@ -2,6 +2,7 @@ package repositories
import (
"context"
"errors"
"fmt"
"time"
@@ -12,6 +13,12 @@ import (
"tyapi-server/internal/shared/interfaces"
)
// 定义错误常量
var (
// ErrUserNotFound 用户不存在错误
ErrUserNotFound = errors.New("用户不存在")
)
// UserRepository 用户仓储实现
type UserRepository struct {
db *gorm.DB
@@ -29,311 +36,150 @@ func NewUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.L
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, entity *entities.User) error {
if err := r.db.WithContext(ctx).Create(entity).Error; err != nil {
r.logger.Error("Failed to create user", zap.Error(err))
func (r *UserRepository) Create(ctx context.Context, user *entities.User) error {
if err := r.db.WithContext(ctx).Create(user).Error; err != nil {
r.logger.Error("创建用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, entity.ID)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
return nil
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(ctx context.Context, id string) (*entities.User, error) {
// 尝试从缓存获取
cacheKey := r.GetCacheKey(id)
// 尝试从缓存获取
cacheKey := fmt.Sprintf("user:id:%s", id)
var user entities.User
if err := r.cache.Get(ctx, cacheKey, &user); err == nil {
return &user, nil
}
// 从数据库获取
if err := r.db.WithContext(ctx).Where("id = ? AND is_deleted = false", id).First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("user not found")
// 从数据库查询
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据ID查询用户失败", zap.Error(err))
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, &user, 1*time.Hour)
r.cache.Set(ctx, cacheKey, &user, 10*time.Minute)
return &user, nil
}
// FindByPhone 根据手机号查找用户
func (r *UserRepository) FindByPhone(ctx context.Context, phone string) (*entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("user:phone:%s", phone)
var user entities.User
if err := r.cache.Get(ctx, cacheKey, &user); err == nil {
return &user, nil
}
// 从数据库查询
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据手机号查询用户失败", zap.Error(err))
return nil, err
}
// 缓存结果
r.cache.Set(ctx, cacheKey, &user, 10*time.Minute)
return &user, nil
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, entity *entities.User) error {
if err := r.db.WithContext(ctx).Save(entity).Error; err != nil {
r.logger.Error("Failed to update user", zap.Error(err))
func (r *UserRepository) Update(ctx context.Context, user *entities.User) error {
if err := r.db.WithContext(ctx).Save(user).Error; err != nil {
r.logger.Error("更新用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, entity.ID)
r.deleteCacheByID(ctx, user.ID)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
return nil
}
// Delete 删除用户
func (r *UserRepository) Delete(ctx context.Context, id string) error {
// 先获取用户信息用于清除缓存
user, err := r.GetByID(ctx, id)
if err != nil {
return err
}
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("Failed to delete user", zap.Error(err))
r.logger.Error("删除用户失败", zap.Error(err))
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, id)
r.deleteCacheByID(ctx, id)
r.deleteCacheByPhone(ctx, user.Phone)
r.logger.Info("用户删除成功", zap.String("user_id", id))
return nil
}
// CreateBatch 批量创建用户
func (r *UserRepository) CreateBatch(ctx context.Context, entities []*entities.User) error {
if err := r.db.WithContext(ctx).CreateInBatches(entities, 100).Error; err != nil {
r.logger.Error("Failed to create users in batch", zap.Error(err))
return err
}
// 清除列表缓存
r.cache.DeletePattern(ctx, "users:list:*")
return nil
}
// GetByIDs 根据ID列表获取用户
func (r *UserRepository) GetByIDs(ctx context.Context, ids []string) ([]*entities.User, error) {
var users []entities.User
if err := r.db.WithContext(ctx).
Where("id IN ? AND is_deleted = false", ids).
Find(&users).Error; err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.User, len(users))
for i := range users {
result[i] = &users[i]
}
return result, nil
}
// UpdateBatch 批量更新用户
func (r *UserRepository) UpdateBatch(ctx context.Context, entities []*entities.User) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for _, entity := range entities {
if err := tx.Save(entity).Error; err != nil {
return err
}
}
return nil
})
}
// DeleteBatch 批量删除用户
func (r *UserRepository) DeleteBatch(ctx context.Context, ids []string) error {
if err := r.db.WithContext(ctx).
Where("id IN ?", ids).
Delete(&entities.User{}).Error; err != nil {
return err
}
// 清除相关缓存
for _, id := range ids {
r.invalidateUserCaches(ctx, id)
}
return nil
}
// List 获取用户列表
func (r *UserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]*entities.User, error) {
// 尝试从缓存获取
cacheKey := fmt.Sprintf("users:list:%d:%d:%s", options.Page, options.PageSize, options.Sort)
// List 分页获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*entities.User, error) {
var users []*entities.User
if err := r.cache.Get(ctx, cacheKey, &users); err == nil {
return users, nil
}
// 从数据库查询
query := r.db.WithContext(ctx).Where("is_deleted = false")
// 应用过滤条件
if options.Search != "" {
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := options.Order
if order == "" {
order = "asc"
}
query = query.Order(fmt.Sprintf("%s %s", options.Sort, order))
} else {
query = query.Order("created_at desc")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
var userEntities []entities.User
if err := query.Find(&userEntities).Error; err != nil {
if err := r.db.WithContext(ctx).Offset(offset).Limit(limit).Find(&users).Error; err != nil {
r.logger.Error("查询用户列表失败", zap.Error(err))
return nil, err
}
// 转换为指针切片
users = make([]*entities.User, len(userEntities))
for i := range userEntities {
users[i] = &userEntities[i]
}
// 缓存结果
r.cache.Set(ctx, cacheKey, users, 30*time.Minute)
return users, nil
}
// Count 统计用户数量
func (r *UserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
query := r.db.WithContext(ctx).Model(&entities.User{}).Where("is_deleted = false")
// 应用过滤条件
if options.Search != "" {
query = query.Where("username ILIKE ? OR email ILIKE ? OR first_name ILIKE ? OR last_name ILIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
// Count 获取用户总数
func (r *UserRepository) Count(ctx context.Context) (int64, error) {
var count int64
if err := query.Count(&count).Error; err != nil {
if err := r.db.WithContext(ctx).Model(&entities.User{}).Count(&count).Error; err != nil {
r.logger.Error("统计用户数量失败", zap.Error(err))
return 0, err
}
return count, nil
}
// Exists 检查用户是否存在
func (r *UserRepository) Exists(ctx context.Context, id string) (bool, error) {
// ExistsByPhone 检查手机号是否存在
func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ? AND is_deleted = false", id).
Count(&count).Error; err != nil {
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("phone = ?", phone).Count(&count).Error; err != nil {
r.logger.Error("检查手机号是否存在失败", zap.Error(err))
return false, err
}
return count > 0, nil
}
// SoftDelete 软删除用户
func (r *UserRepository) SoftDelete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", id).
Update("is_deleted", true).Error; err != nil {
return err
}
// 私有辅助方法
// 清除相关缓存
r.invalidateUserCaches(ctx, id)
return nil
}
// Restore 恢复用户
func (r *UserRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).
Model(&entities.User{}).
Where("id = ?", id).
Update("is_deleted", false).Error; err != nil {
return err
}
// 清除相关缓存
r.invalidateUserCaches(ctx, id)
return nil
}
// WithTx 使用事务
func (r *UserRepository) WithTx(tx interface{}) interfaces.Repository[*entities.User] {
gormTx, ok := tx.(*gorm.DB)
if !ok {
return r
}
return &UserRepository{
db: gormTx,
cache: r.cache,
logger: r.logger,
// deleteCacheByID 根据ID删除缓存
func (r *UserRepository) deleteCacheByID(ctx context.Context, id string) {
cacheKey := fmt.Sprintf("user:id:%s", id)
if err := r.cache.Delete(ctx, cacheKey); err != nil {
r.logger.Warn("删除用户ID缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
}
}
// InvalidateCache 清除缓存
func (r *UserRepository) InvalidateCache(ctx context.Context, keys ...string) error {
return r.cache.Delete(ctx, keys...)
}
// WarmupCache 预热缓存
func (r *UserRepository) WarmupCache(ctx context.Context) error {
// 预热热门用户数据
// 这里可以实现具体的预热逻辑
return nil
}
// GetCacheKey 获取缓存键
func (r *UserRepository) GetCacheKey(id string) string {
return fmt.Sprintf("user:%s", id)
}
// FindByUsername 根据用户名查找用户
func (r *UserRepository) FindByUsername(ctx context.Context, username string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).
Where("username = ? AND is_deleted = false", username).
First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("user not found")
}
return nil, err
// deleteCacheByPhone 根据手机号删除缓存
func (r *UserRepository) deleteCacheByPhone(ctx context.Context, phone string) {
cacheKey := fmt.Sprintf("user:phone:%s", phone)
if err := r.cache.Delete(ctx, cacheKey); err != nil {
r.logger.Warn("删除用户手机号缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
}
return &user, nil
}
// FindByEmail 根据邮箱查找用户
func (r *UserRepository) FindByEmail(ctx context.Context, email string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).
Where("email = ? AND is_deleted = false", email).
First(&user).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("user not found")
}
return nil, err
}
return &user, nil
}
// invalidateUserCaches 清除用户相关缓存
func (r *UserRepository) invalidateUserCaches(ctx context.Context, userID string) {
keys := []string{
r.GetCacheKey(userID),
}
r.cache.Delete(ctx, keys...)
r.cache.DeletePattern(ctx, "users:list:*")
}

View File

@@ -7,127 +7,23 @@ import (
"github.com/gin-gonic/gin"
)
// UserRoutes 用户路由注册器
type UserRoutes struct {
handler *handlers.UserHandler
jwtAuth *middleware.JWTAuthMiddleware
optionalAuth *middleware.OptionalAuthMiddleware
}
// NewUserRoutes 创建用户路由注册器
func NewUserRoutes(
handler *handlers.UserHandler,
jwtAuth *middleware.JWTAuthMiddleware,
optionalAuth *middleware.OptionalAuthMiddleware,
) *UserRoutes {
return &UserRoutes{
handler: handler,
jwtAuth: jwtAuth,
optionalAuth: optionalAuth,
}
}
// RegisterRoutes 注册用户路由
func (r *UserRoutes) RegisterRoutes(router *gin.Engine) {
// API版本组
v1 := router.Group("/api/v1")
// 公开路由(不需要认证)
public := v1.Group("/auth")
// UserRoutes 注册用户相关路由
func UserRoutes(router *gin.Engine, handler *handlers.UserHandler, authMiddleware *middleware.JWTAuthMiddleware) {
// 用户域路由组
usersGroup := router.Group("/api/v1/users")
{
public.POST("/login", r.handler.Login)
public.POST("/register", r.handler.Create)
}
// 公开路由(不需要认证)
usersGroup.POST("/send-code", handler.SendCode) // 发送验证码
usersGroup.POST("/register", handler.Register) // 用户注册
usersGroup.POST("/login-password", handler.LoginWithPassword) // 密码登录
usersGroup.POST("/login-sms", handler.LoginWithSMS) // 短信验证码登录
// 需要认证的路由
protected := v1.Group("/users")
protected.Use(r.jwtAuth.Handle())
{
// 用户管理(管理员)
protected.GET("", r.handler.List)
protected.POST("", r.handler.Create)
protected.GET("/:id", r.handler.GetByID)
protected.PUT("/:id", r.handler.Update)
protected.DELETE("/:id", r.handler.Delete)
// 用户搜索
protected.GET("/search", r.handler.Search)
// 用户统计
protected.GET("/stats", r.handler.GetStats)
}
// 用户个人操作路由
profile := v1.Group("/profile")
profile.Use(r.jwtAuth.Handle())
{
profile.GET("", r.handler.GetProfile)
profile.PUT("", r.handler.UpdateProfile)
profile.POST("/change-password", r.handler.ChangePassword)
profile.POST("/logout", r.handler.Logout)
}
}
// RegisterPublicRoutes 注册公开路由
func (r *UserRoutes) RegisterPublicRoutes(router *gin.Engine) {
v1 := router.Group("/api/v1")
// 公开的用户相关路由
public := v1.Group("/public")
{
// 可选认证的路由(用户可能登录也可能未登录)
public.Use(r.optionalAuth.Handle())
// 这里可以添加一些公开的用户信息查询接口
// 比如根据用户名查看公开信息(如果用户设置为公开)
}
}
// RegisterAdminRoutes 注册管理员路由
func (r *UserRoutes) RegisterAdminRoutes(router *gin.Engine) {
admin := router.Group("/admin/v1")
admin.Use(r.jwtAuth.Handle())
// 这里可以添加管理员权限检查中间件
// 管理员用户管理
users := admin.Group("/users")
{
users.GET("", r.handler.List)
users.GET("/:id", r.handler.GetByID)
users.PUT("/:id", r.handler.Update)
users.DELETE("/:id", r.handler.Delete)
users.GET("/stats", r.handler.GetStats)
users.GET("/search", r.handler.Search)
// 批量操作
users.POST("/batch-delete", r.handleBatchDelete)
users.POST("/batch-update", r.handleBatchUpdate)
}
}
// 批量删除处理器
func (r *UserRoutes) handleBatchDelete(c *gin.Context) {
// 实现批量删除逻辑
// 这里可以接收用户ID列表并调用服务进行批量删除
c.JSON(200, gin.H{"message": "Batch delete not implemented yet"})
}
// 批量更新处理器
func (r *UserRoutes) handleBatchUpdate(c *gin.Context) {
// 实现批量更新逻辑
c.JSON(200, gin.H{"message": "Batch update not implemented yet"})
}
// RegisterHealthRoutes 注册健康检查路由
func (r *UserRoutes) RegisterHealthRoutes(router *gin.Engine) {
health := router.Group("/health")
{
health.GET("/users", func(c *gin.Context) {
// 用户服务健康检查
c.JSON(200, gin.H{
"service": "users",
"status": "healthy",
})
})
// 需要认证的路由
authenticated := usersGroup.Group("")
authenticated.Use(authMiddleware.Handle())
{
authenticated.GET("/me", handler.GetProfile) // 获取当前用户信息
authenticated.PUT("/me/password", handler.ChangePassword) // 修改密码
}
}
}

View File

@@ -0,0 +1,187 @@
package services
import (
"context"
"fmt"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"tyapi-server/internal/config"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/shared/interfaces"
"tyapi-server/internal/shared/sms"
)
// SMSCodeService 短信验证码服务
type SMSCodeService struct {
repo *repositories.SMSCodeRepository
smsClient sms.Service
cache interfaces.CacheService
config config.SMSConfig
logger *zap.Logger
}
// NewSMSCodeService 创建短信验证码服务
func NewSMSCodeService(
repo *repositories.SMSCodeRepository,
smsClient sms.Service,
cache interfaces.CacheService,
config config.SMSConfig,
logger *zap.Logger,
) *SMSCodeService {
return &SMSCodeService{
repo: repo,
smsClient: smsClient,
cache: cache,
config: config,
logger: logger,
}
}
// SendCode 发送验证码
func (s *SMSCodeService) SendCode(ctx context.Context, phone string, scene entities.SMSScene, clientIP, userAgent string) error {
// 检查频率限制
if err := s.checkRateLimit(ctx, phone); err != nil {
return err
}
// 生成验证码
code := s.smsClient.GenerateCode(s.config.CodeLength)
// 创建SMS验证码记录
smsCode := &entities.SMSCode{
ID: uuid.New().String(),
Phone: phone,
Code: code,
Scene: scene,
IP: clientIP,
UserAgent: userAgent,
Used: false,
ExpiresAt: time.Now().Add(s.config.ExpireTime),
}
// 保存验证码
if err := s.repo.Create(ctx, smsCode); err != nil {
s.logger.Error("保存短信验证码失败",
zap.String("phone", phone),
zap.String("scene", string(scene)),
zap.Error(err))
return fmt.Errorf("保存验证码失败: %w", err)
}
// 发送短信
if err := s.smsClient.SendVerificationCode(ctx, phone, code); err != nil {
// 记录发送失败但不删除验证码记录,让其自然过期
s.logger.Error("发送短信验证码失败",
zap.String("phone", phone),
zap.String("code", code),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
// 更新发送记录缓存
s.updateSendRecord(ctx, phone)
s.logger.Info("短信验证码发送成功",
zap.String("phone", phone),
zap.String("scene", string(scene)))
return nil
}
// VerifyCode 验证验证码
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, code string, scene entities.SMSScene) error {
// 根据手机号和场景获取有效的验证码记录
smsCode, err := s.repo.GetValidCode(ctx, phone, scene)
if err != nil {
return fmt.Errorf("验证码无效或已过期")
}
// 验证验证码是否匹配
if smsCode.Code != code {
return fmt.Errorf("验证码无效或已过期")
}
// 标记验证码为已使用
if err := s.repo.MarkAsUsed(ctx, smsCode.ID); err != nil {
s.logger.Error("标记验证码为已使用失败",
zap.String("code_id", smsCode.ID),
zap.Error(err))
return fmt.Errorf("验证码状态更新失败")
}
s.logger.Info("短信验证码验证成功",
zap.String("phone", phone),
zap.String("scene", string(scene)))
return nil
}
// checkRateLimit 检查发送频率限制
func (s *SMSCodeService) checkRateLimit(ctx context.Context, phone string) error {
now := time.Now()
// 检查最小发送间隔
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
var lastSent time.Time
if err := s.cache.Get(ctx, lastSentKey, &lastSent); err == nil {
if now.Sub(lastSent) < s.config.RateLimit.MinInterval {
return fmt.Errorf("请等待 %v 后再试", s.config.RateLimit.MinInterval)
}
}
// 检查每小时发送限制
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
var hourlyCount int
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
if hourlyCount >= s.config.RateLimit.HourlyLimit {
return fmt.Errorf("每小时最多发送 %d 条短信", s.config.RateLimit.HourlyLimit)
}
}
// 检查每日发送限制
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
var dailyCount int
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
if dailyCount >= s.config.RateLimit.DailyLimit {
return fmt.Errorf("每日最多发送 %d 条短信", s.config.RateLimit.DailyLimit)
}
}
return nil
}
// updateSendRecord 更新发送记录
func (s *SMSCodeService) updateSendRecord(ctx context.Context, phone string) {
now := time.Now()
// 更新最后发送时间
lastSentKey := fmt.Sprintf("sms:last_sent:%s", phone)
s.cache.Set(ctx, lastSentKey, now, s.config.RateLimit.MinInterval)
// 更新每小时计数
hourlyKey := fmt.Sprintf("sms:hourly:%s:%s", phone, now.Format("2006010215"))
var hourlyCount int
if err := s.cache.Get(ctx, hourlyKey, &hourlyCount); err == nil {
s.cache.Set(ctx, hourlyKey, hourlyCount+1, time.Hour)
} else {
s.cache.Set(ctx, hourlyKey, 1, time.Hour)
}
// 更新每日计数
dailyKey := fmt.Sprintf("sms:daily:%s:%s", phone, now.Format("20060102"))
var dailyCount int
if err := s.cache.Get(ctx, dailyKey, &dailyCount); err == nil {
s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour)
} else {
s.cache.Set(ctx, dailyKey, 1, 24*time.Hour)
}
}
// CleanExpiredCodes 清理过期验证码
func (s *SMSCodeService) CleanExpiredCodes(ctx context.Context) error {
return s.repo.CleanupExpired(ctx)
}

View File

@@ -3,7 +3,7 @@ package services
import (
"context"
"fmt"
"time"
"regexp"
"github.com/google/uuid"
"go.uber.org/zap"
@@ -18,21 +18,24 @@ import (
// UserService 用户服务实现
type UserService struct {
repo *repositories.UserRepository
eventBus interfaces.EventBus
logger *zap.Logger
repo *repositories.UserRepository
smsCodeService *SMSCodeService
eventBus interfaces.EventBus
logger *zap.Logger
}
// NewUserService 创建用户服务
func NewUserService(
repo *repositories.UserRepository,
smsCodeService *SMSCodeService,
eventBus interfaces.EventBus,
logger *zap.Logger,
) *UserService {
return &UserService{
repo: repo,
eventBus: eventBus,
logger: logger,
repo: repo,
smsCodeService: smsCodeService,
eventBus: eventBus,
logger: logger,
}
}
@@ -43,341 +46,209 @@ func (s *UserService) Name() string {
// Initialize 初始化服务
func (s *UserService) Initialize(ctx context.Context) error {
s.logger.Info("User service initialized")
s.logger.Info("用户服务已初始化")
return nil
}
// HealthCheck 健康检查
func (s *UserService) HealthCheck(ctx context.Context) error {
// 简单检查:尝试查询用户数量
_, err := s.repo.Count(ctx, interfaces.CountOptions{})
return err
// 简单的健康检查
return nil
}
// Shutdown 关闭服务
func (s *UserService) Shutdown(ctx context.Context) error {
s.logger.Info("User service shutdown")
s.logger.Info("用户服务已关闭")
return nil
}
// Create 创建用户
func (s *UserService) Create(ctx context.Context, createDTO interface{}) (*entities.User, error) {
req, ok := createDTO.(*dto.CreateUserRequest)
if !ok {
return nil, fmt.Errorf("invalid DTO type for user creation")
// Register 用户注册
func (s *UserService) Register(ctx context.Context, registerReq *dto.RegisterRequest) (*entities.User, error) {
// 验证手机号格式
if !s.isValidPhone(registerReq.Phone) {
return nil, fmt.Errorf("手机号格式无效")
}
// 验证业务规则
if err := s.ValidateCreate(ctx, req); err != nil {
return nil, err
// 验证密码确认
if registerReq.Password != registerReq.ConfirmPassword {
return nil, fmt.Errorf("密码和确认密码不匹配")
}
// 检查用户名和邮箱是否已存在
if err := s.checkDuplicates(ctx, req.Username, req.Email); err != nil {
// 验证短信验证码
if err := s.smsCodeService.VerifyCode(ctx, registerReq.Phone, registerReq.Code, entities.SMSSceneRegister); err != nil {
return nil, fmt.Errorf("验证码验证失败: %w", err)
}
// 检查手机号是否已存在
if err := s.checkPhoneDuplicate(ctx, registerReq.Phone); err != nil {
return nil, err
}
// 创建用户实体
user := req.ToEntity()
user := registerReq.ToEntity()
user.ID = uuid.New().String()
// 加密密码
hashedPassword, err := s.hashPassword(req.Password)
// 哈希密码
hashedPassword, err := s.hashPassword(registerReq.Password)
if err != nil {
return nil, fmt.Errorf("failed to hash password: %w", err)
return nil, fmt.Errorf("密码加密失败: %w", err)
}
user.Password = hashedPassword
// 保存用户
if err := s.repo.Create(ctx, user); err != nil {
s.logger.Error("Failed to create user", zap.Error(err))
return nil, fmt.Errorf("failed to create user: %w", err)
s.logger.Error("创建用户失败", zap.Error(err))
return nil, fmt.Errorf("创建用户失败: %w", err)
}
// 发布用户创建事件
event := events.NewUserCreatedEvent(user, s.getCorrelationID(ctx))
// 发布用户注册事件
event := events.NewUserRegisteredEvent(user, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user created event", zap.Error(err))
s.logger.Warn("发布用户注册事件失败", zap.Error(err))
}
s.logger.Info("User created successfully",
s.logger.Info("用户注册成功",
zap.String("user_id", user.ID),
zap.String("username", user.Username))
zap.String("phone", user.Phone))
return user, nil
}
// GetByID 根据ID获取用户
func (s *UserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
if id == "" {
return nil, fmt.Errorf("user ID is required")
}
user, err := s.repo.GetByID(ctx, id)
// LoginWithPassword 密码登录
func (s *UserService) LoginWithPassword(ctx context.Context, loginReq *dto.LoginWithPasswordRequest) (*entities.User, error) {
// 根据手机号查找用户
user, err := s.repo.FindByPhone(ctx, loginReq.Phone)
if err != nil {
return nil, fmt.Errorf("user not found: %w", err)
}
return user, nil
}
// Update 更新用户
func (s *UserService) Update(ctx context.Context, id string, updateDTO interface{}) (*entities.User, error) {
req, ok := updateDTO.(*dto.UpdateUserRequest)
if !ok {
return nil, fmt.Errorf("invalid DTO type for user update")
}
// 验证业务规则
if err := s.ValidateUpdate(ctx, id, req); err != nil {
return nil, err
}
// 获取现有用户
user, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("user not found: %w", err)
}
// 记录变更前的值
oldValues := s.captureUserValues(user)
// 应用更新
s.applyUserUpdates(user, req)
// 保存更新
if err := s.repo.Update(ctx, user); err != nil {
s.logger.Error("Failed to update user", zap.Error(err))
return nil, fmt.Errorf("failed to update user: %w", err)
}
// 发布用户更新事件
newValues := s.captureUserValues(user)
changes := s.findChanges(oldValues, newValues)
if len(changes) > 0 {
event := events.NewUserUpdatedEvent(user.ID, changes, oldValues, newValues, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user updated event", zap.Error(err))
}
}
s.logger.Info("User updated successfully",
zap.String("user_id", user.ID),
zap.Int("changes", len(changes)))
return user, nil
}
// Delete 删除用户
func (s *UserService) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("user ID is required")
}
// 获取用户信息用于事件
user, err := s.repo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("user not found: %w", err)
}
// 软删除用户
if err := s.repo.SoftDelete(ctx, id); err != nil {
s.logger.Error("Failed to delete user", zap.Error(err))
return fmt.Errorf("failed to delete user: %w", err)
}
// 发布用户删除事件
event := events.NewUserDeletedEvent(user.ID, user.Username, user.Email, true, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user deleted event", zap.Error(err))
}
s.logger.Info("User deleted successfully", zap.String("user_id", id))
return nil
}
// List 获取用户列表
func (s *UserService) List(ctx context.Context, options interfaces.ListOptions) ([]*entities.User, error) {
return s.repo.List(ctx, options)
}
// Search 搜索用户
func (s *UserService) Search(ctx context.Context, query string, options interfaces.ListOptions) ([]*entities.User, error) {
// 设置搜索关键字
searchOptions := options
searchOptions.Search = query
return s.repo.List(ctx, searchOptions)
}
// Count 统计用户数量
func (s *UserService) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
return s.repo.Count(ctx, options)
}
// Validate 验证用户实体
func (s *UserService) Validate(ctx context.Context, entity *entities.User) error {
return entity.Validate()
}
// ValidateCreate 验证创建请求
func (s *UserService) ValidateCreate(ctx context.Context, createDTO interface{}) error {
req, ok := createDTO.(*dto.CreateUserRequest)
if !ok {
return fmt.Errorf("invalid DTO type")
}
// 基础验证已经由binding标签处理这里添加业务规则验证
if req.Username == "admin" || req.Username == "root" {
return fmt.Errorf("username '%s' is reserved", req.Username)
}
return nil
}
// ValidateUpdate 验证更新请求
func (s *UserService) ValidateUpdate(ctx context.Context, id string, updateDTO interface{}) error {
_, ok := updateDTO.(*dto.UpdateUserRequest)
if !ok {
return fmt.Errorf("invalid DTO type")
}
if id == "" {
return fmt.Errorf("user ID is required")
}
return nil
}
// 业务方法
// Login 用户登录
func (s *UserService) Login(ctx context.Context, loginReq *dto.LoginRequest) (*entities.User, error) {
// 根据用户名或邮箱查找用户
var user *entities.User
var err error
if s.isEmail(loginReq.Login) {
user, err = s.repo.FindByEmail(ctx, loginReq.Login)
} else {
user, err = s.repo.FindByUsername(ctx, loginReq.Login)
}
if err != nil {
return nil, fmt.Errorf("invalid credentials")
return nil, fmt.Errorf("用户名或密码错误")
}
// 验证密码
if !s.checkPassword(loginReq.Password, user.Password) {
return nil, fmt.Errorf("invalid credentials")
return nil, fmt.Errorf("用户名或密码错误")
}
// 检查用户状态
if !user.CanLogin() {
return nil, fmt.Errorf("account is disabled or suspended")
}
// 更新最后登录时间
user.UpdateLastLogin()
if err := s.repo.Update(ctx, user); err != nil {
s.logger.Warn("Failed to update last login time", zap.Error(err))
}
// 发布登录事件
// 发布用户登录事件
event := events.NewUserLoggedInEvent(
user.ID, user.Username,
user.ID, user.Phone,
s.getClientIP(ctx), s.getUserAgent(ctx),
s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish user logged in event", zap.Error(err))
s.logger.Warn("发布用户登录事件失败", zap.Error(err))
}
s.logger.Info("User logged in successfully",
s.logger.Info("用户密码登录成功",
zap.String("user_id", user.ID),
zap.String("username", user.Username))
zap.String("phone", user.Phone))
return user, nil
}
// LoginWithSMS 短信验证码登录
func (s *UserService) LoginWithSMS(ctx context.Context, loginReq *dto.LoginWithSMSRequest) (*entities.User, error) {
// 验证短信验证码
if err := s.smsCodeService.VerifyCode(ctx, loginReq.Phone, loginReq.Code, entities.SMSSceneLogin); err != nil {
return nil, fmt.Errorf("验证码验证失败: %w", err)
}
// 根据手机号查找用户
user, err := s.repo.FindByPhone(ctx, loginReq.Phone)
if err != nil {
return nil, fmt.Errorf("用户不存在")
}
// 发布用户登录事件
event := events.NewUserLoggedInEvent(
user.ID, user.Phone,
s.getClientIP(ctx), s.getUserAgent(ctx),
s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("发布用户登录事件失败", zap.Error(err))
}
s.logger.Info("用户短信登录成功",
zap.String("user_id", user.ID),
zap.String("phone", user.Phone))
return user, nil
}
// ChangePassword 修改密码
func (s *UserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
// 获取用户
// 验证新密码确认
if req.NewPassword != req.ConfirmNewPassword {
return fmt.Errorf("新密码和确认新密码不匹配")
}
// 获取用户信息
user, err := s.repo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("user not found: %w", err)
return fmt.Errorf("用户不存在: %w", err)
}
// 验证旧密
// 验证短信验证
if err := s.smsCodeService.VerifyCode(ctx, user.Phone, req.Code, entities.SMSSceneChangePassword); err != nil {
return fmt.Errorf("验证码验证失败: %w", err)
}
// 验证当前密码
if !s.checkPassword(req.OldPassword, user.Password) {
return fmt.Errorf("current password is incorrect")
return fmt.Errorf("当前密码错误")
}
// 加密新密码
// 哈希新密码
hashedPassword, err := s.hashPassword(req.NewPassword)
if err != nil {
return fmt.Errorf("failed to hash new password: %w", err)
return fmt.Errorf("新密码加密失败: %w", err)
}
// 更新密码
user.Password = hashedPassword
if err := s.repo.Update(ctx, user); err != nil {
return fmt.Errorf("failed to update password: %w", err)
return fmt.Errorf("密码更新失败: %w", err)
}
// 发布密码修改事件
event := events.NewUserPasswordChangedEvent(user.ID, user.Username, s.getCorrelationID(ctx))
event := events.NewUserPasswordChangedEvent(user.ID, user.Phone, s.getCorrelationID(ctx))
if err := s.eventBus.Publish(ctx, event); err != nil {
s.logger.Warn("Failed to publish password changed event", zap.Error(err))
s.logger.Warn("发布密码修改事件失败", zap.Error(err))
}
s.logger.Info("Password changed successfully", zap.String("user_id", userID))
s.logger.Info("密码修改成功", zap.String("user_id", userID))
return nil
}
// GetStats 获取用户统计
func (s *UserService) GetStats(ctx context.Context) (*dto.UserStatsResponse, error) {
total, err := s.repo.Count(ctx, interfaces.CountOptions{})
if err != nil {
return nil, err
// GetByID 根据ID获取用户
func (s *UserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
if id == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
// 这里可以并行查询不同状态的用户数量
// 简化实现,返回基础统计
return &dto.UserStatsResponse{
TotalUsers: total,
ActiveUsers: total, // 简化
InactiveUsers: 0,
SuspendedUsers: 0,
NewUsersToday: 0,
NewUsersWeek: 0,
NewUsersMonth: 0,
}, nil
user, err := s.repo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("用户不存在: %w", err)
}
return user, nil
}
// 私有方法
// 工具方法
// checkDuplicates 检查重复的用户名和邮箱
func (s *UserService) checkDuplicates(ctx context.Context, username, email string) error {
// 检查用户名
if existingUser, err := s.repo.FindByUsername(ctx, username); err == nil && existingUser != nil {
return fmt.Errorf("username already exists")
// checkPhoneDuplicate 检查手机号重复
func (s *UserService) checkPhoneDuplicate(ctx context.Context, phone string) error {
if _, err := s.repo.FindByPhone(ctx, phone); err == nil {
return fmt.Errorf("手机号已存在")
}
// 检查邮箱
if existingUser, err := s.repo.FindByEmail(ctx, email); err == nil && existingUser != nil {
return fmt.Errorf("email already exists")
}
return nil
}
// hashPassword 加密密码
func (s *UserService) hashPassword(password string) (string, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hash), nil
return string(hashedBytes), nil
}
// checkPassword 验证密码
@@ -386,63 +257,24 @@ func (s *UserService) checkPassword(password, hash string) bool {
return err == nil
}
// isEmail 检查是否为邮箱格式
func (s *UserService) isEmail(str string) bool {
return len(str) > 0 && len(str) < 255 &&
len(str) > 5 &&
str[len(str)-4:] != ".." &&
(len(str) > 6 && str[len(str)-4:] == ".com") ||
(len(str) > 5 && str[len(str)-3:] == ".cn") ||
(len(str) > 6 && str[len(str)-4:] == ".org") ||
(len(str) > 6 && str[len(str)-4:] == ".net")
// 简化的邮箱检查,实际应该使用正则表达式
// isValidPhone 验证手机号格式
func (s *UserService) isValidPhone(phone string) bool {
// 简单的中国手机号验证11位数字以1开头
pattern := `^1[3-9]\d{9}$`
matched, _ := regexp.MatchString(pattern, phone)
return matched
}
// applyUserUpdates 应用用户更新
func (s *UserService) applyUserUpdates(user *entities.User, req *dto.UpdateUserRequest) {
if req.FirstName != nil {
user.FirstName = *req.FirstName
}
if req.LastName != nil {
user.LastName = *req.LastName
}
if req.Phone != nil {
user.Phone = *req.Phone
}
if req.Avatar != nil {
user.Avatar = *req.Avatar
}
user.UpdatedAt = time.Now()
}
// captureUserValues 捕获用户值用于变更比较
func (s *UserService) captureUserValues(user *entities.User) map[string]interface{} {
return map[string]interface{}{
"first_name": user.FirstName,
"last_name": user.LastName,
"phone": user.Phone,
"avatar": user.Avatar,
}
}
// findChanges 找出变更的字段
func (s *UserService) findChanges(oldValues, newValues map[string]interface{}) map[string]interface{} {
changes := make(map[string]interface{})
for key, newValue := range newValues {
if oldValue, exists := oldValues[key]; !exists || oldValue != newValue {
changes[key] = newValue
}
}
return changes
// generateUserID 生成用户ID
func (s *UserService) generateUserID() string {
return uuid.New().String()
}
// getCorrelationID 获取关联ID
func (s *UserService) getCorrelationID(ctx context.Context) string {
if id := ctx.Value("correlation_id"); id != nil {
if correlationID, ok := id.(string); ok {
return correlationID
if strID, ok := id.(string); ok {
return strID
}
}
return uuid.New().String()
@@ -451,19 +283,19 @@ func (s *UserService) getCorrelationID(ctx context.Context) string {
// getClientIP 获取客户端IP
func (s *UserService) getClientIP(ctx context.Context) string {
if ip := ctx.Value("client_ip"); ip != nil {
if clientIP, ok := ip.(string); ok {
return clientIP
if strIP, ok := ip.(string); ok {
return strIP
}
}
return "unknown"
return ""
}
// getUserAgent 获取用户代理
func (s *UserService) getUserAgent(ctx context.Context) string {
if ua := ctx.Value("user_agent"); ua != nil {
if userAgent, ok := ua.(string); ok {
return userAgent
if strUA, ok := ua.(string); ok {
return strUA
}
}
return "unknown"
return ""
}

View File

@@ -36,7 +36,7 @@ func (h *HealthChecker) RegisterService(service interfaces.Service) {
defer h.mutex.Unlock()
h.services[service.Name()] = service
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
h.logger.Info("服务已注册健康检查", zap.String("service", service.Name()))
}
// CheckHealth 检查单个服务健康状态
@@ -47,8 +47,8 @@ func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *in
h.mutex.RUnlock()
return &interfaces.HealthStatus{
Status: "DOWN",
Message: "Service not found",
Details: map[string]interface{}{"error": "service not registered"},
Message: "服务未找到",
Details: map[string]interface{}{"error": "服务未注册"},
CheckedAt: time.Now().Unix(),
ResponseTime: 0,
}
@@ -79,24 +79,24 @@ func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *in
if err != nil {
status.Status = "DOWN"
status.Message = "Health check failed"
status.Message = "健康检查失败"
status.Details = map[string]interface{}{
"error": err.Error(),
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Warn("Service health check failed",
h.logger.Warn("服务健康检查失败",
zap.String("service", serviceName),
zap.Error(err),
zap.Int64("response_time_ms", responseTime))
} else {
status.Status = "UP"
status.Message = "Service is healthy"
status.Message = "服务运行正常"
status.Details = map[string]interface{}{
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Debug("Service health check passed",
h.logger.Debug("服务健康检查通过",
zap.String("service", serviceName),
zap.Int64("response_time_ms", responseTime))
}
@@ -173,13 +173,13 @@ func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.Health
// 确定整体状态
if healthyCount == totalCount {
overall.Status = "UP"
overall.Message = "All services are healthy"
overall.Message = "所有服务运行正常"
} else if healthyCount == 0 {
overall.Status = "DOWN"
overall.Message = "All services are down"
overall.Message = "所有服务均已下线"
} else {
overall.Status = "DEGRADED"
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
overall.Message = fmt.Sprintf("%d/%d 个服务运行正常", healthyCount, totalCount)
}
return overall
@@ -205,7 +205,7 @@ func (h *HealthChecker) RemoveService(serviceName string) {
delete(h.services, serviceName)
delete(h.cache, serviceName)
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
h.logger.Info("服务已从健康检查中移除", zap.String("service", serviceName))
}
// ClearCache 清除缓存
@@ -214,7 +214,7 @@ func (h *HealthChecker) ClearCache() {
defer h.mutex.Unlock()
h.cache = make(map[string]*interfaces.HealthStatus)
h.logger.Debug("Health check cache cleared")
h.logger.Debug("健康检查缓存已清除")
}
// GetCacheStats 获取缓存统计
@@ -243,7 +243,7 @@ func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
defer h.mutex.Unlock()
h.cacheTTL = ttl
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
h.logger.Info("健康检查缓存TTL已更新", zap.Duration("ttl", ttl))
}
// StartPeriodicCheck 启动定期健康检查
@@ -251,12 +251,12 @@ func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Du
ticker := time.NewTicker(interval)
defer ticker.Stop()
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
h.logger.Info("已启动定期健康检查", zap.Duration("interval", interval))
for {
select {
case <-ctx.Done():
h.logger.Info("Stopped periodic health check")
h.logger.Info("已停止定期健康检查")
return
case <-ticker.C:
h.performPeriodicCheck(ctx)
@@ -268,14 +268,14 @@ func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Du
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
overall := h.GetOverallStatus(ctx)
h.logger.Info("Periodic health check completed",
h.logger.Info("定期健康检查已完成",
zap.String("overall_status", overall.Status),
zap.String("message", overall.Message),
zap.Int64("response_time_ms", overall.ResponseTime))
// 如果有服务下线,记录警告
if overall.Status != "UP" {
h.logger.Warn("Some services are not healthy",
h.logger.Warn("部分服务不健康",
zap.String("status", overall.Status),
zap.Any("details", overall.Details))
}

View File

@@ -0,0 +1,587 @@
package hooks
import (
"context"
"fmt"
"reflect"
"sort"
"sync"
"time"
"go.uber.org/zap"
)
// HookPriority 钩子优先级
type HookPriority int
const (
// PriorityLowest 最低优先级
PriorityLowest HookPriority = 0
// PriorityLow 低优先级
PriorityLow HookPriority = 25
// PriorityNormal 普通优先级
PriorityNormal HookPriority = 50
// PriorityHigh 高优先级
PriorityHigh HookPriority = 75
// PriorityHighest 最高优先级
PriorityHighest HookPriority = 100
)
// HookFunc 钩子函数类型
type HookFunc func(ctx context.Context, data interface{}) error
// Hook 钩子定义
type Hook struct {
Name string
Func HookFunc
Priority HookPriority
Async bool
Timeout time.Duration
}
// HookResult 钩子执行结果
type HookResult struct {
HookName string `json:"hook_name"`
Success bool `json:"success"`
Duration time.Duration `json:"duration"`
Error string `json:"error,omitempty"`
}
// HookConfig 钩子配置
type HookConfig struct {
// 默认超时时间
DefaultTimeout time.Duration
// 是否记录执行时间
TrackDuration bool
// 错误处理策略
ErrorStrategy ErrorStrategy
}
// ErrorStrategy 错误处理策略
type ErrorStrategy int
const (
// ContinueOnError 遇到错误继续执行
ContinueOnError ErrorStrategy = iota
// StopOnError 遇到错误停止执行
StopOnError
// CollectErrors 收集所有错误
CollectErrors
)
// DefaultHookConfig 默认钩子配置
func DefaultHookConfig() HookConfig {
return HookConfig{
DefaultTimeout: 30 * time.Second,
TrackDuration: true,
ErrorStrategy: ContinueOnError,
}
}
// HookSystem 钩子系统
type HookSystem struct {
hooks map[string][]*Hook
config HookConfig
logger *zap.Logger
mutex sync.RWMutex
stats map[string]*HookStats
}
// HookStats 钩子统计
type HookStats struct {
TotalExecutions int `json:"total_executions"`
Successes int `json:"successes"`
Failures int `json:"failures"`
TotalDuration time.Duration `json:"total_duration"`
AverageDuration time.Duration `json:"average_duration"`
LastExecution time.Time `json:"last_execution"`
LastError string `json:"last_error,omitempty"`
}
// NewHookSystem 创建钩子系统
func NewHookSystem(config HookConfig, logger *zap.Logger) *HookSystem {
return &HookSystem{
hooks: make(map[string][]*Hook),
config: config,
logger: logger,
stats: make(map[string]*HookStats),
}
}
// Register 注册钩子
func (hs *HookSystem) Register(event string, hook *Hook) error {
if hook.Name == "" {
return fmt.Errorf("hook name cannot be empty")
}
if hook.Func == nil {
return fmt.Errorf("hook function cannot be nil")
}
if hook.Timeout == 0 {
hook.Timeout = hs.config.DefaultTimeout
}
hs.mutex.Lock()
defer hs.mutex.Unlock()
// 检查是否已经注册了同名钩子
for _, existingHook := range hs.hooks[event] {
if existingHook.Name == hook.Name {
return fmt.Errorf("hook %s already registered for event %s", hook.Name, event)
}
}
hs.hooks[event] = append(hs.hooks[event], hook)
// 按优先级排序
sort.Slice(hs.hooks[event], func(i, j int) bool {
return hs.hooks[event][i].Priority > hs.hooks[event][j].Priority
})
// 初始化统计
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
hs.stats[hookKey] = &HookStats{}
hs.logger.Info("Registered hook",
zap.String("event", event),
zap.String("hook_name", hook.Name),
zap.Int("priority", int(hook.Priority)),
zap.Bool("async", hook.Async))
return nil
}
// RegisterFunc 注册钩子函数(简化版)
func (hs *HookSystem) RegisterFunc(event, name string, priority HookPriority, fn HookFunc) error {
hook := &Hook{
Name: name,
Func: fn,
Priority: priority,
Async: false,
Timeout: hs.config.DefaultTimeout,
}
return hs.Register(event, hook)
}
// RegisterAsyncFunc 注册异步钩子函数
func (hs *HookSystem) RegisterAsyncFunc(event, name string, priority HookPriority, fn HookFunc) error {
hook := &Hook{
Name: name,
Func: fn,
Priority: priority,
Async: true,
Timeout: hs.config.DefaultTimeout,
}
return hs.Register(event, hook)
}
// Unregister 取消注册钩子
func (hs *HookSystem) Unregister(event, hookName string) error {
hs.mutex.Lock()
defer hs.mutex.Unlock()
hooks := hs.hooks[event]
for i, hook := range hooks {
if hook.Name == hookName {
// 删除钩子
hs.hooks[event] = append(hooks[:i], hooks[i+1:]...)
// 删除统计
hookKey := fmt.Sprintf("%s.%s", event, hookName)
delete(hs.stats, hookKey)
hs.logger.Info("Unregistered hook",
zap.String("event", event),
zap.String("hook_name", hookName))
return nil
}
}
return fmt.Errorf("hook %s not found for event %s", hookName, event)
}
// Trigger 触发事件
func (hs *HookSystem) Trigger(ctx context.Context, event string, data interface{}) ([]HookResult, error) {
hs.mutex.RLock()
hooks := make([]*Hook, len(hs.hooks[event]))
copy(hooks, hs.hooks[event])
hs.mutex.RUnlock()
if len(hooks) == 0 {
hs.logger.Debug("No hooks registered for event", zap.String("event", event))
return nil, nil
}
hs.logger.Debug("Triggering event",
zap.String("event", event),
zap.Int("hook_count", len(hooks)))
results := make([]HookResult, 0, len(hooks))
var errors []error
for _, hook := range hooks {
result := hs.executeHook(ctx, event, hook, data)
results = append(results, result)
if !result.Success {
err := fmt.Errorf("hook %s failed: %s", hook.Name, result.Error)
errors = append(errors, err)
// 根据错误策略决定是否继续
if hs.config.ErrorStrategy == StopOnError {
break
}
}
}
// 处理错误
if len(errors) > 0 {
switch hs.config.ErrorStrategy {
case StopOnError:
return results, errors[0]
case CollectErrors:
return results, fmt.Errorf("multiple hook errors: %v", errors)
case ContinueOnError:
// 继续执行,但记录错误
hs.logger.Warn("Some hooks failed but continuing execution",
zap.String("event", event),
zap.Int("error_count", len(errors)))
}
}
return results, nil
}
// executeHook 执行单个钩子
func (hs *HookSystem) executeHook(ctx context.Context, event string, hook *Hook, data interface{}) HookResult {
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
start := time.Now()
result := HookResult{
HookName: hook.Name,
Success: false,
}
// 更新统计
defer func() {
result.Duration = time.Since(start)
hs.updateStats(hookKey, result)
}()
if hook.Async {
// 异步执行
go func() {
hs.doExecuteHook(ctx, hook, data)
}()
result.Success = true // 异步执行总是认为成功
return result
}
// 同步执行
err := hs.doExecuteHook(ctx, hook, data)
if err != nil {
result.Error = err.Error()
hs.logger.Error("Hook execution failed",
zap.String("event", event),
zap.String("hook_name", hook.Name),
zap.Error(err))
} else {
result.Success = true
hs.logger.Debug("Hook executed successfully",
zap.String("event", event),
zap.String("hook_name", hook.Name))
}
return result
}
// doExecuteHook 实际执行钩子
func (hs *HookSystem) doExecuteHook(ctx context.Context, hook *Hook, data interface{}) error {
// 设置超时上下文
hookCtx, cancel := context.WithTimeout(ctx, hook.Timeout)
defer cancel()
// 在goroutine中执行以便处理超时
errChan := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("hook panicked: %v", r)
}
}()
errChan <- hook.Func(hookCtx, data)
}()
select {
case err := <-errChan:
return err
case <-hookCtx.Done():
return fmt.Errorf("hook execution timeout after %v", hook.Timeout)
}
}
// updateStats 更新统计信息
func (hs *HookSystem) updateStats(hookKey string, result HookResult) {
hs.mutex.Lock()
defer hs.mutex.Unlock()
stats, exists := hs.stats[hookKey]
if !exists {
stats = &HookStats{}
hs.stats[hookKey] = stats
}
stats.TotalExecutions++
stats.LastExecution = time.Now()
if result.Success {
stats.Successes++
} else {
stats.Failures++
stats.LastError = result.Error
}
if hs.config.TrackDuration {
stats.TotalDuration += result.Duration
stats.AverageDuration = stats.TotalDuration / time.Duration(stats.TotalExecutions)
}
}
// GetHooks 获取事件的所有钩子
func (hs *HookSystem) GetHooks(event string) []*Hook {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
hooks := make([]*Hook, len(hs.hooks[event]))
copy(hooks, hs.hooks[event])
return hooks
}
// GetEvents 获取所有注册的事件
func (hs *HookSystem) GetEvents() []string {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
events := make([]string, 0, len(hs.hooks))
for event := range hs.hooks {
events = append(events, event)
}
sort.Strings(events)
return events
}
// GetStats 获取钩子统计信息
func (hs *HookSystem) GetStats() map[string]*HookStats {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
stats := make(map[string]*HookStats)
for key, stat := range hs.stats {
statCopy := *stat
stats[key] = &statCopy
}
return stats
}
// GetEventStats 获取特定事件的统计信息
func (hs *HookSystem) GetEventStats(event string) map[string]*HookStats {
allStats := hs.GetStats()
eventStats := make(map[string]*HookStats)
prefix := event + "."
for key, stat := range allStats {
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
hookName := key[len(prefix):]
eventStats[hookName] = stat
}
}
return eventStats
}
// Clear 清除所有钩子
func (hs *HookSystem) Clear() {
hs.mutex.Lock()
defer hs.mutex.Unlock()
hs.hooks = make(map[string][]*Hook)
hs.stats = make(map[string]*HookStats)
hs.logger.Info("Cleared all hooks")
}
// ClearEvent 清除特定事件的所有钩子
func (hs *HookSystem) ClearEvent(event string) {
hs.mutex.Lock()
defer hs.mutex.Unlock()
// 删除钩子
delete(hs.hooks, event)
// 删除统计
prefix := event + "."
for key := range hs.stats {
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
delete(hs.stats, key)
}
}
hs.logger.Info("Cleared hooks for event", zap.String("event", event))
}
// Count 获取钩子总数
func (hs *HookSystem) Count() int {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
total := 0
for _, hooks := range hs.hooks {
total += len(hooks)
}
return total
}
// EventCount 获取特定事件的钩子数量
func (hs *HookSystem) EventCount(event string) int {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
return len(hs.hooks[event])
}
// 实现Service接口
// Name 返回服务名称
func (hs *HookSystem) Name() string {
return "hook-system"
}
// Initialize 初始化钩子系统
func (hs *HookSystem) Initialize(ctx context.Context) error {
hs.logger.Info("Hook system initialized")
return nil
}
// Start 启动钩子系统
func (hs *HookSystem) Start(ctx context.Context) error {
hs.logger.Info("Hook system started")
return nil
}
// HealthCheck 健康检查
func (hs *HookSystem) HealthCheck(ctx context.Context) error {
return nil
}
// Shutdown 关闭钩子系统
func (hs *HookSystem) Shutdown(ctx context.Context) error {
hs.logger.Info("Hook system shutdown")
return nil
}
// 便捷方法
// OnUserCreated 用户创建事件钩子
func (hs *HookSystem) OnUserCreated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.created", name, priority, fn)
}
// OnUserUpdated 用户更新事件钩子
func (hs *HookSystem) OnUserUpdated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.updated", name, priority, fn)
}
// OnUserDeleted 用户删除事件钩子
func (hs *HookSystem) OnUserDeleted(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.deleted", name, priority, fn)
}
// OnOrderCreated 订单创建事件钩子
func (hs *HookSystem) OnOrderCreated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("order.created", name, priority, fn)
}
// OnOrderCompleted 订单完成事件钩子
func (hs *HookSystem) OnOrderCompleted(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("order.completed", name, priority, fn)
}
// TriggerUserCreated 触发用户创建事件
func (hs *HookSystem) TriggerUserCreated(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.created", user)
}
// TriggerUserUpdated 触发用户更新事件
func (hs *HookSystem) TriggerUserUpdated(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.updated", user)
}
// TriggerUserDeleted 触发用户删除事件
func (hs *HookSystem) TriggerUserDeleted(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.deleted", user)
}
// HookBuilder 钩子构建器
type HookBuilder struct {
hook *Hook
}
// NewHookBuilder 创建钩子构建器
func NewHookBuilder(name string, fn HookFunc) *HookBuilder {
return &HookBuilder{
hook: &Hook{
Name: name,
Func: fn,
Priority: PriorityNormal,
Async: false,
Timeout: 30 * time.Second,
},
}
}
// WithPriority 设置优先级
func (hb *HookBuilder) WithPriority(priority HookPriority) *HookBuilder {
hb.hook.Priority = priority
return hb
}
// WithTimeout 设置超时时间
func (hb *HookBuilder) WithTimeout(timeout time.Duration) *HookBuilder {
hb.hook.Timeout = timeout
return hb
}
// Async 设置为异步执行
func (hb *HookBuilder) Async() *HookBuilder {
hb.hook.Async = true
return hb
}
// Build 构建钩子
func (hb *HookBuilder) Build() *Hook {
return hb.hook
}
// TypedHookFunc 类型化钩子函数
type TypedHookFunc[T any] func(ctx context.Context, data T) error
// RegisterTypedFunc 注册类型化钩子函数
func RegisterTypedFunc[T any](hs *HookSystem, event, name string, priority HookPriority, fn TypedHookFunc[T]) error {
hookFunc := func(ctx context.Context, data interface{}) error {
typedData, ok := data.(T)
if !ok {
return fmt.Errorf("invalid data type for hook %s, expected %s", name, reflect.TypeOf((*T)(nil)).Elem().Name())
}
return fn(ctx, typedData)
}
return hs.RegisterFunc(event, name, priority, hookFunc)
}

View File

@@ -20,7 +20,7 @@ func NewResponseBuilder() interfaces.ResponseBuilder {
// Success 成功响应
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
msg := "Success"
msg := "操作成功"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -38,7 +38,7 @@ func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...s
// Created 创建成功响应
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
msg := "Created successfully"
msg := "创建成功"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -58,7 +58,7 @@ func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...s
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
// 根据错误类型确定状态码
statusCode := http.StatusInternalServerError
message := "Internal server error"
message := "服务器内部错误"
errorDetail := err.Error()
// 这里可以根据不同的错误类型设置不同的状态码
@@ -93,7 +93,7 @@ func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...i
// Unauthorized 401错误响应
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
msg := "Unauthorized"
msg := "用户未登录或认证已过期"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -110,7 +110,7 @@ func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
// Forbidden 403错误响应
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
msg := "Forbidden"
msg := "权限不足,无法访问此资源"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -127,7 +127,7 @@ func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
// NotFound 404错误响应
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
msg := "Resource not found"
msg := "请求的资源不存在"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -156,7 +156,7 @@ func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
// InternalError 500错误响应
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
msg := "Internal server error"
msg := "服务器内部错误"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -175,7 +175,7 @@ func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
response := interfaces.APIResponse{
Success: true,
Message: "Success",
Message: "查询成功",
Data: data,
Pagination: &pagination,
RequestID: r.getRequestID(c),
@@ -215,9 +215,35 @@ func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta
// CustomResponse 自定义响应
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
var message string
switch statusCode {
case http.StatusOK:
message = "请求成功"
case http.StatusCreated:
message = "创建成功"
case http.StatusNoContent:
message = "无内容"
case http.StatusBadRequest:
message = "请求参数错误"
case http.StatusUnauthorized:
message = "认证失败"
case http.StatusForbidden:
message = "权限不足"
case http.StatusNotFound:
message = "资源不存在"
case http.StatusConflict:
message = "资源冲突"
case http.StatusTooManyRequests:
message = "请求过于频繁"
case http.StatusInternalServerError:
message = "服务器内部错误"
default:
message = "未知状态"
}
response := interfaces.APIResponse{
Success: statusCode >= 200 && statusCode < 300,
Message: http.StatusText(statusCode),
Message: message,
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
@@ -230,7 +256,7 @@ func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data in
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
response := interfaces.APIResponse{
Success: false,
Message: "Validation failed",
Message: "请求参数验证失败",
Errors: errors,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
@@ -241,7 +267,7 @@ func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
// TooManyRequests 限流错误响应
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
msg := "Too many requests"
msg := "请求过于频繁,请稍后再试"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}

View File

@@ -8,6 +8,8 @@ import (
"time"
"github.com/gin-gonic/gin"
swaggerFiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
"go.uber.org/zap"
"tyapi-server/internal/config"
@@ -51,7 +53,7 @@ func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
// 注册路由
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
r.logger.Info("Registered HTTP handler",
r.logger.Info("已注册HTTP处理器",
zap.String("method", handler.GetMethod()),
zap.String("path", handler.GetPath()))
@@ -62,7 +64,7 @@ func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
r.middlewares = append(r.middlewares, middleware)
r.logger.Info("Registered middleware",
r.logger.Info("已注册中间件",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
@@ -93,7 +95,7 @@ func (r *GinRouter) Start(addr string) error {
IdleTimeout: r.config.Server.IdleTimeout,
}
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
r.logger.Info("正在启动HTTP服务器", zap.String("addr", addr))
// 启动服务器
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
@@ -109,15 +111,15 @@ func (r *GinRouter) Stop(ctx context.Context) error {
return nil
}
r.logger.Info("Stopping HTTP server...")
r.logger.Info("正在关闭HTTP服务器...")
// 优雅关闭服务器
if err := r.server.Shutdown(ctx); err != nil {
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
r.logger.Error("优雅关闭服务器失败", zap.Error(err))
return err
}
r.logger.Info("HTTP server stopped")
r.logger.Info("HTTP服务器已关闭")
return nil
}
@@ -137,7 +139,7 @@ func (r *GinRouter) applyMiddlewares() {
for _, middleware := range r.middlewares {
if middleware.IsGlobal() {
r.engine.Use(middleware.Handle())
r.logger.Debug("Applied global middleware",
r.logger.Debug("已应用全局中间件",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
}
@@ -156,6 +158,18 @@ func (r *GinRouter) SetupDefaultRoutes() {
})
})
// 详细健康检查
r.engine.GET("/health/detailed", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp": time.Now().Unix(),
"service": r.config.App.Name,
"version": r.config.App.Version,
"uptime": time.Now().Unix(),
"environment": r.config.App.Env,
})
})
// API信息
r.engine.GET("/info", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
@@ -166,11 +180,37 @@ func (r *GinRouter) SetupDefaultRoutes() {
})
})
// Swagger文档路由 (仅在开发环境启用)
if !r.config.App.IsProduction() {
// Swagger UI
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
// API文档重定向
r.engine.GET("/docs", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/swagger/index.html")
})
// API文档信息
r.engine.GET("/api/docs", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"swagger_ui": fmt.Sprintf("http://%s/swagger/index.html", c.Request.Host),
"openapi_json": fmt.Sprintf("http://%s/swagger/doc.json", c.Request.Host),
"redoc": fmt.Sprintf("http://%s/redoc", c.Request.Host),
"message": "API文档已可用",
})
})
r.logger.Info("Swagger documentation enabled",
zap.String("swagger_url", "/swagger/index.html"),
zap.String("docs_url", "/docs"),
zap.String("api_docs_url", "/api/docs"))
}
// 404处理
r.engine.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Route not found",
"message": "路由未找到",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),
@@ -181,7 +221,7 @@ func (r *GinRouter) SetupDefaultRoutes() {
r.engine.NoMethod(func(c *gin.Context) {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"success": false,
"message": "Method not allowed",
"message": "请求方法不允许",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),

View File

@@ -42,13 +42,13 @@ func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
// ValidateQuery 验证查询参数
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
v.response.BadRequest(c, "Invalid query parameters", err.Error())
v.response.BadRequest(c, "查询参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
@@ -57,13 +57,13 @@ func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error
// ValidateParam 验证路径参数
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
v.response.BadRequest(c, "Invalid path parameters", err.Error())
v.response.BadRequest(c, "路径参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
@@ -73,7 +73,7 @@ func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
// 绑定请求体
if err := c.ShouldBindJSON(dto); err != nil {
v.response.BadRequest(c, "Invalid request body", err.Error())
v.response.BadRequest(c, "请求体格式错误", err.Error())
return err
}
@@ -115,44 +115,74 @@ func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) stri
tag := fieldError.Tag()
param := fieldError.Param()
fieldDisplayName := v.getFieldDisplayName(field)
switch tag {
case "required":
return fmt.Sprintf("%s is required", field)
return fmt.Sprintf("%s 不能为空", fieldDisplayName)
case "email":
return fmt.Sprintf("%s must be a valid email address", field)
return fmt.Sprintf("%s 必须是有效的邮箱地址", fieldDisplayName)
case "min":
return fmt.Sprintf("%s must be at least %s characters", field, param)
return fmt.Sprintf("%s 长度不能少于 %s 位", fieldDisplayName, param)
case "max":
return fmt.Sprintf("%s must be at most %s characters", field, param)
return fmt.Sprintf("%s 长度不能超过 %s 位", fieldDisplayName, param)
case "len":
return fmt.Sprintf("%s must be exactly %s characters", field, param)
return fmt.Sprintf("%s 长度必须为 %s 位", fieldDisplayName, param)
case "gt":
return fmt.Sprintf("%s must be greater than %s", field, param)
return fmt.Sprintf("%s 必须大于 %s", fieldDisplayName, param)
case "gte":
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
return fmt.Sprintf("%s 必须大于等于 %s", fieldDisplayName, param)
case "lt":
return fmt.Sprintf("%s must be less than %s", field, param)
return fmt.Sprintf("%s 必须小于 %s", fieldDisplayName, param)
case "lte":
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
return fmt.Sprintf("%s 必须小于等于 %s", fieldDisplayName, param)
case "oneof":
return fmt.Sprintf("%s must be one of [%s]", field, param)
return fmt.Sprintf("%s 必须是以下值之一:[%s]", fieldDisplayName, param)
case "url":
return fmt.Sprintf("%s must be a valid URL", field)
return fmt.Sprintf("%s 必须是有效的URL地址", fieldDisplayName)
case "alpha":
return fmt.Sprintf("%s must contain only alphabetic characters", field)
return fmt.Sprintf("%s 只能包含字母", fieldDisplayName)
case "alphanum":
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
return fmt.Sprintf("%s 只能包含字母和数字", fieldDisplayName)
case "numeric":
return fmt.Sprintf("%s must be numeric", field)
return fmt.Sprintf("%s 必须是数字", fieldDisplayName)
case "phone":
return fmt.Sprintf("%s must be a valid phone number", field)
return fmt.Sprintf("%s 必须是有效的手机号", fieldDisplayName)
case "username":
return fmt.Sprintf("%s must be a valid username", field)
return fmt.Sprintf("%s 格式不正确,只能包含字母、数字、下划线,且不能以数字开头", fieldDisplayName)
case "strong_password":
return fmt.Sprintf("%s 强度不足必须包含大小写字母和数字且不少于8位", fieldDisplayName)
case "eqfield":
return fmt.Sprintf("%s 必须与 %s 一致", fieldDisplayName, v.getFieldDisplayName(param))
default:
return fmt.Sprintf("%s is invalid", field)
return fmt.Sprintf("%s 格式不正确", fieldDisplayName)
}
}
// getFieldDisplayName 获取字段显示名称(中文)
func (v *RequestValidator) getFieldDisplayName(field string) string {
fieldNames := map[string]string{
"phone": "手机号",
"password": "密码",
"confirm_password": "确认密码",
"old_password": "原密码",
"new_password": "新密码",
"confirm_new_password": "确认新密码",
"code": "验证码",
"username": "用户名",
"email": "邮箱",
"display_name": "显示名称",
"scene": "使用场景",
"Password": "密码",
"NewPassword": "新密码",
}
if displayName, exists := fieldNames[field]; exists {
return displayName
}
return field
}
// toSnakeCase 转换为snake_case
func (v *RequestValidator) toSnakeCase(str string) string {
var result strings.Builder

View File

@@ -0,0 +1,294 @@
package http
import (
"strings"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
zh_translations "github.com/go-playground/validator/v10/translations/zh"
)
// RequestValidatorZh 中文验证器实现
type RequestValidatorZh struct {
validator *validator.Validate
translator ut.Translator
response interfaces.ResponseBuilder
}
// NewRequestValidatorZh 创建支持中文翻译的请求验证器
func NewRequestValidatorZh(response interfaces.ResponseBuilder) interfaces.RequestValidator {
// 创建验证器实例
validate := validator.New()
// 创建中文locale
zhLocale := zh.New()
uni := ut.New(zhLocale, zhLocale)
// 获取中文翻译器
trans, _ := uni.GetTranslator("zh")
// 注册中文翻译
zh_translations.RegisterDefaultTranslations(validate, trans)
// 注册自定义验证器
registerCustomValidatorsZh(validate, trans)
return &RequestValidatorZh{
validator: validate,
translator: trans,
response: response,
}
}
// Validate 验证请求体
func (v *RequestValidatorZh) Validate(c *gin.Context, dto interface{}) error {
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrorsZh(err)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
}
// ValidateQuery 验证查询参数
func (v *RequestValidatorZh) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
v.response.BadRequest(c, "查询参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrorsZh(err)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
}
// ValidateParam 验证路径参数
func (v *RequestValidatorZh) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
v.response.BadRequest(c, "路径参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrorsZh(err)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
}
// BindAndValidate 绑定并验证请求
func (v *RequestValidatorZh) BindAndValidate(c *gin.Context, dto interface{}) error {
// 绑定请求体
if err := c.ShouldBindJSON(dto); err != nil {
v.response.BadRequest(c, "请求体格式错误", err.Error())
return err
}
// 验证数据
return v.Validate(c, dto)
}
// formatValidationErrorsZh 格式化验证错误(中文翻译版)
func (v *RequestValidatorZh) formatValidationErrorsZh(err error) map[string][]string {
errors := make(map[string][]string)
if validationErrors, ok := err.(validator.ValidationErrors); ok {
for _, fieldError := range validationErrors {
fieldName := v.getFieldNameZh(fieldError)
// 首先尝试使用翻译器获取翻译后的错误消息
errorMessage := fieldError.Translate(v.translator)
// 如果翻译后的消息包含英文字段名,则替换为中文字段名
fieldDisplayName := v.getFieldDisplayName(fieldError.Field())
if fieldDisplayName != fieldError.Field() {
// 替换字段名为中文
errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName)
}
if _, exists := errors[fieldName]; !exists {
errors[fieldName] = []string{}
}
errors[fieldName] = append(errors[fieldName], errorMessage)
}
}
return errors
}
// getFieldNameZh 获取字段名JSON标签优先
func (v *RequestValidatorZh) getFieldNameZh(fieldError validator.FieldError) string {
fieldName := fieldError.Field()
return v.toSnakeCase(fieldName)
}
// getFieldDisplayName 获取字段显示名称(中文)
func (v *RequestValidatorZh) getFieldDisplayName(field string) string {
fieldNames := map[string]string{
"phone": "手机号",
"password": "密码",
"confirm_password": "确认密码",
"old_password": "原密码",
"new_password": "新密码",
"confirm_new_password": "确认新密码",
"code": "验证码",
"username": "用户名",
"email": "邮箱",
"display_name": "显示名称",
"scene": "使用场景",
"Password": "密码",
"NewPassword": "新密码",
"ConfirmPassword": "确认密码",
}
if displayName, exists := fieldNames[field]; exists {
return displayName
}
return field
}
// toSnakeCase 转换为snake_case
func (v *RequestValidatorZh) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if i > 0 && (r >= 'A' && r <= 'Z') {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// registerCustomValidatorsZh 注册自定义验证器和中文翻译
func registerCustomValidatorsZh(v *validator.Validate, trans ut.Translator) {
// 注册手机号验证器
v.RegisterValidation("phone", validatePhoneZh)
v.RegisterTranslation("phone", trans, func(ut ut.Translator) error {
return ut.Add("phone", "{0}必须是有效的手机号", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("phone", fe.Field())
return t
})
// 注册用户名验证器
v.RegisterValidation("username", validateUsernameZh)
v.RegisterTranslation("username", trans, func(ut ut.Translator) error {
return ut.Add("username", "{0}格式不正确,只能包含字母、数字、下划线,且不能以数字开头", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("username", fe.Field())
return t
})
// 注册密码强度验证器
v.RegisterValidation("strong_password", validateStrongPasswordZh)
v.RegisterTranslation("strong_password", trans, func(ut ut.Translator) error {
return ut.Add("strong_password", "{0}强度不足必须包含大小写字母和数字且不少于8位", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("strong_password", fe.Field())
return t
})
// 自定义eqfield翻译
v.RegisterTranslation("eqfield", trans, func(ut ut.Translator) error {
return ut.Add("eqfield", "{0}必须等于{1}", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("eqfield", fe.Field(), fe.Param())
return t
})
}
// validatePhoneZh 验证手机号
func validatePhoneZh(fl validator.FieldLevel) bool {
phone := fl.Field().String()
if phone == "" {
return true // 空值由required标签处理
}
// 中国手机号验证11位以1开头
if len(phone) != 11 {
return false
}
if !strings.HasPrefix(phone, "1") {
return false
}
// 检查是否全是数字
for _, r := range phone {
if r < '0' || r > '9' {
return false
}
}
return true
}
// validateUsernameZh 验证用户名
func validateUsernameZh(fl validator.FieldLevel) bool {
username := fl.Field().String()
if username == "" {
return true // 空值由required标签处理
}
// 用户名规则3-30个字符只能包含字母、数字、下划线不能以数字开头
if len(username) < 3 || len(username) > 30 {
return false
}
// 不能以数字开头
if username[0] >= '0' && username[0] <= '9' {
return false
}
// 只能包含字母、数字、下划线
for _, r := range username {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
return false
}
}
return true
}
// validateStrongPasswordZh 验证密码强度
func validateStrongPasswordZh(fl validator.FieldLevel) bool {
password := fl.Field().String()
if password == "" {
return true // 空值由required标签处理
}
// 密码强度规则至少8个字符包含大小写字母、数字
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
for _, r := range password {
switch {
case r >= 'A' && r <= 'Z':
hasUpper = true
case r >= 'a' && r <= 'z':
hasLower = true
case r >= '0' && r <= '9':
hasDigit = true
}
}
return hasUpper && hasLower && hasDigit
}
// ValidateStruct 直接验证结构体
func (v *RequestValidatorZh) ValidateStruct(dto interface{}) error {
return v.validator.Struct(dto)
}

View File

@@ -76,9 +76,14 @@ type ResponseBuilder interface {
NotFound(c *gin.Context, message ...string)
Conflict(c *gin.Context, message string)
InternalError(c *gin.Context, message ...string)
ValidationError(c *gin.Context, errors interface{})
TooManyRequests(c *gin.Context, message ...string)
// 分页响应
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
// 自定义响应
CustomResponse(c *gin.Context, statusCode int, data interface{})
}
// RequestValidator 请求验证器接口
@@ -90,6 +95,9 @@ type RequestValidator interface {
// 绑定和验证
BindAndValidate(c *gin.Context, dto interface{}) error
// 直接验证结构体
ValidateStruct(dto interface{}) error
}
// PaginationMeta 分页元数据

View File

@@ -2,6 +2,15 @@ package interfaces
import (
"context"
"errors"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/entities"
)
// 常见错误定义
var (
ErrCacheMiss = errors.New("cache miss")
)
// Service 通用服务接口
@@ -16,6 +25,22 @@ type Service interface {
Shutdown(ctx context.Context) error
}
// UserService 用户服务接口
type UserService interface {
Service
// 用户注册
Register(ctx context.Context, req *dto.RegisterRequest) (*entities.User, error)
// 密码登录
LoginWithPassword(ctx context.Context, req *dto.LoginWithPasswordRequest) (*entities.User, error)
// 短信验证码登录
LoginWithSMS(ctx context.Context, req *dto.LoginWithSMSRequest) (*entities.User, error)
// 修改密码
ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error
// 根据ID获取用户
GetByID(ctx context.Context, id string) (*entities.User, error)
}
// DomainService 领域服务接口,支持泛型
type DomainService[T Entity] interface {
Service

View File

@@ -0,0 +1,214 @@
package logger
import (
"context"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// LogLevel 日志级别
type LogLevel string
const (
DebugLevel LogLevel = "debug"
InfoLevel LogLevel = "info"
WarnLevel LogLevel = "warn"
ErrorLevel LogLevel = "error"
)
// LogContext 日志上下文
type LogContext struct {
RequestID string
UserID string
TraceID string
OperationName string
Layer string // repository/service/handler
Component string
}
// ContextualLogger 上下文感知的日志器
type ContextualLogger struct {
logger *zap.Logger
ctx LogContext
}
// NewContextualLogger 创建上下文日志器
func NewContextualLogger(logger *zap.Logger) *ContextualLogger {
return &ContextualLogger{
logger: logger,
}
}
// WithContext 添加上下文信息
func (l *ContextualLogger) WithContext(ctx context.Context) *ContextualLogger {
logCtx := LogContext{}
// 从context中提取常用字段
if requestID := getStringFromContext(ctx, "request_id"); requestID != "" {
logCtx.RequestID = requestID
}
if userID := getStringFromContext(ctx, "user_id"); userID != "" {
logCtx.UserID = userID
}
if traceID := getStringFromContext(ctx, "trace_id"); traceID != "" {
logCtx.TraceID = traceID
}
return &ContextualLogger{
logger: l.logger,
ctx: logCtx,
}
}
// WithLayer 设置层级信息
func (l *ContextualLogger) WithLayer(layer string) *ContextualLogger {
newCtx := l.ctx
newCtx.Layer = layer
return &ContextualLogger{
logger: l.logger,
ctx: newCtx,
}
}
// WithComponent 设置组件信息
func (l *ContextualLogger) WithComponent(component string) *ContextualLogger {
newCtx := l.ctx
newCtx.Component = component
return &ContextualLogger{
logger: l.logger,
ctx: newCtx,
}
}
// WithOperation 设置操作名称
func (l *ContextualLogger) WithOperation(operation string) *ContextualLogger {
newCtx := l.ctx
newCtx.OperationName = operation
return &ContextualLogger{
logger: l.logger,
ctx: newCtx,
}
}
// 构建基础字段
func (l *ContextualLogger) buildBaseFields() []zapcore.Field {
fields := []zapcore.Field{}
if l.ctx.RequestID != "" {
fields = append(fields, zap.String("request_id", l.ctx.RequestID))
}
if l.ctx.UserID != "" {
fields = append(fields, zap.String("user_id", l.ctx.UserID))
}
if l.ctx.TraceID != "" {
fields = append(fields, zap.String("trace_id", l.ctx.TraceID))
}
if l.ctx.Layer != "" {
fields = append(fields, zap.String("layer", l.ctx.Layer))
}
if l.ctx.Component != "" {
fields = append(fields, zap.String("component", l.ctx.Component))
}
if l.ctx.OperationName != "" {
fields = append(fields, zap.String("operation", l.ctx.OperationName))
}
return fields
}
// LogTechnicalError 记录技术性错误Repository层
func (l *ContextualLogger) LogTechnicalError(msg string, err error, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.Error(err))
allFields = append(allFields, zap.String("error_type", "technical"))
allFields = append(allFields, fields...)
l.logger.Error(msg, allFields...)
}
// LogBusinessWarn 记录业务警告Service层
func (l *ContextualLogger) LogBusinessWarn(msg string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "business"))
allFields = append(allFields, fields...)
l.logger.Warn(msg, allFields...)
}
// LogBusinessInfo 记录业务信息Service层
func (l *ContextualLogger) LogBusinessInfo(msg string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "business"))
allFields = append(allFields, fields...)
l.logger.Info(msg, allFields...)
}
// LogUserAction 记录用户行为Handler层
func (l *ContextualLogger) LogUserAction(msg string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "user_action"))
allFields = append(allFields, fields...)
l.logger.Info(msg, allFields...)
}
// LogRequestFailed 记录请求失败Handler层
func (l *ContextualLogger) LogRequestFailed(msg string, errorType string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "request_failed"))
allFields = append(allFields, zap.String("error_category", errorType))
allFields = append(allFields, fields...)
l.logger.Info(msg, allFields...)
}
// getStringFromContext 从上下文获取字符串值
func getStringFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if str, ok := value.(string); ok {
return str
}
}
return ""
}
// ErrorCategory 错误分类
type ErrorCategory string
const (
DatabaseError ErrorCategory = "database"
NetworkError ErrorCategory = "network"
ValidationError ErrorCategory = "validation"
BusinessError ErrorCategory = "business"
AuthError ErrorCategory = "auth"
ExternalAPIError ErrorCategory = "external_api"
)
// CategorizeError 错误分类
func CategorizeError(err error) ErrorCategory {
errMsg := strings.ToLower(err.Error())
switch {
case strings.Contains(errMsg, "database") ||
strings.Contains(errMsg, "sql") ||
strings.Contains(errMsg, "gorm"):
return DatabaseError
case strings.Contains(errMsg, "network") ||
strings.Contains(errMsg, "connection") ||
strings.Contains(errMsg, "timeout"):
return NetworkError
case strings.Contains(errMsg, "validation") ||
strings.Contains(errMsg, "invalid") ||
strings.Contains(errMsg, "format"):
return ValidationError
case strings.Contains(errMsg, "unauthorized") ||
strings.Contains(errMsg, "forbidden") ||
strings.Contains(errMsg, "token"):
return AuthError
default:
return BusinessError
}
}

View File

@@ -0,0 +1,263 @@
package metrics
import (
"context"
"sync"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// BusinessMetrics 业务指标收集器
type BusinessMetrics struct {
metrics interfaces.MetricsCollector
logger *zap.Logger
mutex sync.RWMutex
// 业务指标缓存
userMetrics map[string]int64
orderMetrics map[string]int64
}
// NewBusinessMetrics 创建业务指标收集器
func NewBusinessMetrics(metrics interfaces.MetricsCollector, logger *zap.Logger) *BusinessMetrics {
bm := &BusinessMetrics{
metrics: metrics,
logger: logger,
userMetrics: make(map[string]int64),
orderMetrics: make(map[string]int64),
}
// 注册业务指标
bm.registerBusinessMetrics()
return bm
}
// registerBusinessMetrics 注册业务指标
func (bm *BusinessMetrics) registerBusinessMetrics() {
// 用户相关指标
bm.metrics.RegisterCounter("users_created_total", "Total number of users created", []string{"source"})
bm.metrics.RegisterCounter("users_login_total", "Total number of user logins", []string{"method", "status"})
bm.metrics.RegisterGauge("users_active_sessions", "Current number of active user sessions", nil)
// 订单相关指标
bm.metrics.RegisterCounter("orders_created_total", "Total number of orders created", []string{"status"})
bm.metrics.RegisterCounter("orders_amount_total", "Total order amount in cents", []string{"currency"})
bm.metrics.RegisterHistogram("orders_processing_duration_seconds", "Order processing duration", []string{"status"}, []float64{0.1, 0.5, 1, 2, 5, 10, 30})
// API相关指标
bm.metrics.RegisterCounter("api_errors_total", "Total number of API errors", []string{"endpoint", "error_type"})
bm.metrics.RegisterHistogram("api_response_size_bytes", "API response size in bytes", []string{"endpoint"}, []float64{100, 1000, 10000, 100000})
// 缓存相关指标
bm.metrics.RegisterCounter("cache_operations_total", "Total number of cache operations", []string{"operation", "result"})
bm.metrics.RegisterGauge("cache_memory_usage_bytes", "Cache memory usage in bytes", []string{"cache_type"})
// 数据库相关指标
bm.metrics.RegisterHistogram("database_query_duration_seconds", "Database query duration", []string{"operation", "table"}, []float64{0.001, 0.01, 0.1, 1, 10})
bm.metrics.RegisterCounter("database_errors_total", "Total number of database errors", []string{"operation", "error_type"})
bm.logger.Info("Business metrics registered successfully")
}
// User相关指标
// RecordUserCreated 记录用户创建
func (bm *BusinessMetrics) RecordUserCreated(source string) {
bm.metrics.IncrementCounter("users_created_total", map[string]string{
"source": source,
})
bm.mutex.Lock()
bm.userMetrics["created"]++
bm.mutex.Unlock()
bm.logger.Debug("Recorded user created", zap.String("source", source))
}
// RecordUserLogin 记录用户登录
func (bm *BusinessMetrics) RecordUserLogin(method, status string) {
bm.metrics.IncrementCounter("users_login_total", map[string]string{
"method": method,
"status": status,
})
bm.logger.Debug("Recorded user login", zap.String("method", method), zap.String("status", status))
}
// UpdateActiveUserSessions 更新活跃用户会话数
func (bm *BusinessMetrics) UpdateActiveUserSessions(count float64) {
bm.metrics.RecordGauge("users_active_sessions", count, nil)
}
// Order相关指标
// RecordOrderCreated 记录订单创建
func (bm *BusinessMetrics) RecordOrderCreated(status string, amount float64, currency string) {
bm.metrics.IncrementCounter("orders_created_total", map[string]string{
"status": status,
})
// 记录订单金额(以分为单位)
amountCents := int64(amount * 100)
bm.metrics.IncrementCounter("orders_amount_total", map[string]string{
"currency": currency,
})
bm.mutex.Lock()
bm.orderMetrics["created"]++
bm.orderMetrics["amount"] += amountCents
bm.mutex.Unlock()
bm.logger.Debug("Recorded order created",
zap.String("status", status),
zap.Float64("amount", amount),
zap.String("currency", currency))
}
// RecordOrderProcessingDuration 记录订单处理时长
func (bm *BusinessMetrics) RecordOrderProcessingDuration(status string, duration float64) {
bm.metrics.RecordHistogram("orders_processing_duration_seconds", duration, map[string]string{
"status": status,
})
}
// API相关指标
// RecordAPIError 记录API错误
func (bm *BusinessMetrics) RecordAPIError(endpoint, errorType string) {
bm.metrics.IncrementCounter("api_errors_total", map[string]string{
"endpoint": endpoint,
"error_type": errorType,
})
bm.logger.Debug("Recorded API error",
zap.String("endpoint", endpoint),
zap.String("error_type", errorType))
}
// RecordAPIResponseSize 记录API响应大小
func (bm *BusinessMetrics) RecordAPIResponseSize(endpoint string, sizeBytes float64) {
bm.metrics.RecordHistogram("api_response_size_bytes", sizeBytes, map[string]string{
"endpoint": endpoint,
})
}
// Cache相关指标
// RecordCacheOperation 记录缓存操作
func (bm *BusinessMetrics) RecordCacheOperation(operation, result string) {
bm.metrics.IncrementCounter("cache_operations_total", map[string]string{
"operation": operation,
"result": result,
})
}
// UpdateCacheMemoryUsage 更新缓存内存使用量
func (bm *BusinessMetrics) UpdateCacheMemoryUsage(cacheType string, usageBytes float64) {
bm.metrics.RecordGauge("cache_memory_usage_bytes", usageBytes, map[string]string{
"cache_type": cacheType,
})
}
// Database相关指标
// RecordDatabaseQuery 记录数据库查询
func (bm *BusinessMetrics) RecordDatabaseQuery(operation, table string, duration float64) {
bm.metrics.RecordHistogram("database_query_duration_seconds", duration, map[string]string{
"operation": operation,
"table": table,
})
}
// RecordDatabaseError 记录数据库错误
func (bm *BusinessMetrics) RecordDatabaseError(operation, errorType string) {
bm.metrics.IncrementCounter("database_errors_total", map[string]string{
"operation": operation,
"error_type": errorType,
})
bm.logger.Debug("Recorded database error",
zap.String("operation", operation),
zap.String("error_type", errorType))
}
// 获取统计信息
// GetUserStats 获取用户统计
func (bm *BusinessMetrics) GetUserStats() map[string]int64 {
bm.mutex.RLock()
defer bm.mutex.RUnlock()
stats := make(map[string]int64)
for k, v := range bm.userMetrics {
stats[k] = v
}
return stats
}
// GetOrderStats 获取订单统计
func (bm *BusinessMetrics) GetOrderStats() map[string]int64 {
bm.mutex.RLock()
defer bm.mutex.RUnlock()
stats := make(map[string]int64)
for k, v := range bm.orderMetrics {
stats[k] = v
}
return stats
}
// GetOverallStats 获取整体统计
func (bm *BusinessMetrics) GetOverallStats() map[string]interface{} {
return map[string]interface{}{
"user_stats": bm.GetUserStats(),
"order_stats": bm.GetOrderStats(),
}
}
// Reset 重置统计数据
func (bm *BusinessMetrics) Reset() {
bm.mutex.Lock()
defer bm.mutex.Unlock()
bm.userMetrics = make(map[string]int64)
bm.orderMetrics = make(map[string]int64)
bm.logger.Info("Business metrics reset")
}
// Context相关方法
// WithContext 创建带上下文的业务指标收集器
func (bm *BusinessMetrics) WithContext(ctx context.Context) *BusinessMetrics {
// 这里可以从context中提取追踪信息关联指标
return bm
}
// 实现Service接口如果需要
// Name 返回服务名称
func (bm *BusinessMetrics) Name() string {
return "business-metrics"
}
// Initialize 初始化服务
func (bm *BusinessMetrics) Initialize(ctx context.Context) error {
bm.logger.Info("Business metrics service initialized")
return nil
}
// HealthCheck 健康检查
func (bm *BusinessMetrics) HealthCheck(ctx context.Context) error {
// 检查指标收集器是否正常
return nil
}
// Shutdown 关闭服务
func (bm *BusinessMetrics) Shutdown(ctx context.Context) error {
bm.logger.Info("Business metrics service shutdown")
return nil
}

View File

@@ -0,0 +1,353 @@
package metrics
import (
"net/http"
"strconv"
"sync"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
)
// PrometheusMetrics Prometheus指标收集器
type PrometheusMetrics struct {
logger *zap.Logger
registry *prometheus.Registry
mutex sync.RWMutex
// 预定义指标
httpRequests *prometheus.CounterVec
httpDuration *prometheus.HistogramVec
activeUsers prometheus.Gauge
dbConnections prometheus.Gauge
cacheHits *prometheus.CounterVec
businessMetrics map[string]prometheus.Collector
}
// NewPrometheusMetrics 创建Prometheus指标收集器
func NewPrometheusMetrics(logger *zap.Logger) *PrometheusMetrics {
registry := prometheus.NewRegistry()
// HTTP请求计数器
httpRequests := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
)
// HTTP请求耗时直方图
httpDuration := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path"},
)
// 活跃用户数
activeUsers := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "active_users_total",
Help: "Current number of active users",
},
)
// 数据库连接数
dbConnections := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "database_connections_active",
Help: "Current number of active database connections",
},
)
// 缓存命中率
cacheHits := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_operations_total",
Help: "Total number of cache operations",
},
[]string{"operation", "result"},
)
// 注册指标
registry.MustRegister(httpRequests)
registry.MustRegister(httpDuration)
registry.MustRegister(activeUsers)
registry.MustRegister(dbConnections)
registry.MustRegister(cacheHits)
return &PrometheusMetrics{
logger: logger,
registry: registry,
httpRequests: httpRequests,
httpDuration: httpDuration,
activeUsers: activeUsers,
dbConnections: dbConnections,
cacheHits: cacheHits,
businessMetrics: make(map[string]prometheus.Collector),
}
}
// RecordHTTPRequest 记录HTTP请求指标
func (m *PrometheusMetrics) RecordHTTPRequest(method, path string, statusCode int, duration float64) {
status := strconv.Itoa(statusCode)
m.httpRequests.WithLabelValues(method, path, status).Inc()
m.httpDuration.WithLabelValues(method, path).Observe(duration)
m.logger.Debug("Recorded HTTP request metric",
zap.String("method", method),
zap.String("path", path),
zap.String("status", status),
zap.Float64("duration", duration))
}
// RecordHTTPDuration 记录HTTP请求耗时
func (m *PrometheusMetrics) RecordHTTPDuration(method, path string, duration float64) {
m.httpDuration.WithLabelValues(method, path).Observe(duration)
m.logger.Debug("Recorded HTTP duration metric",
zap.String("method", method),
zap.String("path", path),
zap.Float64("duration", duration))
}
// IncrementCounter 增加计数器
func (m *PrometheusMetrics) IncrementCounter(name string, labels map[string]string) {
if counter, exists := m.getOrCreateCounter(name, labels); exists {
if vec, ok := counter.(*prometheus.CounterVec); ok {
vec.With(labels).Inc()
}
}
}
// RecordGauge 记录仪表盘值
func (m *PrometheusMetrics) RecordGauge(name string, value float64, labels map[string]string) {
if gauge, exists := m.getOrCreateGauge(name, labels); exists {
if vec, ok := gauge.(*prometheus.GaugeVec); ok {
vec.With(labels).Set(value)
} else if g, ok := gauge.(prometheus.Gauge); ok {
g.Set(value)
}
}
}
// RecordHistogram 记录直方图值
func (m *PrometheusMetrics) RecordHistogram(name string, value float64, labels map[string]string) {
if histogram, exists := m.getOrCreateHistogram(name, labels); exists {
if vec, ok := histogram.(*prometheus.HistogramVec); ok {
vec.With(labels).Observe(value)
}
}
}
// RegisterCounter 注册计数器
func (m *PrometheusMetrics) RegisterCounter(name, help string, labels []string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.businessMetrics[name]; exists {
return nil // 已存在
}
var counter prometheus.Collector
if len(labels) > 0 {
counter = prometheus.NewCounterVec(
prometheus.CounterOpts{Name: name, Help: help},
labels,
)
} else {
counter = prometheus.NewCounter(
prometheus.CounterOpts{Name: name, Help: help},
)
}
if err := m.registry.Register(counter); err != nil {
return err
}
m.businessMetrics[name] = counter
m.logger.Info("Registered counter metric", zap.String("name", name))
return nil
}
// RegisterGauge 注册仪表盘
func (m *PrometheusMetrics) RegisterGauge(name, help string, labels []string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.businessMetrics[name]; exists {
return nil
}
var gauge prometheus.Collector
if len(labels) > 0 {
gauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{Name: name, Help: help},
labels,
)
} else {
gauge = prometheus.NewGauge(
prometheus.GaugeOpts{Name: name, Help: help},
)
}
if err := m.registry.Register(gauge); err != nil {
return err
}
m.businessMetrics[name] = gauge
m.logger.Info("Registered gauge metric", zap.String("name", name))
return nil
}
// RegisterHistogram 注册直方图
func (m *PrometheusMetrics) RegisterHistogram(name, help string, labels []string, buckets []float64) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.businessMetrics[name]; exists {
return nil
}
if buckets == nil {
buckets = prometheus.DefBuckets
}
var histogram prometheus.Collector
if len(labels) > 0 {
histogram = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: name,
Help: help,
Buckets: buckets,
},
labels,
)
} else {
histogram = prometheus.NewHistogram(
prometheus.HistogramOpts{
Name: name,
Help: help,
Buckets: buckets,
},
)
}
if err := m.registry.Register(histogram); err != nil {
return err
}
m.businessMetrics[name] = histogram
m.logger.Info("Registered histogram metric", zap.String("name", name))
return nil
}
// GetHandler 获取HTTP处理器
func (m *PrometheusMetrics) GetHandler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
}
// 内部辅助方法
func (m *PrometheusMetrics) getOrCreateCounter(name string, labels map[string]string) (prometheus.Collector, bool) {
m.mutex.RLock()
counter, exists := m.businessMetrics[name]
m.mutex.RUnlock()
if !exists {
// 自动创建计数器
labelNames := make([]string, 0, len(labels))
for k := range labels {
labelNames = append(labelNames, k)
}
if err := m.RegisterCounter(name, "Auto-created counter", labelNames); err != nil {
m.logger.Error("Failed to auto-create counter", zap.String("name", name), zap.Error(err))
return nil, false
}
m.mutex.RLock()
counter, exists = m.businessMetrics[name]
m.mutex.RUnlock()
}
return counter, exists
}
func (m *PrometheusMetrics) getOrCreateGauge(name string, labels map[string]string) (prometheus.Collector, bool) {
m.mutex.RLock()
gauge, exists := m.businessMetrics[name]
m.mutex.RUnlock()
if !exists {
labelNames := make([]string, 0, len(labels))
for k := range labels {
labelNames = append(labelNames, k)
}
if err := m.RegisterGauge(name, "Auto-created gauge", labelNames); err != nil {
m.logger.Error("Failed to auto-create gauge", zap.String("name", name), zap.Error(err))
return nil, false
}
m.mutex.RLock()
gauge, exists = m.businessMetrics[name]
m.mutex.RUnlock()
}
return gauge, exists
}
func (m *PrometheusMetrics) getOrCreateHistogram(name string, labels map[string]string) (prometheus.Collector, bool) {
m.mutex.RLock()
histogram, exists := m.businessMetrics[name]
m.mutex.RUnlock()
if !exists {
labelNames := make([]string, 0, len(labels))
for k := range labels {
labelNames = append(labelNames, k)
}
if err := m.RegisterHistogram(name, "Auto-created histogram", labelNames, nil); err != nil {
m.logger.Error("Failed to auto-create histogram", zap.String("name", name), zap.Error(err))
return nil, false
}
m.mutex.RLock()
histogram, exists = m.businessMetrics[name]
m.mutex.RUnlock()
}
return histogram, exists
}
// UpdateActiveUsers 更新活跃用户数
func (m *PrometheusMetrics) UpdateActiveUsers(count float64) {
m.activeUsers.Set(count)
}
// UpdateDBConnections 更新数据库连接数
func (m *PrometheusMetrics) UpdateDBConnections(count float64) {
m.dbConnections.Set(count)
}
// RecordCacheOperation 记录缓存操作
func (m *PrometheusMetrics) RecordCacheOperation(operation, result string) {
m.cacheHits.WithLabelValues(operation, result).Inc()
}
// GetStats 获取指标统计
func (m *PrometheusMetrics) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"registered_metrics": len(m.businessMetrics),
}
}

View File

@@ -42,31 +42,31 @@ func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "Missing authorization header")
m.respondUnauthorized(c, "缺少认证头部")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "Invalid authorization header format")
m.respondUnauthorized(c, "认证头部格式无效")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "Missing token")
m.respondUnauthorized(c, "缺少认证令牌")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
m.logger.Warn("Invalid token",
m.logger.Warn("无效的认证令牌",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "Invalid token")
m.respondUnauthorized(c, "认证令牌无效")
return
}
@@ -119,7 +119,7 @@ func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "Unauthorized",
"message": "认证失败",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),

View File

@@ -2,11 +2,11 @@ package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
@@ -15,14 +15,16 @@ import (
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
response interfaces.ResponseBuilder
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
response: response,
limiters: make(map[string]*rate.Limiter),
}
}
@@ -48,15 +50,13 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
// 检查是否允许请求
if !limiter.Allow() {
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Header("Retry-After", "60")
c.JSON(http.StatusTooManyRequests, gin.H{
"success": false,
"message": "Rate limit exceeded",
"error": "Too many requests",
})
// 使用统一的响应格式
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
c.Abort()
return
}

View File

@@ -2,23 +2,35 @@ package middleware
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"tyapi-server/internal/shared/tracing"
)
// RequestLoggerMiddleware 请求日志中间件
type RequestLoggerMiddleware struct {
logger *zap.Logger
logger *zap.Logger
useColoredLog bool
isDevelopment bool
tracer *tracing.Tracer
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
return &RequestLoggerMiddleware{
logger: logger,
logger: logger,
useColoredLog: isDevelopment, // 开发环境使用彩色日志
isDevelopment: isDevelopment,
tracer: tracer,
}
}
@@ -34,24 +46,110 @@ func (m *RequestLoggerMiddleware) GetPriority() int {
// Handle 返回中间件处理函数
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 使用zap logger记录请求信息
m.logger.Info("HTTP Request",
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
)
if m.useColoredLog {
// 开发环境使用Gin默认的彩色日志格式
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
// 返回空字符串因为我们已经用zap记录了
return ""
})
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
// 获取TraceID
traceID := param.Request.Header.Get("X-Trace-ID")
if traceID == "" && m.tracer != nil {
traceID = m.tracer.GetTraceID(param.Request.Context())
}
// 检查是否为错误响应
if param.StatusCode >= 400 && m.tracer != nil {
span := trace.SpanFromContext(param.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
}
}
traceInfo := ""
if traceID != "" {
traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
}
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
methodColor, param.Method, resetColor,
param.Path,
traceInfo,
param.ErrorMessage,
)
})
} else {
// 生产环境使用结构化JSON日志
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 获取TraceID
traceID := param.Request.Header.Get("X-Trace-ID")
if traceID == "" && m.tracer != nil {
traceID = m.tracer.GetTraceID(param.Request.Context())
}
// 检查是否为错误响应
if param.StatusCode >= 400 && m.tracer != nil {
span := trace.SpanFromContext(param.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 对于服务器错误,记录更详细的日志
if param.StatusCode >= 500 {
m.logger.Error("服务器错误",
zap.Int("status_code", param.StatusCode),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.Duration("latency", param.Latency),
zap.String("client_ip", param.ClientIP),
zap.String("trace_id", traceID),
)
}
}
}
// 记录请求日志
logFields := []zap.Field{
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
}
// 添加TraceID
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
m.logger.Info("HTTP请求", logFields...)
return ""
})
}
}
// IsGlobal 是否为全局中间件
@@ -102,6 +200,70 @@ func (m *RequestIDMiddleware) IsGlobal() bool {
return true
}
// TraceIDMiddleware 追踪ID中间件
type TraceIDMiddleware struct {
tracer *tracing.Tracer
}
// NewTraceIDMiddleware 创建追踪ID中间件
func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
return &TraceIDMiddleware{
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *TraceIDMiddleware) GetName() string {
return "trace_id"
}
// GetPriority 返回中间件优先级
func (m *TraceIDMiddleware) GetPriority() int {
return 94 // 仅次于请求ID中间件
}
// Handle 返回中间件处理函数
func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取或生成追踪ID
traceID := m.tracer.GetTraceID(c.Request.Context())
if traceID != "" {
// 设置追踪ID到响应头
c.Header("X-Trace-ID", traceID)
// 添加到上下文
c.Set("trace_id", traceID)
}
// 检查是否为错误请求例如URL不存在
c.Next()
// 请求完成后检查状态码
if c.Writer.Status() >= 400 {
// 获取当前span
span := trace.SpanFromContext(c.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 设置错误上下文以便后续span可以识别
c.Request = c.Request.WithContext(context.WithValue(
c.Request.Context(),
"otel_error_request",
true,
))
}
}
}
}
// IsGlobal 是否为全局中间件
func (m *TraceIDMiddleware) IsGlobal() bool {
return true
}
// SecurityHeadersMiddleware 安全头部中间件
type SecurityHeadersMiddleware struct{}
@@ -183,13 +345,15 @@ func (m *ResponseTimeMiddleware) IsGlobal() bool {
type RequestBodyLoggerMiddleware struct {
logger *zap.Logger
enable bool
tracer *tracing.Tracer
}
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
return &RequestBodyLoggerMiddleware{
logger: logger,
enable: enable,
tracer: tracer,
}
}
@@ -220,13 +384,26 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
// 重新设置body供后续处理使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 获取追踪ID
traceID := ""
if m.tracer != nil {
traceID = m.tracer.GetTraceID(c.Request.Context())
}
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
m.logger.Debug("Request Body",
logFields := []zap.Field{
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.String("body", string(bodyBytes)),
zap.String("request_id", c.GetString("request_id")),
)
}
// 添加追踪ID
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
m.logger.Debug("请求体", logFields...)
}
}
}
@@ -239,3 +416,83 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
return false // 可选中间件,不是全局的
}
// ErrorTrackingMiddleware 错误追踪中间件
type ErrorTrackingMiddleware struct {
logger *zap.Logger
tracer *tracing.Tracer
}
// NewErrorTrackingMiddleware 创建错误追踪中间件
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
return &ErrorTrackingMiddleware{
logger: logger,
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *ErrorTrackingMiddleware) GetName() string {
return "error_tracking"
}
// GetPriority 返回中间件优先级
func (m *ErrorTrackingMiddleware) GetPriority() int {
return 60 // 低优先级,在大多数中间件之后执行
}
// Handle 返回中间件处理函数
func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 检查是否有错误
if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
// 获取当前span
span := trace.SpanFromContext(c.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录错误日志
traceID := m.tracer.GetTraceID(c.Request.Context())
spanID := m.tracer.GetSpanID(c.Request.Context())
logFields := []zap.Field{
zap.Int("status_code", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", c.FullPath()),
zap.String("client_ip", c.ClientIP()),
}
// 添加追踪信息
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
if spanID != "" {
logFields = append(logFields, zap.String("span_id", spanID))
}
// 添加错误信息
if len(c.Errors) > 0 {
logFields = append(logFields, zap.String("errors", c.Errors.String()))
}
// 根据状态码决定日志级别
if c.Writer.Status() >= 500 {
m.logger.Error("服务器错误", logFields...)
} else {
m.logger.Warn("客户端错误", logFields...)
}
}
}
}
}
// IsGlobal 是否为全局中间件
func (m *ErrorTrackingMiddleware) IsGlobal() bool {
return true
}

View File

@@ -0,0 +1,389 @@
package resilience
import (
"context"
"errors"
"sync"
"time"
"go.uber.org/zap"
)
// CircuitState 熔断器状态
type CircuitState int
const (
// StateClosed 关闭状态(正常)
StateClosed CircuitState = iota
// StateOpen 开启状态(熔断)
StateOpen
// StateHalfOpen 半开状态(测试)
StateHalfOpen
)
func (s CircuitState) String() string {
switch s {
case StateClosed:
return "CLOSED"
case StateOpen:
return "OPEN"
case StateHalfOpen:
return "HALF_OPEN"
default:
return "UNKNOWN"
}
}
// CircuitBreakerConfig 熔断器配置
type CircuitBreakerConfig struct {
// 故障阈值
FailureThreshold int
// 重置超时时间
ResetTimeout time.Duration
// 检测窗口大小
WindowSize int
// 半开状态允许的请求数
HalfOpenMaxRequests int
// 成功阈值(半开->关闭)
SuccessThreshold int
}
// DefaultCircuitBreakerConfig 默认熔断器配置
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
FailureThreshold: 5,
ResetTimeout: 60 * time.Second,
WindowSize: 10,
HalfOpenMaxRequests: 3,
SuccessThreshold: 2,
}
}
// CircuitBreaker 熔断器
type CircuitBreaker struct {
config CircuitBreakerConfig
logger *zap.Logger
mutex sync.RWMutex
// 状态
state CircuitState
// 计数器
failures int
successes int
requests int
consecutiveFailures int
// 时间记录
lastFailTime time.Time
lastStateChange time.Time
// 统计窗口
window []bool // true=success, false=failure
windowIndex int
windowFull bool
// 事件回调
onStateChange func(from, to CircuitState)
}
// NewCircuitBreaker 创建熔断器
func NewCircuitBreaker(config CircuitBreakerConfig, logger *zap.Logger) *CircuitBreaker {
cb := &CircuitBreaker{
config: config,
logger: logger,
state: StateClosed,
window: make([]bool, config.WindowSize),
lastStateChange: time.Now(),
}
return cb
}
// Execute 执行函数,如果熔断器开启则快速失败
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
// 检查是否允许执行
if !cb.allowRequest() {
return ErrCircuitBreakerOpen
}
// 执行函数
start := time.Now()
err := fn()
duration := time.Since(start)
// 记录结果
cb.recordResult(err == nil, duration)
return err
}
// allowRequest 检查是否允许请求
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
switch cb.state {
case StateClosed:
return true
case StateOpen:
// 检查是否到了重置时间
if now.Sub(cb.lastStateChange) > cb.config.ResetTimeout {
cb.setState(StateHalfOpen)
return true
}
return false
case StateHalfOpen:
// 半开状态下限制请求数
return cb.requests < cb.config.HalfOpenMaxRequests
default:
return false
}
}
// recordResult 记录执行结果
func (cb *CircuitBreaker) recordResult(success bool, duration time.Duration) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.requests++
// 更新滑动窗口
cb.updateWindow(success)
if success {
cb.successes++
cb.consecutiveFailures = 0
cb.onSuccess()
} else {
cb.failures++
cb.consecutiveFailures++
cb.lastFailTime = time.Now()
cb.onFailure()
}
cb.logger.Debug("Circuit breaker recorded result",
zap.Bool("success", success),
zap.Duration("duration", duration),
zap.String("state", cb.state.String()),
zap.Int("failures", cb.failures),
zap.Int("successes", cb.successes))
}
// updateWindow 更新滑动窗口
func (cb *CircuitBreaker) updateWindow(success bool) {
cb.window[cb.windowIndex] = success
cb.windowIndex = (cb.windowIndex + 1) % cb.config.WindowSize
if cb.windowIndex == 0 {
cb.windowFull = true
}
}
// onSuccess 成功时的处理
func (cb *CircuitBreaker) onSuccess() {
if cb.state == StateHalfOpen {
// 半开状态下,如果成功次数达到阈值,则关闭熔断器
if cb.successes >= cb.config.SuccessThreshold {
cb.setState(StateClosed)
}
}
}
// onFailure 失败时的处理
func (cb *CircuitBreaker) onFailure() {
if cb.state == StateClosed {
// 关闭状态下,检查是否需要开启熔断器
if cb.shouldTrip() {
cb.setState(StateOpen)
}
} else if cb.state == StateHalfOpen {
// 半开状态下,如果失败则立即开启熔断器
cb.setState(StateOpen)
}
}
// shouldTrip 检查是否应该触发熔断
func (cb *CircuitBreaker) shouldTrip() bool {
// 基于连续失败次数
if cb.consecutiveFailures >= cb.config.FailureThreshold {
return true
}
// 基于滑动窗口的失败率
if cb.windowFull {
failures := 0
for _, success := range cb.window {
if !success {
failures++
}
}
failureRate := float64(failures) / float64(cb.config.WindowSize)
return failureRate >= 0.5 // 50%失败率
}
return false
}
// setState 设置状态
func (cb *CircuitBreaker) setState(newState CircuitState) {
if cb.state == newState {
return
}
oldState := cb.state
cb.state = newState
cb.lastStateChange = time.Now()
// 重置计数器
if newState == StateClosed {
cb.requests = 0
cb.failures = 0
cb.successes = 0
cb.consecutiveFailures = 0
} else if newState == StateHalfOpen {
cb.requests = 0
cb.successes = 0
}
cb.logger.Info("Circuit breaker state changed",
zap.String("from", oldState.String()),
zap.String("to", newState.String()),
zap.Int("failures", cb.failures),
zap.Int("successes", cb.successes))
// 触发状态变更回调
if cb.onStateChange != nil {
cb.onStateChange(oldState, newState)
}
}
// GetState 获取当前状态
func (cb *CircuitBreaker) GetState() CircuitState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// GetStats 获取统计信息
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return CircuitBreakerStats{
State: cb.state.String(),
Failures: cb.failures,
Successes: cb.successes,
Requests: cb.requests,
ConsecutiveFailures: cb.consecutiveFailures,
LastFailTime: cb.lastFailTime,
LastStateChange: cb.lastStateChange,
FailureThreshold: cb.config.FailureThreshold,
ResetTimeout: cb.config.ResetTimeout,
}
}
// Reset 重置熔断器
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.setState(StateClosed)
cb.window = make([]bool, cb.config.WindowSize)
cb.windowIndex = 0
cb.windowFull = false
cb.logger.Info("Circuit breaker reset")
}
// SetStateChangeCallback 设置状态变更回调
func (cb *CircuitBreaker) SetStateChangeCallback(callback func(from, to CircuitState)) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.onStateChange = callback
}
// CircuitBreakerStats 熔断器统计信息
type CircuitBreakerStats struct {
State string `json:"state"`
Failures int `json:"failures"`
Successes int `json:"successes"`
Requests int `json:"requests"`
ConsecutiveFailures int `json:"consecutive_failures"`
LastFailTime time.Time `json:"last_fail_time"`
LastStateChange time.Time `json:"last_state_change"`
FailureThreshold int `json:"failure_threshold"`
ResetTimeout time.Duration `json:"reset_timeout"`
}
// 预定义错误
var (
ErrCircuitBreakerOpen = errors.New("circuit breaker is open")
)
// Wrapper 熔断器包装器
type Wrapper struct {
breakers map[string]*CircuitBreaker
logger *zap.Logger
mutex sync.RWMutex
}
// NewWrapper 创建熔断器包装器
func NewWrapper(logger *zap.Logger) *Wrapper {
return &Wrapper{
breakers: make(map[string]*CircuitBreaker),
logger: logger,
}
}
// GetOrCreate 获取或创建熔断器
func (w *Wrapper) GetOrCreate(name string, config CircuitBreakerConfig) *CircuitBreaker {
w.mutex.Lock()
defer w.mutex.Unlock()
if cb, exists := w.breakers[name]; exists {
return cb
}
cb := NewCircuitBreaker(config, w.logger.Named(name))
w.breakers[name] = cb
w.logger.Info("Created circuit breaker", zap.String("name", name))
return cb
}
// Execute 执行带熔断器的函数
func (w *Wrapper) Execute(ctx context.Context, name string, fn func() error) error {
cb := w.GetOrCreate(name, DefaultCircuitBreakerConfig())
return cb.Execute(ctx, fn)
}
// GetStats 获取所有熔断器统计
func (w *Wrapper) GetStats() map[string]CircuitBreakerStats {
w.mutex.RLock()
defer w.mutex.RUnlock()
stats := make(map[string]CircuitBreakerStats)
for name, cb := range w.breakers {
stats[name] = cb.GetStats()
}
return stats
}
// ResetAll 重置所有熔断器
func (w *Wrapper) ResetAll() {
w.mutex.RLock()
defer w.mutex.RUnlock()
for name, cb := range w.breakers {
cb.Reset()
w.logger.Info("Reset circuit breaker", zap.String("name", name))
}
}

View File

@@ -0,0 +1,467 @@
package resilience
import (
"context"
"fmt"
"math/rand"
"sync"
"time"
"go.uber.org/zap"
)
// RetryConfig 重试配置
type RetryConfig struct {
// 最大重试次数
MaxAttempts int
// 初始延迟
InitialDelay time.Duration
// 最大延迟
MaxDelay time.Duration
// 退避倍数
BackoffMultiplier float64
// 抖动系数
JitterFactor float64
// 重试条件
RetryCondition func(error) bool
// 延迟函数
DelayFunc func(attempt int, config RetryConfig) time.Duration
}
// DefaultRetryConfig 默认重试配置
func DefaultRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffMultiplier: 2.0,
JitterFactor: 0.1,
RetryCondition: DefaultRetryCondition,
DelayFunc: ExponentialBackoffWithJitter,
}
}
// RetryableError 可重试错误接口
type RetryableError interface {
error
IsRetryable() bool
}
// DefaultRetryCondition 默认重试条件
func DefaultRetryCondition(err error) bool {
if err == nil {
return false
}
// 检查是否实现了RetryableError接口
if retryable, ok := err.(RetryableError); ok {
return retryable.IsRetryable()
}
// 默认所有错误都重试
return true
}
// IsRetryableHTTPError HTTP错误重试条件
func IsRetryableHTTPError(statusCode int) bool {
// 5xx错误通常可以重试
// 429Too Many Requests也可以重试
return statusCode >= 500 || statusCode == 429
}
// DelayFunction 延迟函数类型
type DelayFunction func(attempt int, config RetryConfig) time.Duration
// FixedDelay 固定延迟
func FixedDelay(attempt int, config RetryConfig) time.Duration {
return config.InitialDelay
}
// LinearBackoff 线性退避
func LinearBackoff(attempt int, config RetryConfig) time.Duration {
delay := time.Duration(attempt) * config.InitialDelay
if delay > config.MaxDelay {
delay = config.MaxDelay
}
return delay
}
// ExponentialBackoff 指数退避
func ExponentialBackoff(attempt int, config RetryConfig) time.Duration {
delay := config.InitialDelay
for i := 0; i < attempt; i++ {
delay = time.Duration(float64(delay) * config.BackoffMultiplier)
}
if delay > config.MaxDelay {
delay = config.MaxDelay
}
return delay
}
// ExponentialBackoffWithJitter 带抖动的指数退避
func ExponentialBackoffWithJitter(attempt int, config RetryConfig) time.Duration {
delay := ExponentialBackoff(attempt, config)
// 添加抖动
jitter := config.JitterFactor
if jitter > 0 {
jitterRange := float64(delay) * jitter
jitterOffset := (rand.Float64() - 0.5) * 2 * jitterRange
delay = time.Duration(float64(delay) + jitterOffset)
}
if delay < 0 {
delay = config.InitialDelay
}
return delay
}
// RetryStats 重试统计
type RetryStats struct {
TotalAttempts int `json:"total_attempts"`
Successes int `json:"successes"`
Failures int `json:"failures"`
TotalRetries int `json:"total_retries"`
AverageAttempts float64 `json:"average_attempts"`
TotalDelay time.Duration `json:"total_delay"`
LastError string `json:"last_error,omitempty"`
}
// Retryer 重试器
type Retryer struct {
config RetryConfig
logger *zap.Logger
stats RetryStats
}
// NewRetryer 创建重试器
func NewRetryer(config RetryConfig, logger *zap.Logger) *Retryer {
if config.DelayFunc == nil {
config.DelayFunc = ExponentialBackoffWithJitter
}
if config.RetryCondition == nil {
config.RetryCondition = DefaultRetryCondition
}
return &Retryer{
config: config,
logger: logger,
}
}
// Execute 执行带重试的函数
func (r *Retryer) Execute(ctx context.Context, operation func() error) error {
return r.ExecuteWithResult(ctx, func() (interface{}, error) {
return nil, operation()
})
}
// ExecuteWithResult 执行带重试和返回值的函数
func (r *Retryer) ExecuteWithResult(ctx context.Context, operation func() (interface{}, error)) error {
var lastErr error
startTime := time.Now()
for attempt := 0; attempt < r.config.MaxAttempts; attempt++ {
// 检查上下文是否被取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// 执行操作
attemptStart := time.Now()
_, err := operation()
attemptDuration := time.Since(attemptStart)
// 更新统计
r.stats.TotalAttempts++
if err == nil {
r.stats.Successes++
r.logger.Debug("Operation succeeded",
zap.Int("attempt", attempt+1),
zap.Duration("duration", attemptDuration))
return nil
}
lastErr = err
r.stats.Failures++
if attempt > 0 {
r.stats.TotalRetries++
}
// 检查是否应该重试
if !r.config.RetryCondition(err) {
r.logger.Debug("Error is not retryable",
zap.Error(err),
zap.Int("attempt", attempt+1))
break
}
// 如果这是最后一次尝试,不需要延迟
if attempt == r.config.MaxAttempts-1 {
r.logger.Debug("Reached max attempts",
zap.Error(err),
zap.Int("max_attempts", r.config.MaxAttempts))
break
}
// 计算延迟
delay := r.config.DelayFunc(attempt, r.config)
r.stats.TotalDelay += delay
r.logger.Debug("Operation failed, retrying",
zap.Error(err),
zap.Int("attempt", attempt+1),
zap.Duration("delay", delay),
zap.Duration("attempt_duration", attemptDuration))
// 等待重试
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
}
}
// 更新最终统计
totalDuration := time.Since(startTime)
if r.stats.TotalAttempts > 0 {
r.stats.AverageAttempts = float64(r.stats.TotalRetries) / float64(r.stats.Successes+r.stats.Failures)
}
if lastErr != nil {
r.stats.LastError = lastErr.Error()
}
r.logger.Warn("Operation failed after all retries",
zap.Error(lastErr),
zap.Int("total_attempts", r.stats.TotalAttempts),
zap.Duration("total_duration", totalDuration))
return fmt.Errorf("operation failed after %d attempts: %w", r.config.MaxAttempts, lastErr)
}
// GetStats 获取重试统计
func (r *Retryer) GetStats() RetryStats {
return r.stats
}
// Reset 重置统计
func (r *Retryer) Reset() {
r.stats = RetryStats{}
r.logger.Debug("Retry stats reset")
}
// Retry 简单重试函数
func Retry(ctx context.Context, config RetryConfig, operation func() error) error {
retryer := NewRetryer(config, zap.NewNop())
return retryer.Execute(ctx, operation)
}
// RetryWithResult 带返回值的重试函数
func RetryWithResult[T any](ctx context.Context, config RetryConfig, operation func() (T, error)) (T, error) {
var result T
var finalErr error
retryer := NewRetryer(config, zap.NewNop())
err := retryer.ExecuteWithResult(ctx, func() (interface{}, error) {
r, e := operation()
result = r
return r, e
})
if err != nil {
finalErr = err
}
return result, finalErr
}
// 预定义的重试配置
// QuickRetry 快速重试(适用于轻量级操作)
func QuickRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 50 * time.Millisecond,
MaxDelay: 500 * time.Millisecond,
BackoffMultiplier: 2.0,
JitterFactor: 0.1,
RetryCondition: DefaultRetryCondition,
DelayFunc: ExponentialBackoffWithJitter,
}
}
// StandardRetry 标准重试(适用于一般操作)
func StandardRetry() RetryConfig {
return DefaultRetryConfig()
}
// PatientRetry 耐心重试(适用于重要操作)
func PatientRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 5,
InitialDelay: 200 * time.Millisecond,
MaxDelay: 10 * time.Second,
BackoffMultiplier: 2.0,
JitterFactor: 0.2,
RetryCondition: DefaultRetryCondition,
DelayFunc: ExponentialBackoffWithJitter,
}
}
// DatabaseRetry 数据库重试配置
func DatabaseRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 2 * time.Second,
BackoffMultiplier: 1.5,
JitterFactor: 0.1,
RetryCondition: func(err error) bool {
// 这里可以根据具体的数据库错误类型判断
// 例如:连接超时、临时网络错误等
return DefaultRetryCondition(err)
},
DelayFunc: ExponentialBackoffWithJitter,
}
}
// HTTPRetry HTTP重试配置
func HTTPRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 200 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffMultiplier: 2.0,
JitterFactor: 0.15,
RetryCondition: func(err error) bool {
// HTTP相关的重试条件
return DefaultRetryCondition(err)
},
DelayFunc: ExponentialBackoffWithJitter,
}
}
// RetryManager 重试管理器
type RetryManager struct {
retryers map[string]*Retryer
logger *zap.Logger
mutex sync.RWMutex
}
// NewRetryManager 创建重试管理器
func NewRetryManager(logger *zap.Logger) *RetryManager {
return &RetryManager{
retryers: make(map[string]*Retryer),
logger: logger,
}
}
// GetOrCreate 获取或创建重试器
func (rm *RetryManager) GetOrCreate(name string, config RetryConfig) *Retryer {
rm.mutex.Lock()
defer rm.mutex.Unlock()
if retryer, exists := rm.retryers[name]; exists {
return retryer
}
retryer := NewRetryer(config, rm.logger.Named(name))
rm.retryers[name] = retryer
rm.logger.Info("Created retryer", zap.String("name", name))
return retryer
}
// Execute 执行带重试的操作
func (rm *RetryManager) Execute(ctx context.Context, name string, operation func() error) error {
retryer := rm.GetOrCreate(name, DefaultRetryConfig())
return retryer.Execute(ctx, operation)
}
// GetStats 获取所有重试器统计
func (rm *RetryManager) GetStats() map[string]RetryStats {
rm.mutex.RLock()
defer rm.mutex.RUnlock()
stats := make(map[string]RetryStats)
for name, retryer := range rm.retryers {
stats[name] = retryer.GetStats()
}
return stats
}
// ResetAll 重置所有重试器统计
func (rm *RetryManager) ResetAll() {
rm.mutex.RLock()
defer rm.mutex.RUnlock()
for name, retryer := range rm.retryers {
retryer.Reset()
rm.logger.Info("Reset retryer stats", zap.String("name", name))
}
}
// RetryerWrapper 重试器包装器
type RetryerWrapper struct {
manager *RetryManager
logger *zap.Logger
}
// NewRetryerWrapper 创建重试器包装器
func NewRetryerWrapper(logger *zap.Logger) *RetryerWrapper {
return &RetryerWrapper{
manager: NewRetryManager(logger),
logger: logger,
}
}
// ExecuteWithQuickRetry 执行快速重试
func (rw *RetryerWrapper) ExecuteWithQuickRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".quick", QuickRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithStandardRetry 执行标准重试
func (rw *RetryerWrapper) ExecuteWithStandardRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".standard", StandardRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithDatabaseRetry 执行数据库重试
func (rw *RetryerWrapper) ExecuteWithDatabaseRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".database", DatabaseRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithHTTPRetry 执行HTTP重试
func (rw *RetryerWrapper) ExecuteWithHTTPRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".http", HTTPRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithCustomRetry 执行自定义重试
func (rw *RetryerWrapper) ExecuteWithCustomRetry(ctx context.Context, name string, config RetryConfig, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".custom", config)
return retryer.Execute(ctx, operation)
}
// GetManager 获取重试管理器
func (rw *RetryerWrapper) GetManager() *RetryManager {
return rw.manager
}
// GetAllStats 获取所有统计信息
func (rw *RetryerWrapper) GetAllStats() map[string]RetryStats {
return rw.manager.GetStats()
}
// ResetAllStats 重置所有统计信息
func (rw *RetryerWrapper) ResetAllStats() {
rw.manager.ResetAll()
}

View File

@@ -0,0 +1,612 @@
package saga
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// SagaStatus Saga状态
type SagaStatus int
const (
// StatusPending 等待中
StatusPending SagaStatus = iota
// StatusRunning 执行中
StatusRunning
// StatusCompleted 已完成
StatusCompleted
// StatusFailed 失败
StatusFailed
// StatusCompensating 补偿中
StatusCompensating
// StatusCompensated 已补偿
StatusCompensated
// StatusAborted 已中止
StatusAborted
)
func (s SagaStatus) String() string {
switch s {
case StatusPending:
return "PENDING"
case StatusRunning:
return "RUNNING"
case StatusCompleted:
return "COMPLETED"
case StatusFailed:
return "FAILED"
case StatusCompensating:
return "COMPENSATING"
case StatusCompensated:
return "COMPENSATED"
case StatusAborted:
return "ABORTED"
default:
return "UNKNOWN"
}
}
// StepStatus 步骤状态
type StepStatus int
const (
// StepPending 等待执行
StepPending StepStatus = iota
// StepRunning 执行中
StepRunning
// StepCompleted 完成
StepCompleted
// StepFailed 失败
StepFailed
// StepCompensated 已补偿
StepCompensated
// StepSkipped 跳过
StepSkipped
)
func (s StepStatus) String() string {
switch s {
case StepPending:
return "PENDING"
case StepRunning:
return "RUNNING"
case StepCompleted:
return "COMPLETED"
case StepFailed:
return "FAILED"
case StepCompensated:
return "COMPENSATED"
case StepSkipped:
return "SKIPPED"
default:
return "UNKNOWN"
}
}
// SagaStep Saga步骤
type SagaStep struct {
Name string
Action func(ctx context.Context, data interface{}) error
Compensate func(ctx context.Context, data interface{}) error
Status StepStatus
Error error
StartTime time.Time
EndTime time.Time
RetryCount int
MaxRetries int
Timeout time.Duration
}
// SagaConfig Saga配置
type SagaConfig struct {
// 默认超时时间
DefaultTimeout time.Duration
// 默认重试次数
DefaultMaxRetries int
// 是否并行执行(当前只支持串行)
Parallel bool
// 事件发布器
EventBus interfaces.EventBus
}
// DefaultSagaConfig 默认Saga配置
func DefaultSagaConfig() SagaConfig {
return SagaConfig{
DefaultTimeout: 30 * time.Second,
DefaultMaxRetries: 3,
Parallel: false,
}
}
// Saga 分布式事务
type Saga struct {
ID string
Name string
Steps []*SagaStep
Status SagaStatus
Data interface{}
StartTime time.Time
EndTime time.Time
Error error
Config SagaConfig
logger *zap.Logger
mutex sync.RWMutex
currentStep int
result interface{}
}
// NewSaga 创建新的Saga
func NewSaga(id, name string, config SagaConfig, logger *zap.Logger) *Saga {
return &Saga{
ID: id,
Name: name,
Steps: make([]*SagaStep, 0),
Status: StatusPending,
Config: config,
logger: logger,
currentStep: -1,
}
}
// AddStep 添加步骤
func (s *Saga) AddStep(name string, action, compensate func(ctx context.Context, data interface{}) error) *Saga {
step := &SagaStep{
Name: name,
Action: action,
Compensate: compensate,
Status: StepPending,
MaxRetries: s.Config.DefaultMaxRetries,
Timeout: s.Config.DefaultTimeout,
}
s.mutex.Lock()
s.Steps = append(s.Steps, step)
s.mutex.Unlock()
s.logger.Debug("Added step to saga",
zap.String("saga_id", s.ID),
zap.String("step_name", name))
return s
}
// AddStepWithConfig 添加带配置的步骤
func (s *Saga) AddStepWithConfig(name string, action, compensate func(ctx context.Context, data interface{}) error, maxRetries int, timeout time.Duration) *Saga {
step := &SagaStep{
Name: name,
Action: action,
Compensate: compensate,
Status: StepPending,
MaxRetries: maxRetries,
Timeout: timeout,
}
s.mutex.Lock()
s.Steps = append(s.Steps, step)
s.mutex.Unlock()
s.logger.Debug("Added step with config to saga",
zap.String("saga_id", s.ID),
zap.String("step_name", name),
zap.Int("max_retries", maxRetries),
zap.Duration("timeout", timeout))
return s
}
// Execute 执行Saga
func (s *Saga) Execute(ctx context.Context, data interface{}) error {
s.mutex.Lock()
if s.Status != StatusPending {
s.mutex.Unlock()
return fmt.Errorf("saga %s is not in pending status", s.ID)
}
s.Status = StatusRunning
s.Data = data
s.StartTime = time.Now()
s.mutex.Unlock()
s.logger.Info("Starting saga execution",
zap.String("saga_id", s.ID),
zap.String("saga_name", s.Name),
zap.Int("total_steps", len(s.Steps)))
// 发布Saga开始事件
s.publishEvent(ctx, "saga.started")
// 执行所有步骤
for i, step := range s.Steps {
s.mutex.Lock()
s.currentStep = i
s.mutex.Unlock()
if err := s.executeStep(ctx, step, data); err != nil {
s.logger.Error("Step execution failed",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Error(err))
// 执行补偿
if compensateErr := s.compensate(ctx, i-1); compensateErr != nil {
s.logger.Error("Compensation failed",
zap.String("saga_id", s.ID),
zap.Error(compensateErr))
s.setStatus(StatusAborted)
s.publishEvent(ctx, "saga.aborted")
return fmt.Errorf("saga execution failed and compensation failed: %w", compensateErr)
}
s.setStatus(StatusCompensated)
s.publishEvent(ctx, "saga.compensated")
return fmt.Errorf("saga execution failed: %w", err)
}
}
// 所有步骤成功完成
s.setStatus(StatusCompleted)
s.EndTime = time.Now()
s.logger.Info("Saga completed successfully",
zap.String("saga_id", s.ID),
zap.Duration("duration", s.EndTime.Sub(s.StartTime)))
s.publishEvent(ctx, "saga.completed")
return nil
}
// executeStep 执行单个步骤
func (s *Saga) executeStep(ctx context.Context, step *SagaStep, data interface{}) error {
step.Status = StepRunning
step.StartTime = time.Now()
s.logger.Debug("Executing step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
// 设置超时上下文
stepCtx, cancel := context.WithTimeout(ctx, step.Timeout)
defer cancel()
// 重试逻辑
var lastErr error
for attempt := 0; attempt <= step.MaxRetries; attempt++ {
if attempt > 0 {
s.logger.Debug("Retrying step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Int("attempt", attempt))
}
err := step.Action(stepCtx, data)
if err == nil {
step.Status = StepCompleted
step.EndTime = time.Now()
s.logger.Debug("Step completed successfully",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Duration("duration", step.EndTime.Sub(step.StartTime)))
return nil
}
lastErr = err
step.RetryCount = attempt
// 检查是否应该重试
if attempt < step.MaxRetries {
select {
case <-stepCtx.Done():
// 上下文被取消,停止重试
break
case <-time.After(time.Duration(attempt+1) * 100 * time.Millisecond):
// 等待一段时间后重试
}
}
}
// 所有重试都失败了
step.Status = StepFailed
step.Error = lastErr
step.EndTime = time.Now()
return lastErr
}
// compensate 执行补偿
func (s *Saga) compensate(ctx context.Context, fromStep int) error {
s.setStatus(StatusCompensating)
s.logger.Info("Starting compensation",
zap.String("saga_id", s.ID),
zap.Int("from_step", fromStep))
// 逆序执行补偿
for i := fromStep; i >= 0; i-- {
step := s.Steps[i]
// 只补偿已完成的步骤
if step.Status != StepCompleted {
step.Status = StepSkipped
continue
}
if step.Compensate == nil {
s.logger.Warn("No compensation function for step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
continue
}
s.logger.Debug("Compensating step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
// 设置超时上下文
compensateCtx, cancel := context.WithTimeout(ctx, step.Timeout)
err := step.Compensate(compensateCtx, s.Data)
cancel()
if err != nil {
s.logger.Error("Compensation failed for step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Error(err))
return err
}
step.Status = StepCompensated
s.logger.Debug("Step compensated successfully",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
}
s.logger.Info("Compensation completed",
zap.String("saga_id", s.ID))
return nil
}
// setStatus 设置状态
func (s *Saga) setStatus(status SagaStatus) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.Status = status
}
// GetStatus 获取状态
func (s *Saga) GetStatus() SagaStatus {
s.mutex.RLock()
defer s.mutex.RUnlock()
return s.Status
}
// GetProgress 获取进度
func (s *Saga) GetProgress() SagaProgress {
s.mutex.RLock()
defer s.mutex.RUnlock()
completed := 0
for _, step := range s.Steps {
if step.Status == StepCompleted {
completed++
}
}
var percentage float64
if len(s.Steps) > 0 {
percentage = float64(completed) / float64(len(s.Steps)) * 100
}
return SagaProgress{
SagaID: s.ID,
Status: s.Status.String(),
TotalSteps: len(s.Steps),
CompletedSteps: completed,
CurrentStep: s.currentStep + 1,
PercentComplete: percentage,
StartTime: s.StartTime,
Duration: time.Since(s.StartTime),
}
}
// GetStepStatus 获取所有步骤状态
func (s *Saga) GetStepStatus() []StepProgress {
s.mutex.RLock()
defer s.mutex.RUnlock()
progress := make([]StepProgress, len(s.Steps))
for i, step := range s.Steps {
progress[i] = StepProgress{
Name: step.Name,
Status: step.Status.String(),
RetryCount: step.RetryCount,
StartTime: step.StartTime,
EndTime: step.EndTime,
Duration: step.EndTime.Sub(step.StartTime),
Error: "",
}
if step.Error != nil {
progress[i].Error = step.Error.Error()
}
}
return progress
}
// publishEvent 发布事件
func (s *Saga) publishEvent(ctx context.Context, eventType string) {
if s.Config.EventBus == nil {
return
}
event := &SagaEvent{
SagaID: s.ID,
SagaName: s.Name,
EventType: eventType,
Status: s.Status.String(),
Timestamp: time.Now(),
Data: s.Data,
}
// 这里应该实现Event接口简化处理
_ = event
}
// SagaProgress Saga进度
type SagaProgress struct {
SagaID string `json:"saga_id"`
Status string `json:"status"`
TotalSteps int `json:"total_steps"`
CompletedSteps int `json:"completed_steps"`
CurrentStep int `json:"current_step"`
PercentComplete float64 `json:"percent_complete"`
StartTime time.Time `json:"start_time"`
Duration time.Duration `json:"duration"`
}
// StepProgress 步骤进度
type StepProgress struct {
Name string `json:"name"`
Status string `json:"status"`
RetryCount int `json:"retry_count"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration time.Duration `json:"duration"`
Error string `json:"error,omitempty"`
}
// SagaEvent Saga事件
type SagaEvent struct {
SagaID string `json:"saga_id"`
SagaName string `json:"saga_name"`
EventType string `json:"event_type"`
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data,omitempty"`
}
// SagaManager Saga管理器
type SagaManager struct {
sagas map[string]*Saga
logger *zap.Logger
mutex sync.RWMutex
config SagaConfig
}
// NewSagaManager 创建Saga管理器
func NewSagaManager(config SagaConfig, logger *zap.Logger) *SagaManager {
return &SagaManager{
sagas: make(map[string]*Saga),
logger: logger,
config: config,
}
}
// CreateSaga 创建Saga
func (sm *SagaManager) CreateSaga(id, name string) *Saga {
saga := NewSaga(id, name, sm.config, sm.logger.Named("saga"))
sm.mutex.Lock()
sm.sagas[id] = saga
sm.mutex.Unlock()
sm.logger.Info("Created saga",
zap.String("saga_id", id),
zap.String("saga_name", name))
return saga
}
// GetSaga 获取Saga
func (sm *SagaManager) GetSaga(id string) (*Saga, bool) {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
saga, exists := sm.sagas[id]
return saga, exists
}
// ListSagas 列出所有Saga
func (sm *SagaManager) ListSagas() []*Saga {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
sagas := make([]*Saga, 0, len(sm.sagas))
for _, saga := range sm.sagas {
sagas = append(sagas, saga)
}
return sagas
}
// GetSagaProgress 获取Saga进度
func (sm *SagaManager) GetSagaProgress(id string) (SagaProgress, bool) {
saga, exists := sm.GetSaga(id)
if !exists {
return SagaProgress{}, false
}
return saga.GetProgress(), true
}
// RemoveSaga 移除Saga
func (sm *SagaManager) RemoveSaga(id string) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
delete(sm.sagas, id)
sm.logger.Debug("Removed saga", zap.String("saga_id", id))
}
// GetStats 获取统计信息
func (sm *SagaManager) GetStats() map[string]interface{} {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
statusCount := make(map[string]int)
for _, saga := range sm.sagas {
status := saga.GetStatus().String()
statusCount[status]++
}
return map[string]interface{}{
"total_sagas": len(sm.sagas),
"status_count": statusCount,
}
}
// 实现Service接口
// Name 返回服务名称
func (sm *SagaManager) Name() string {
return "saga-manager"
}
// Initialize 初始化服务
func (sm *SagaManager) Initialize(ctx context.Context) error {
sm.logger.Info("Saga manager service initialized")
return nil
}
// HealthCheck 健康检查
func (sm *SagaManager) HealthCheck(ctx context.Context) error {
return nil
}
// Shutdown 关闭服务
func (sm *SagaManager) Shutdown(ctx context.Context) error {
sm.logger.Info("Saga manager service shutdown")
return nil
}

View File

@@ -0,0 +1,130 @@
package sms
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
"go.uber.org/zap"
"tyapi-server/internal/config"
)
// Service 短信服务接口
type Service interface {
SendVerificationCode(ctx context.Context, phone string, code string) error
GenerateCode(length int) string
}
// AliSMSService 阿里云短信服务实现
type AliSMSService struct {
client *dysmsapi.Client
config config.SMSConfig
logger *zap.Logger
}
// NewAliSMSService 创建阿里云短信服务
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
if err != nil {
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
}
return &AliSMSService{
client: client,
config: cfg,
logger: logger,
}, nil
}
// SendVerificationCode 发送验证码
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.PhoneNumbers = phone
request.SignName = s.config.SignName
request.TemplateCode = s.config.TemplateCode
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("Failed to send SMS",
zap.String("phone", phone),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Code != "OK" {
s.logger.Error("SMS send failed",
zap.String("phone", phone),
zap.String("code", response.Code),
zap.String("message", response.Message))
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
}
s.logger.Info("SMS sent successfully",
zap.String("phone", phone),
zap.String("bizId", response.BizId))
return nil
}
// GenerateCode 生成验证码
func (s *AliSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
// 生成指定长度的数字验证码
max := big.NewInt(int64(pow10(length)))
n, _ := rand.Int(rand.Reader, max)
// 格式化为指定长度不足时前面补0
format := fmt.Sprintf("%%0%dd", length)
return fmt.Sprintf(format, n.Int64())
}
// pow10 计算10的n次方
func pow10(n int) int {
result := 1
for i := 0; i < n; i++ {
result *= 10
}
return result
}
// MockSMSService 模拟短信服务(用于开发和测试)
type MockSMSService struct {
logger *zap.Logger
}
// NewMockSMSService 创建模拟短信服务
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
return &MockSMSService{
logger: logger,
}
}
// SendVerificationCode 模拟发送验证码
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
s.logger.Info("Mock SMS sent",
zap.String("phone", phone),
zap.String("code", code))
return nil
}
// GenerateCode 生成验证码
func (s *MockSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
// 开发环境使用固定验证码便于测试
result := ""
for i := 0; i < length; i++ {
result += "1"
}
return result
}

View File

@@ -0,0 +1,292 @@
package tracing
import (
"context"
"fmt"
"reflect"
"runtime"
"strings"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
// TracableService 可追踪的服务接口
type TracableService interface {
Name() string
}
// ServiceDecorator 服务装饰器
type ServiceDecorator struct {
tracer *Tracer
logger *zap.Logger
config DecoratorConfig
}
// DecoratorConfig 装饰器配置
type DecoratorConfig struct {
EnableMethodTracing bool
ExcludePatterns []string
IncludeArguments bool
IncludeResults bool
SlowMethodThreshold time.Duration
}
// DefaultDecoratorConfig 默认装饰器配置
func DefaultDecoratorConfig() DecoratorConfig {
return DecoratorConfig{
EnableMethodTracing: true,
ExcludePatterns: []string{"Health", "Ping", "Name"},
IncludeArguments: true,
IncludeResults: false,
SlowMethodThreshold: 100 * time.Millisecond,
}
}
// NewServiceDecorator 创建服务装饰器
func NewServiceDecorator(tracer *Tracer, logger *zap.Logger) *ServiceDecorator {
return &ServiceDecorator{
tracer: tracer,
logger: logger,
config: DefaultDecoratorConfig(),
}
}
// WrapService 自动包装服务,为所有方法添加链路追踪
func (d *ServiceDecorator) WrapService(service interface{}) interface{} {
serviceValue := reflect.ValueOf(service)
serviceType := reflect.TypeOf(service)
if serviceType.Kind() == reflect.Ptr {
serviceType = serviceType.Elem()
serviceValue = serviceValue.Elem()
}
// 创建代理结构
proxyType := d.createProxyType(serviceType)
proxyValue := reflect.New(proxyType).Elem()
// 设置原始服务字段
proxyValue.FieldByName("target").Set(reflect.ValueOf(service))
proxyValue.FieldByName("decorator").Set(reflect.ValueOf(d))
return proxyValue.Addr().Interface()
}
// createProxyType 创建代理类型
func (d *ServiceDecorator) createProxyType(serviceType reflect.Type) reflect.Type {
// 获取服务名称
serviceName := d.getServiceName(serviceType)
// 创建代理结构字段
fields := []reflect.StructField{
{
Name: "target",
Type: reflect.PtrTo(serviceType),
},
{
Name: "decorator",
Type: reflect.TypeOf(d),
},
}
// 为每个方法创建包装器方法
for i := 0; i < serviceType.NumMethod(); i++ {
method := serviceType.Method(i)
if d.shouldTraceMethod(method.Name) {
// 创建方法字段(用于存储方法实现)
fields = append(fields, reflect.StructField{
Name: method.Name,
Type: method.Type,
})
}
}
// 创建新的结构类型
proxyType := reflect.StructOf(fields)
// 实现接口方法
d.implementMethods(proxyType, serviceType, serviceName)
return proxyType
}
// shouldTraceMethod 判断是否应该追踪方法
func (d *ServiceDecorator) shouldTraceMethod(methodName string) bool {
if !d.config.EnableMethodTracing {
return false
}
for _, pattern := range d.config.ExcludePatterns {
if strings.Contains(methodName, pattern) {
return false
}
}
return true
}
// getServiceName 获取服务名称
func (d *ServiceDecorator) getServiceName(serviceType reflect.Type) string {
serviceName := serviceType.Name()
// 移除Service后缀
if strings.HasSuffix(serviceName, "Service") {
serviceName = strings.TrimSuffix(serviceName, "Service")
}
return strings.ToLower(serviceName)
}
// TraceMethodCall 追踪方法调用
func (d *ServiceDecorator) TraceMethodCall(
ctx context.Context,
serviceName, methodName string,
fn func(context.Context) ([]reflect.Value, error),
args []reflect.Value,
) ([]reflect.Value, error) {
// 创建span名称
spanName := fmt.Sprintf("%s.%s", serviceName, methodName)
// 开始追踪
ctx, span := d.tracer.StartSpan(ctx, spanName)
defer span.End()
// 添加基础属性
d.tracer.AddSpanAttributes(span,
attribute.String("service.name", serviceName),
attribute.String("service.method", methodName),
attribute.String("service.type", "business"),
)
// 添加参数信息(如果启用)
if d.config.IncludeArguments {
d.addArgumentAttributes(span, args)
}
// 记录开始时间
startTime := time.Now()
// 执行原始方法
results, err := fn(ctx)
// 计算执行时间
duration := time.Since(startTime)
d.tracer.AddSpanAttributes(span,
attribute.Int64("service.duration_ms", duration.Milliseconds()),
)
// 标记慢方法
if duration > d.config.SlowMethodThreshold {
d.tracer.AddSpanAttributes(span,
attribute.Bool("service.slow_method", true),
)
d.logger.Warn("慢方法检测",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Duration("duration", duration),
zap.String("trace_id", d.tracer.GetTraceID(ctx)),
)
}
// 处理错误
if err != nil {
d.tracer.SetSpanError(span, err)
d.logger.Error("服务方法执行失败",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Error(err),
zap.String("trace_id", d.tracer.GetTraceID(ctx)),
)
} else {
d.tracer.SetSpanSuccess(span)
// 添加结果信息(如果启用)
if d.config.IncludeResults {
d.addResultAttributes(span, results)
}
}
return results, err
}
// addArgumentAttributes 添加参数属性
func (d *ServiceDecorator) addArgumentAttributes(span trace.Span, args []reflect.Value) {
for i, arg := range args {
if i == 0 && arg.Type().String() == "context.Context" {
continue // 跳过context参数
}
argName := fmt.Sprintf("service.arg_%d", i)
argValue := d.extractValue(arg)
if argValue != "" && len(argValue) < 1000 { // 限制长度避免性能问题
d.tracer.AddSpanAttributes(span,
attribute.String(argName, argValue),
)
}
}
}
// addResultAttributes 添加结果属性
func (d *ServiceDecorator) addResultAttributes(span trace.Span, results []reflect.Value) {
for i, result := range results {
if result.Type().String() == "error" {
continue // 错误在其他地方处理
}
resultName := fmt.Sprintf("service.result_%d", i)
resultValue := d.extractValue(result)
if resultValue != "" && len(resultValue) < 1000 {
d.tracer.AddSpanAttributes(span,
attribute.String(resultName, resultValue),
)
}
}
}
// extractValue 提取值的字符串表示
func (d *ServiceDecorator) extractValue(value reflect.Value) string {
if !value.IsValid() {
return ""
}
switch value.Kind() {
case reflect.String:
return value.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return fmt.Sprintf("%d", value.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return fmt.Sprintf("%d", value.Uint())
case reflect.Float32, reflect.Float64:
return fmt.Sprintf("%.2f", value.Float())
case reflect.Bool:
return fmt.Sprintf("%t", value.Bool())
case reflect.Ptr:
if value.IsNil() {
return "nil"
}
return d.extractValue(value.Elem())
case reflect.Struct:
// 对于结构体,只返回类型名
return value.Type().Name()
case reflect.Slice, reflect.Array:
return fmt.Sprintf("[%d items]", value.Len())
default:
return value.Type().Name()
}
}
// implementMethods 实现接口方法(占位符,实际需要运行时代理)
func (d *ServiceDecorator) implementMethods(proxyType, serviceType reflect.Type, serviceName string) {
// 这里是运行时方法实现的占位符
// 实际实现需要使用reflect.MakeFunc或其他运行时代理技术
}
// GetFunctionName 获取函数名称
func GetFunctionName(fn interface{}) string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
parts := strings.Split(name, ".")
return parts[len(parts)-1]
}

View File

@@ -0,0 +1,320 @@
package tracing
import (
"context"
"fmt"
"strings"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
gormSpanKey = "otel:span"
gormOperationKey = "otel:operation"
gormTableNameKey = "otel:table_name"
gormStartTimeKey = "otel:start_time"
)
// GormTracingPlugin GORM链路追踪插件
type GormTracingPlugin struct {
tracer *Tracer
logger *zap.Logger
config GormPluginConfig
}
// GormPluginConfig GORM插件配置
type GormPluginConfig struct {
IncludeSQL bool
IncludeValues bool
SlowThreshold time.Duration
ExcludeTables []string
SanitizeSQL bool
}
// DefaultGormPluginConfig 默认GORM插件配置
func DefaultGormPluginConfig() GormPluginConfig {
return GormPluginConfig{
IncludeSQL: true,
IncludeValues: false, // 生产环境建议设为false避免记录敏感数据
SlowThreshold: 200 * time.Millisecond,
ExcludeTables: []string{"migrations", "schema_migrations"},
SanitizeSQL: true,
}
}
// NewGormTracingPlugin 创建GORM追踪插件
func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin {
return &GormTracingPlugin{
tracer: tracer,
logger: logger,
config: DefaultGormPluginConfig(),
}
}
// Name 返回插件名称
func (p *GormTracingPlugin) Name() string {
return "gorm-otel-tracing"
}
// Initialize 初始化插件
func (p *GormTracingPlugin) Initialize(db *gorm.DB) error {
// 注册各种操作的回调
callbacks := []string{"create", "query", "update", "delete", "raw"}
for _, operation := range callbacks {
switch operation {
case "create":
err := db.Callback().Create().Before("gorm:create").
Register(p.Name()+":before_create", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before create callback: %w", err)
}
err = db.Callback().Create().After("gorm:create").
Register(p.Name()+":after_create", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after create callback: %w", err)
}
case "query":
err := db.Callback().Query().Before("gorm:query").
Register(p.Name()+":before_query", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before query callback: %w", err)
}
err = db.Callback().Query().After("gorm:query").
Register(p.Name()+":after_query", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after query callback: %w", err)
}
case "update":
err := db.Callback().Update().Before("gorm:update").
Register(p.Name()+":before_update", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before update callback: %w", err)
}
err = db.Callback().Update().After("gorm:update").
Register(p.Name()+":after_update", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after update callback: %w", err)
}
case "delete":
err := db.Callback().Delete().Before("gorm:delete").
Register(p.Name()+":before_delete", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before delete callback: %w", err)
}
err = db.Callback().Delete().After("gorm:delete").
Register(p.Name()+":after_delete", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after delete callback: %w", err)
}
case "raw":
err := db.Callback().Raw().Before("gorm:raw").
Register(p.Name()+":before_raw", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before raw callback: %w", err)
}
err = db.Callback().Raw().After("gorm:raw").
Register(p.Name()+":after_raw", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after raw callback: %w", err)
}
}
}
p.logger.Info("GORM追踪插件已初始化")
return nil
}
// beforeOperation 操作前回调
func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) {
// 检查是否应该跳过追踪
if p.shouldSkipTracing(db) {
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
// 获取操作信息
operation := p.getOperationType(db)
tableName := p.getTableName(db)
// 检查是否应该排除此表
if p.isExcludedTable(tableName) {
return
}
// 开始追踪
ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName)
// 添加基础属性
p.tracer.AddSpanAttributes(span,
attribute.String("db.system", "postgresql"),
attribute.String("db.operation", operation),
)
if tableName != "" {
p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName))
}
// 保存追踪信息到GORM context
db.Set(gormSpanKey, span)
db.Set(gormOperationKey, operation)
db.Set(gormTableNameKey, tableName)
db.Set(gormStartTimeKey, time.Now())
// 更新statement context
db.Statement.Context = ctx
}
// afterOperation 操作后回调
func (p *GormTracingPlugin) afterOperation(db *gorm.DB) {
// 获取span
spanValue, exists := db.Get(gormSpanKey)
if !exists {
return
}
span, ok := spanValue.(trace.Span)
if !ok {
return
}
defer span.End()
// 获取操作信息
operation, _ := db.Get(gormOperationKey)
tableName, _ := db.Get(gormTableNameKey)
startTime, _ := db.Get(gormStartTimeKey)
// 计算执行时间
var duration time.Duration
if st, ok := startTime.(time.Time); ok {
duration = time.Since(st)
p.tracer.AddSpanAttributes(span,
attribute.Int64("db.duration_ms", duration.Milliseconds()),
)
}
// 添加SQL信息
if p.config.IncludeSQL && db.Statement.SQL.String() != "" {
sql := db.Statement.SQL.String()
if p.config.SanitizeSQL {
sql = p.sanitizeSQL(sql)
}
p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql))
}
// 添加影响行数
if db.Statement.RowsAffected >= 0 {
p.tracer.AddSpanAttributes(span,
attribute.Int64("db.rows_affected", db.Statement.RowsAffected),
)
}
// 处理错误
if db.Error != nil {
p.tracer.SetSpanError(span, db.Error)
span.SetStatus(codes.Error, db.Error.Error())
p.logger.Error("数据库操作失败",
zap.String("operation", fmt.Sprintf("%v", operation)),
zap.String("table", fmt.Sprintf("%v", tableName)),
zap.Error(db.Error),
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
)
} else {
p.tracer.SetSpanSuccess(span)
span.SetStatus(codes.Ok, "success")
// 检查慢查询
if duration > p.config.SlowThreshold {
p.tracer.AddSpanAttributes(span,
attribute.Bool("db.slow_query", true),
)
p.logger.Warn("慢SQL查询检测",
zap.String("operation", fmt.Sprintf("%v", operation)),
zap.String("table", fmt.Sprintf("%v", tableName)),
zap.Duration("duration", duration),
zap.String("sql", db.Statement.SQL.String()),
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
)
}
}
}
// shouldSkipTracing 检查是否应该跳过追踪
func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool {
// 检查是否已有span避免重复追踪
if _, exists := db.Get(gormSpanKey); exists {
return true
}
return false
}
// getOperationType 获取操作类型
func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string {
switch db.Statement.ReflectValue.Kind() {
default:
sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String()))
if sql == "" {
return "unknown"
}
if strings.HasPrefix(sql, "SELECT") {
return "select"
} else if strings.HasPrefix(sql, "INSERT") {
return "insert"
} else if strings.HasPrefix(sql, "UPDATE") {
return "update"
} else if strings.HasPrefix(sql, "DELETE") {
return "delete"
} else if strings.HasPrefix(sql, "CREATE") {
return "create"
} else if strings.HasPrefix(sql, "DROP") {
return "drop"
} else if strings.HasPrefix(sql, "ALTER") {
return "alter"
}
return "query"
}
}
// getTableName 获取表名
func (p *GormTracingPlugin) getTableName(db *gorm.DB) string {
if db.Statement.Table != "" {
return db.Statement.Table
}
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
return db.Statement.Schema.Table
}
return ""
}
// isExcludedTable 检查是否为排除的表
func (p *GormTracingPlugin) isExcludedTable(tableName string) bool {
for _, excluded := range p.config.ExcludeTables {
if tableName == excluded {
return true
}
}
return false
}
// sanitizeSQL 清理SQL语句移除敏感信息
func (p *GormTracingPlugin) sanitizeSQL(sql string) string {
// 简单的SQL清理将参数替换为占位符
// 在生产环境中,您可能需要更复杂的清理逻辑
return strings.ReplaceAll(sql, "'", "?")
}

View File

@@ -0,0 +1,407 @@
package tracing
import (
"context"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// TracedRedisCache Redis缓存自动追踪包装器
type TracedRedisCache struct {
client redis.UniversalClient
tracer *Tracer
logger *zap.Logger
prefix string
config RedisTracingConfig
}
// RedisTracingConfig Redis追踪配置
type RedisTracingConfig struct {
IncludeKeys bool
IncludeValues bool
MaxKeyLength int
MaxValueLength int
SlowThreshold time.Duration
SanitizeValues bool
}
// DefaultRedisTracingConfig 默认Redis追踪配置
func DefaultRedisTracingConfig() RedisTracingConfig {
return RedisTracingConfig{
IncludeKeys: true,
IncludeValues: false, // 生产环境建议设为false保护敏感数据
MaxKeyLength: 100,
MaxValueLength: 1000,
SlowThreshold: 50 * time.Millisecond,
SanitizeValues: true,
}
}
// NewTracedRedisCache 创建带追踪的Redis缓存
func NewTracedRedisCache(client redis.UniversalClient, tracer *Tracer, logger *zap.Logger, prefix string) interfaces.CacheService {
return &TracedRedisCache{
client: client,
tracer: tracer,
logger: logger,
prefix: prefix,
config: DefaultRedisTracingConfig(),
}
}
// Name 返回服务名称
func (c *TracedRedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (c *TracedRedisCache) Initialize(ctx context.Context) error {
c.logger.Info("Redis缓存服务已初始化")
return nil
}
// HealthCheck 健康检查
func (c *TracedRedisCache) HealthCheck(ctx context.Context) error {
_, err := c.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (c *TracedRedisCache) Shutdown(ctx context.Context) error {
c.logger.Info("Redis缓存服务已关闭")
return c.client.Close()
}
// Get 获取缓存值
func (c *TracedRedisCache) Get(ctx context.Context, key string, dest interface{}) error {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "get", key)
defer span.End()
// 添加基础属性
c.addBaseAttributes(span, "get", key)
// 记录开始时间
startTime := time.Now()
// 构建完整键名
fullKey := c.buildKey(key)
// 执行Redis操作
result, err := c.client.Get(ctx, fullKey).Result()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
)
// 检查慢操作
if duration > c.config.SlowThreshold {
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.slow_operation", true),
)
c.logger.Warn("Redis慢操作检测",
zap.String("operation", "get"),
zap.String("key", c.sanitizeKey(key)),
zap.Duration("duration", duration),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
}
// 处理结果
if err != nil {
if err == redis.Nil {
// 缓存未命中
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.hit", false),
attribute.String("redis.result", "miss"),
)
c.tracer.SetSpanSuccess(span)
return interfaces.ErrCacheMiss
} else {
// Redis错误
c.tracer.SetSpanError(span, err)
c.logger.Error("Redis GET操作失败",
zap.String("key", c.sanitizeKey(key)),
zap.Error(err),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
return err
}
}
// 缓存命中
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.hit", true),
attribute.String("redis.result", "hit"),
attribute.Int("redis.value_size", len(result)),
)
// 反序列化
if err := c.deserialize(result, dest); err != nil {
c.tracer.SetSpanError(span, err)
return err
}
c.tracer.SetSpanSuccess(span)
return nil
}
// Set 设置缓存值
func (c *TracedRedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "set", key)
defer span.End()
// 添加基础属性
c.addBaseAttributes(span, "set", key)
// 处理TTL
var expiration time.Duration
if len(ttl) > 0 {
if duration, ok := ttl[0].(time.Duration); ok {
expiration = duration
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.ttl_seconds", int64(expiration.Seconds())),
)
}
}
// 记录开始时间
startTime := time.Now()
// 序列化值
serialized, err := c.serialize(value)
if err != nil {
c.tracer.SetSpanError(span, err)
return err
}
// 构建完整键名
fullKey := c.buildKey(key)
// 执行Redis操作
err = c.client.Set(ctx, fullKey, serialized, expiration).Err()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
attribute.Int("redis.value_size", len(serialized)),
)
// 检查慢操作
if duration > c.config.SlowThreshold {
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.slow_operation", true),
)
c.logger.Warn("Redis慢操作检测",
zap.String("operation", "set"),
zap.String("key", c.sanitizeKey(key)),
zap.Duration("duration", duration),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
}
// 处理错误
if err != nil {
c.tracer.SetSpanError(span, err)
c.logger.Error("Redis SET操作失败",
zap.String("key", c.sanitizeKey(key)),
zap.Error(err),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
return err
}
c.tracer.SetSpanSuccess(span)
return nil
}
// Delete 删除缓存
func (c *TracedRedisCache) Delete(ctx context.Context, keys ...string) error {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "delete", strings.Join(keys, ","))
defer span.End()
// 添加基础属性
c.tracer.AddSpanAttributes(span,
attribute.String("redis.operation", "delete"),
attribute.Int("redis.key_count", len(keys)),
)
// 记录开始时间
startTime := time.Now()
// 构建完整键名
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = c.buildKey(key)
}
// 执行Redis操作
deleted, err := c.client.Del(ctx, fullKeys...).Result()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
attribute.Int64("redis.deleted_count", deleted),
)
// 处理错误
if err != nil {
c.tracer.SetSpanError(span, err)
c.logger.Error("Redis DELETE操作失败",
zap.Strings("keys", c.sanitizeKeys(keys)),
zap.Error(err),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
return err
}
c.tracer.SetSpanSuccess(span)
return nil
}
// Exists 检查键是否存在
func (c *TracedRedisCache) Exists(ctx context.Context, key string) (bool, error) {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "exists", key)
defer span.End()
// 添加基础属性
c.addBaseAttributes(span, "exists", key)
// 记录开始时间
startTime := time.Now()
// 构建完整键名
fullKey := c.buildKey(key)
// 执行Redis操作
count, err := c.client.Exists(ctx, fullKey).Result()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
attribute.Bool("redis.exists", count > 0),
)
// 处理错误
if err != nil {
c.tracer.SetSpanError(span, err)
return false, err
}
c.tracer.SetSpanSuccess(span)
return count > 0, nil
}
// GetMultiple 批量获取(基础实现)
func (c *TracedRedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 简单实现逐个获取实际应用中可以使用MGET优化
for _, key := range keys {
var value interface{}
if err := c.Get(ctx, key, &value); err == nil {
result[key] = value
}
}
return result, nil
}
// SetMultiple 批量设置(基础实现)
func (c *TracedRedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
// 简单实现逐个设置实际应用中可以使用pipeline优化
for key, value := range data {
if err := c.Set(ctx, key, value, ttl...); err != nil {
return err
}
}
return nil
}
// DeletePattern 按模式删除(基础实现)
func (c *TracedRedisCache) DeletePattern(ctx context.Context, pattern string) error {
// 这里需要实现模式删除逻辑
return fmt.Errorf("DeletePattern not implemented")
}
// Keys 获取匹配的键(基础实现)
func (c *TracedRedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
// 这里需要实现键匹配逻辑
return nil, fmt.Errorf("Keys not implemented")
}
// Stats 获取缓存统计(基础实现)
func (c *TracedRedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
return interfaces.CacheStats{}, fmt.Errorf("Stats not implemented")
}
// 辅助方法
// addBaseAttributes 添加基础属性
func (c *TracedRedisCache) addBaseAttributes(span trace.Span, operation, key string) {
c.tracer.AddSpanAttributes(span,
attribute.String("redis.operation", operation),
attribute.String("db.system", "redis"),
)
if c.config.IncludeKeys {
sanitizedKey := c.sanitizeKey(key)
if len(sanitizedKey) <= c.config.MaxKeyLength {
c.tracer.AddSpanAttributes(span,
attribute.String("redis.key", sanitizedKey),
)
}
}
}
// buildKey 构建完整的Redis键名
func (c *TracedRedisCache) buildKey(key string) string {
if c.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", c.prefix, key)
}
// sanitizeKey 清理键名用于日志记录
func (c *TracedRedisCache) sanitizeKey(key string) string {
if len(key) <= c.config.MaxKeyLength {
return key
}
return key[:c.config.MaxKeyLength] + "..."
}
// sanitizeKeys 批量清理键名
func (c *TracedRedisCache) sanitizeKeys(keys []string) []string {
result := make([]string, len(keys))
for i, key := range keys {
result[i] = c.sanitizeKey(key)
}
return result
}
// serialize 序列化值(简单实现)
func (c *TracedRedisCache) serialize(value interface{}) (string, error) {
// 这里应该使用JSON或其他序列化方法
return fmt.Sprintf("%v", value), nil
}
// deserialize 反序列化值(简单实现)
func (c *TracedRedisCache) deserialize(data string, dest interface{}) error {
// 这里应该实现真正的反序列化逻辑
return fmt.Errorf("deserialize not fully implemented")
}

View File

@@ -0,0 +1,189 @@
package tracing
import (
"context"
"fmt"
"time"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/shared/interfaces"
)
// ServiceWrapper 服务包装器,提供自动追踪能力
type ServiceWrapper struct {
tracer *Tracer
logger *zap.Logger
}
// NewServiceWrapper 创建服务包装器
func NewServiceWrapper(tracer *Tracer, logger *zap.Logger) *ServiceWrapper {
return &ServiceWrapper{
tracer: tracer,
logger: logger,
}
}
// TraceServiceCall 追踪服务调用的通用方法
func (w *ServiceWrapper) TraceServiceCall(
ctx context.Context,
serviceName, methodName string,
fn func(context.Context) error,
) error {
// 创建span名称
spanName := fmt.Sprintf("%s.%s", serviceName, methodName)
// 开始追踪
ctx, span := w.tracer.StartSpan(ctx, spanName)
defer span.End()
// 添加基础属性
w.tracer.AddSpanAttributes(span,
attribute.String("service.name", serviceName),
attribute.String("service.method", methodName),
attribute.String("service.type", "business"),
)
// 记录开始时间
startTime := time.Now()
// 执行原始方法
err := fn(ctx)
// 计算执行时间
duration := time.Since(startTime)
w.tracer.AddSpanAttributes(span,
attribute.Int64("service.duration_ms", duration.Milliseconds()),
)
// 标记慢方法
if duration > 100*time.Millisecond {
w.tracer.AddSpanAttributes(span,
attribute.Bool("service.slow_method", true),
)
w.logger.Warn("慢方法检测",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Duration("duration", duration),
zap.String("trace_id", w.tracer.GetTraceID(ctx)),
)
}
// 处理错误
if err != nil {
w.tracer.SetSpanError(span, err)
w.logger.Error("服务方法执行失败",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Error(err),
zap.String("trace_id", w.tracer.GetTraceID(ctx)),
)
} else {
w.tracer.SetSpanSuccess(span)
}
return err
}
// TracedUserService 自动追踪的用户服务包装器
type TracedUserService struct {
service interfaces.UserService
wrapper *ServiceWrapper
}
// NewTracedUserService 创建带追踪的用户服务
func NewTracedUserService(service interfaces.UserService, wrapper *ServiceWrapper) interfaces.UserService {
return &TracedUserService{
service: service,
wrapper: wrapper,
}
}
func (t *TracedUserService) Name() string {
return "user-service"
}
func (t *TracedUserService) Initialize(ctx context.Context) error {
return t.wrapper.TraceServiceCall(ctx, "user", "initialize", t.service.Initialize)
}
func (t *TracedUserService) HealthCheck(ctx context.Context) error {
return t.service.HealthCheck(ctx) // 不追踪健康检查
}
func (t *TracedUserService) Shutdown(ctx context.Context) error {
return t.wrapper.TraceServiceCall(ctx, "user", "shutdown", t.service.Shutdown)
}
func (t *TracedUserService) Register(ctx context.Context, req *dto.RegisterRequest) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "register", func(ctx context.Context) error {
result, err = t.service.Register(ctx, req)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}
func (t *TracedUserService) LoginWithPassword(ctx context.Context, req *dto.LoginWithPasswordRequest) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "login_password", func(ctx context.Context) error {
result, err = t.service.LoginWithPassword(ctx, req)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}
func (t *TracedUserService) LoginWithSMS(ctx context.Context, req *dto.LoginWithSMSRequest) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "login_sms", func(ctx context.Context) error {
result, err = t.service.LoginWithSMS(ctx, req)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}
func (t *TracedUserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
return t.wrapper.TraceServiceCall(ctx, "user", "change_password", func(ctx context.Context) error {
return t.service.ChangePassword(ctx, userID, req)
})
}
func (t *TracedUserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "get_by_id", func(ctx context.Context) error {
result, err = t.service.GetByID(ctx, id)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}

View File

@@ -0,0 +1,474 @@
package tracing
import (
"context"
"fmt"
"sync"
"time"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
// TracerConfig 追踪器配置
type TracerConfig struct {
ServiceName string
ServiceVersion string
Environment string
Endpoint string
SampleRate float64
Enabled bool
}
// DefaultTracerConfig 默认追踪器配置
func DefaultTracerConfig() TracerConfig {
return TracerConfig{
ServiceName: "tyapi-server",
ServiceVersion: "1.0.0",
Environment: "development",
Endpoint: "http://localhost:4317",
SampleRate: 0.1,
Enabled: true,
}
}
// Tracer 链路追踪器
type Tracer struct {
config TracerConfig
logger *zap.Logger
provider *sdktrace.TracerProvider
tracer trace.Tracer
mutex sync.RWMutex
initialized bool
shutdown func(context.Context) error
}
// NewTracer 创建链路追踪器
func NewTracer(config TracerConfig, logger *zap.Logger) *Tracer {
return &Tracer{
config: config,
logger: logger,
}
}
// Initialize 初始化追踪器
func (t *Tracer) Initialize(ctx context.Context) error {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.initialized {
return nil
}
if !t.config.Enabled {
t.logger.Info("Tracing is disabled")
return nil
}
// 创建资源
res, err := resource.New(ctx,
resource.WithAttributes(
attribute.String("service.name", t.config.ServiceName),
attribute.String("service.version", t.config.ServiceVersion),
attribute.String("environment", t.config.Environment),
),
)
if err != nil {
return fmt.Errorf("failed to create resource: %w", err)
}
// 创建采样器
sampler := sdktrace.TraceIDRatioBased(t.config.SampleRate)
// 创建导出器
var spanProcessor sdktrace.SpanProcessor
if t.config.Endpoint != "" {
// 使用OTLP gRPC导出器支持Jaeger、Tempo等
exporter, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(t.config.Endpoint),
otlptracegrpc.WithInsecure(), // 开发环境使用生产环境应配置TLS
otlptracegrpc.WithTimeout(time.Second*10),
otlptracegrpc.WithRetry(otlptracegrpc.RetryConfig{
Enabled: true,
InitialInterval: time.Millisecond * 100,
MaxInterval: time.Second * 5,
MaxElapsedTime: time.Second * 30,
}),
)
if err != nil {
t.logger.Warn("Failed to create OTLP exporter, using noop exporter",
zap.Error(err),
zap.String("endpoint", t.config.Endpoint))
spanProcessor = sdktrace.NewSimpleSpanProcessor(&noopExporter{})
} else {
// 在生产环境中使用批处理器以提高性能
spanProcessor = sdktrace.NewBatchSpanProcessor(exporter,
sdktrace.WithBatchTimeout(time.Second*5),
sdktrace.WithMaxExportBatchSize(512),
sdktrace.WithMaxQueueSize(2048),
sdktrace.WithExportTimeout(time.Second*30),
)
t.logger.Info("OTLP exporter initialized successfully",
zap.String("endpoint", t.config.Endpoint))
}
} else {
// 如果没有配置端点,使用空导出器
spanProcessor = sdktrace.NewSimpleSpanProcessor(&noopExporter{})
t.logger.Info("Using noop exporter (no endpoint configured)")
}
// 创建TracerProvider
provider := sdktrace.NewTracerProvider(
sdktrace.WithResource(res),
sdktrace.WithSampler(sampler),
sdktrace.WithSpanProcessor(spanProcessor),
)
// 设置全局TracerProvider
otel.SetTracerProvider(provider)
// 创建Tracer
tracer := provider.Tracer(t.config.ServiceName)
t.provider = provider
t.tracer = tracer
t.shutdown = func(ctx context.Context) error {
return provider.Shutdown(ctx)
}
t.initialized = true
t.logger.Info("Tracing initialized successfully",
zap.String("service", t.config.ServiceName),
zap.Float64("sample_rate", t.config.SampleRate))
return nil
}
// StartSpan 开始一个新的span
func (t *Tracer) StartSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
if !t.initialized || !t.config.Enabled {
return ctx, trace.SpanFromContext(ctx)
}
return t.tracer.Start(ctx, name, opts...)
}
// StartHTTPSpan 开始一个HTTP span
func (t *Tracer) StartHTTPSpan(ctx context.Context, method, path string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("%s %s", method, path)
// 检查是否已有错误标记,如果有则使用"error"作为操作名
// 这样可以匹配Jaeger采样配置中的错误操作策略
if ctx.Value("otel_error_request") != nil {
spanName = "error"
}
ctx, span := t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
attribute.String("http.method", method),
attribute.String("http.route", path),
),
)
// 保存原始操作名,以便在错误发生时可以更新
if ctx.Value("otel_error_request") == nil {
ctx = context.WithValue(ctx, "otel_original_operation", spanName)
}
return ctx, span
}
// StartDBSpan 开始一个数据库span
func (t *Tracer) StartDBSpan(ctx context.Context, operation, table string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("db.%s.%s", operation, table)
return t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("db.operation", operation),
attribute.String("db.table", table),
attribute.String("db.system", "postgresql"),
),
)
}
// StartCacheSpan 开始一个缓存span
func (t *Tracer) StartCacheSpan(ctx context.Context, operation, key string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("cache.%s", operation)
return t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("cache.operation", operation),
attribute.String("cache.system", "redis"),
),
)
}
// StartExternalAPISpan 开始一个外部API调用span
func (t *Tracer) StartExternalAPISpan(ctx context.Context, service, operation string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("api.%s.%s", service, operation)
return t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("api.service", service),
attribute.String("api.operation", operation),
),
)
}
// AddSpanAttributes 添加span属性
func (t *Tracer) AddSpanAttributes(span trace.Span, attrs ...attribute.KeyValue) {
if span.IsRecording() {
span.SetAttributes(attrs...)
}
}
// SetSpanError 设置span错误
func (t *Tracer) SetSpanError(span trace.Span, err error) {
if span.IsRecording() {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
// 将span操作名更新为"error"以匹配Jaeger采样配置
// 注意这是一种变通方法因为OpenTelemetry不支持直接更改span名称
// 我们通过添加特殊属性来标识这是一个错误span
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录错误日志包含trace ID便于关联
if t.logger != nil {
ctx := trace.ContextWithSpan(context.Background(), span)
t.logger.Error("操作发生错误",
zap.Error(err),
zap.String("trace_id", t.GetTraceID(ctx)),
zap.String("span_id", t.GetSpanID(ctx)),
)
}
}
}
// SetSpanSuccess 设置span成功
func (t *Tracer) SetSpanSuccess(span trace.Span) {
if span.IsRecording() {
span.SetStatus(codes.Ok, "success")
}
}
// SetHTTPStatus 根据HTTP状态码设置span状态
func (t *Tracer) SetHTTPStatus(span trace.Span, statusCode int) {
if !span.IsRecording() {
return
}
// 添加HTTP状态码属性
span.SetAttributes(attribute.Int("http.status_code", statusCode))
// 对于4xx和5xx错误标记为错误并应用错误采样策略
if statusCode >= 400 {
errorMsg := fmt.Sprintf("HTTP %d", statusCode)
span.SetStatus(codes.Error, errorMsg)
// 添加错误操作标记以匹配Jaeger采样配置
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录HTTP错误
if t.logger != nil {
ctx := trace.ContextWithSpan(context.Background(), span)
t.logger.Warn("HTTP请求错误",
zap.Int("status_code", statusCode),
zap.String("trace_id", t.GetTraceID(ctx)),
zap.String("span_id", t.GetSpanID(ctx)),
)
}
} else {
span.SetStatus(codes.Ok, "success")
}
}
// GetTraceID 获取当前上下文的trace ID
func (t *Tracer) GetTraceID(ctx context.Context) string {
span := trace.SpanFromContext(ctx)
if span.SpanContext().IsValid() {
return span.SpanContext().TraceID().String()
}
return ""
}
// GetSpanID 获取当前上下文的span ID
func (t *Tracer) GetSpanID(ctx context.Context) string {
span := trace.SpanFromContext(ctx)
if span.SpanContext().IsValid() {
return span.SpanContext().SpanID().String()
}
return ""
}
// IsTracing 检查是否正在追踪
func (t *Tracer) IsTracing(ctx context.Context) bool {
span := trace.SpanFromContext(ctx)
return span.SpanContext().IsValid() && span.IsRecording()
}
// Shutdown 关闭追踪器
func (t *Tracer) Shutdown(ctx context.Context) error {
t.mutex.Lock()
defer t.mutex.Unlock()
if !t.initialized || t.shutdown == nil {
return nil
}
err := t.shutdown(ctx)
if err != nil {
t.logger.Error("Failed to shutdown tracer", zap.Error(err))
return err
}
t.initialized = false
t.logger.Info("Tracer shutdown successfully")
return nil
}
// GetStats 获取追踪统计信息
func (t *Tracer) GetStats() map[string]interface{} {
t.mutex.RLock()
defer t.mutex.RUnlock()
return map[string]interface{}{
"initialized": t.initialized,
"enabled": t.config.Enabled,
"service_name": t.config.ServiceName,
"service_version": t.config.ServiceVersion,
"environment": t.config.Environment,
"sample_rate": t.config.SampleRate,
"endpoint": t.config.Endpoint,
}
}
// 实现Service接口
// Name 返回服务名称
func (t *Tracer) Name() string {
return "tracer"
}
// HealthCheck 健康检查
func (t *Tracer) HealthCheck(ctx context.Context) error {
if !t.config.Enabled {
return nil
}
if !t.initialized {
return fmt.Errorf("tracer not initialized")
}
return nil
}
// noopExporter 简单的无操作导出器(用于演示)
type noopExporter struct{}
func (e *noopExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error {
// 在实际应用中这里应该将spans发送到Jaeger或其他追踪系统
return nil
}
func (e *noopExporter) Shutdown(ctx context.Context) error {
return nil
}
// TraceMiddleware 追踪中间件工厂
func (t *Tracer) TraceMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !t.initialized || !t.config.Enabled {
c.Next()
return
}
// 开始HTTP span
ctx, span := t.StartHTTPSpan(c.Request.Context(), c.Request.Method, c.FullPath())
defer span.End()
// 将trace ID添加到响应头
traceID := t.GetTraceID(ctx)
if traceID != "" {
c.Header("X-Trace-ID", traceID)
}
// 将span上下文存储到gin上下文
c.Request = c.Request.WithContext(ctx)
// 处理请求
c.Next()
// 设置HTTP状态码
t.SetHTTPStatus(span, c.Writer.Status())
// 添加响应信息
t.AddSpanAttributes(span,
attribute.Int("http.status_code", c.Writer.Status()),
attribute.Int("http.response_size", c.Writer.Size()),
)
// 添加错误信息
if len(c.Errors) > 0 {
errMsg := c.Errors.String()
t.SetSpanError(span, fmt.Errorf(errMsg))
}
}
}
// GinTraceMiddleware 兼容旧的方法名,保持向后兼容
func (t *Tracer) GinTraceMiddleware() gin.HandlerFunc {
return t.TraceMiddleware()
}
// WithTracing 添加追踪到上下文的辅助函数
func WithTracing(ctx context.Context, tracer *Tracer, name string) (context.Context, trace.Span) {
return tracer.StartSpan(ctx, name)
}
// TraceFunction 追踪函数执行的辅助函数
func (t *Tracer) TraceFunction(ctx context.Context, name string, fn func(context.Context) error) error {
ctx, span := t.StartSpan(ctx, name)
defer span.End()
err := fn(ctx)
if err != nil {
t.SetSpanError(span, err)
} else {
t.SetSpanSuccess(span)
}
return err
}
// TraceFunctionWithResult 追踪带返回值的函数执行
func TraceFunctionWithResult[T any](ctx context.Context, tracer *Tracer, name string, fn func(context.Context) (T, error)) (T, error) {
ctx, span := tracer.StartSpan(ctx, name)
defer span.End()
result, err := fn(ctx)
if err != nil {
tracer.SetSpanError(span, err)
} else {
tracer.SetSpanSuccess(span)
}
return result, err
}