Initial commit: Basic project structure and dependencies
This commit is contained in:
284
internal/shared/cache/redis_cache.go
vendored
Normal file
284
internal/shared/cache/redis_cache.go
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// RedisCache Redis缓存实现
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
logger *zap.Logger
|
||||
prefix string
|
||||
|
||||
// 统计信息
|
||||
hits int64
|
||||
misses int64
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存实例
|
||||
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
logger: logger,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (r *RedisCache) Name() string {
|
||||
return "redis-cache"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (r *RedisCache) Initialize(ctx context.Context) error {
|
||||
// 测试连接
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to connect to Redis", zap.Error(err))
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Redis cache service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (r *RedisCache) HealthCheck(ctx context.Context) error {
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (r *RedisCache) Shutdown(ctx context.Context) error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
val, err := r.client.Get(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
r.misses++
|
||||
return fmt.Errorf("cache miss: key %s not found", key)
|
||||
}
|
||||
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.hits++
|
||||
return json.Unmarshal([]byte(val), dest)
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value: %w", err)
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
case string:
|
||||
expiration, _ = time.ParseDuration(v)
|
||||
default:
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
|
||||
err = r.client.Set(ctx, fullKey, data, expiration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
err := r.client.Del(ctx, fullKeys...).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
count, err := r.client.Exists(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取
|
||||
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
values, err := r.client.MGet(ctx, fullKeys...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for i, val := range values {
|
||||
if val != nil {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
|
||||
result[keys[i]] = data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置
|
||||
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
default:
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
for key, value := range data {
|
||||
fullKey := r.getFullKey(key)
|
||||
jsonData, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pipe.Set(ctx, fullKey, jsonData, expiration)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeletePattern 按模式删除
|
||||
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
return r.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys 获取匹配的键
|
||||
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 移除前缀
|
||||
result := make([]string, len(keys))
|
||||
prefixLen := len(r.prefix) + 1 // +1 for ":"
|
||||
for i, key := range keys {
|
||||
if len(key) > prefixLen {
|
||||
result[i] = key[prefixLen:]
|
||||
} else {
|
||||
result[i] = key
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计
|
||||
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
|
||||
dbSize, _ := r.client.DBSize(ctx).Result()
|
||||
|
||||
return interfaces.CacheStats{
|
||||
Hits: r.hits,
|
||||
Misses: r.misses,
|
||||
Keys: dbSize,
|
||||
Memory: 0, // 暂时设为0,后续可解析Redis info
|
||||
Connections: 0, // 暂时设为0,后续可解析Redis info
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getFullKey 获取完整键名
|
||||
func (r *RedisCache) getFullKey(key string) string {
|
||||
if r.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", r.prefix, key)
|
||||
}
|
||||
|
||||
// Flush 清空所有缓存
|
||||
func (r *RedisCache) Flush(ctx context.Context) error {
|
||||
if r.prefix == "" {
|
||||
return r.client.FlushDB(ctx).Err()
|
||||
}
|
||||
|
||||
// 只删除带前缀的键
|
||||
return r.DeletePattern(ctx, "*")
|
||||
}
|
||||
|
||||
// GetClient 获取原始Redis客户端
|
||||
func (r *RedisCache) GetClient() *redis.Client {
|
||||
return r.client
|
||||
}
|
||||
195
internal/shared/database/database.go
Normal file
195
internal/shared/database/database.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Name string
|
||||
SSLMode string
|
||||
Timezone string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
}
|
||||
|
||||
// DB 数据库包装器
|
||||
type DB struct {
|
||||
*gorm.DB
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewConnection 创建新的数据库连接
|
||||
func NewConnection(config Config) (*DB, error) {
|
||||
// 构建DSN
|
||||
dsn := buildDSN(config)
|
||||
|
||||
// 配置GORM
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true, // 使用单数表名
|
||||
},
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
return &DB{
|
||||
DB: db,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildDSN 构建数据库连接字符串
|
||||
func buildDSN(config Config) string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s",
|
||||
config.Host,
|
||||
config.User,
|
||||
config.Password,
|
||||
config.Name,
|
||||
config.Port,
|
||||
config.SSLMode,
|
||||
config.Timezone,
|
||||
)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// Ping 检查数据库连接
|
||||
func (db *DB) Ping() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Ping()
|
||||
}
|
||||
|
||||
// GetStats 获取连接池统计信息
|
||||
func (db *DB) GetStats() (map[string]interface{}, error) {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := sqlDB.Stats()
|
||||
return map[string]interface{}{
|
||||
"max_open_connections": stats.MaxOpenConnections,
|
||||
"open_connections": stats.OpenConnections,
|
||||
"in_use": stats.InUse,
|
||||
"idle": stats.Idle,
|
||||
"wait_count": stats.WaitCount,
|
||||
"wait_duration": stats.WaitDuration,
|
||||
"max_idle_closed": stats.MaxIdleClosed,
|
||||
"max_idle_time_closed": stats.MaxIdleTimeClosed,
|
||||
"max_lifetime_closed": stats.MaxLifetimeClosed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginTx 开始事务
|
||||
func (db *DB) BeginTx() *gorm.DB {
|
||||
return db.DB.Begin()
|
||||
}
|
||||
|
||||
// Migrate 执行数据库迁移
|
||||
func (db *DB) Migrate(models ...interface{}) error {
|
||||
return db.DB.AutoMigrate(models...)
|
||||
}
|
||||
|
||||
// IsHealthy 检查数据库健康状态
|
||||
func (db *DB) IsHealthy() bool {
|
||||
return db.Ping() == nil
|
||||
}
|
||||
|
||||
// WithContext 返回带上下文的数据库实例
|
||||
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
|
||||
if c, ok := ctx.(context.Context); ok {
|
||||
return db.DB.WithContext(c)
|
||||
}
|
||||
return db.DB
|
||||
}
|
||||
|
||||
// 事务包装器
|
||||
type TxWrapper struct {
|
||||
tx *gorm.DB
|
||||
}
|
||||
|
||||
// NewTxWrapper 创建事务包装器
|
||||
func (db *DB) NewTxWrapper() *TxWrapper {
|
||||
return &TxWrapper{
|
||||
tx: db.BeginTx(),
|
||||
}
|
||||
}
|
||||
|
||||
// Commit 提交事务
|
||||
func (tx *TxWrapper) Commit() error {
|
||||
return tx.tx.Commit().Error
|
||||
}
|
||||
|
||||
// Rollback 回滚事务
|
||||
func (tx *TxWrapper) Rollback() error {
|
||||
return tx.tx.Rollback().Error
|
||||
}
|
||||
|
||||
// GetDB 获取事务数据库实例
|
||||
func (tx *TxWrapper) GetDB() *gorm.DB {
|
||||
return tx.tx
|
||||
}
|
||||
|
||||
// WithTx 在事务中执行函数
|
||||
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
|
||||
tx := db.BeginTx()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
313
internal/shared/events/event_bus.go
Normal file
313
internal/shared/events/event_bus.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// MemoryEventBus 内存事件总线实现
|
||||
type MemoryEventBus struct {
|
||||
subscribers map[string][]interfaces.EventHandler
|
||||
mutex sync.RWMutex
|
||||
logger *zap.Logger
|
||||
running bool
|
||||
stopCh chan struct{}
|
||||
eventQueue chan eventTask
|
||||
workerCount int
|
||||
}
|
||||
|
||||
// eventTask 事件任务
|
||||
type eventTask struct {
|
||||
event interfaces.Event
|
||||
handler interfaces.EventHandler
|
||||
retries int
|
||||
}
|
||||
|
||||
// NewMemoryEventBus 创建内存事件总线
|
||||
func NewMemoryEventBus(logger *zap.Logger, workerCount int) *MemoryEventBus {
|
||||
if workerCount <= 0 {
|
||||
workerCount = 5 // 默认5个工作协程
|
||||
}
|
||||
|
||||
return &MemoryEventBus{
|
||||
subscribers: make(map[string][]interfaces.EventHandler),
|
||||
logger: logger,
|
||||
eventQueue: make(chan eventTask, 1000), // 缓冲1000个事件
|
||||
workerCount: workerCount,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (bus *MemoryEventBus) Name() string {
|
||||
return "memory-event-bus"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (bus *MemoryEventBus) Initialize(ctx context.Context) error {
|
||||
bus.logger.Info("Memory event bus service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (bus *MemoryEventBus) HealthCheck(ctx context.Context) error {
|
||||
if !bus.running {
|
||||
return fmt.Errorf("event bus is not running")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (bus *MemoryEventBus) Shutdown(ctx context.Context) error {
|
||||
bus.Stop(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动事件总线
|
||||
func (bus *MemoryEventBus) Start(ctx context.Context) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
if bus.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
bus.running = true
|
||||
|
||||
// 启动工作协程
|
||||
for i := 0; i < bus.workerCount; i++ {
|
||||
go bus.worker(i)
|
||||
}
|
||||
|
||||
bus.logger.Info("Event bus started", zap.Int("workers", bus.workerCount))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止事件总线
|
||||
func (bus *MemoryEventBus) Stop(ctx context.Context) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
if !bus.running {
|
||||
return nil
|
||||
}
|
||||
|
||||
bus.running = false
|
||||
close(bus.stopCh)
|
||||
|
||||
// 等待所有工作协程结束或超时
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
time.Sleep(5 * time.Second) // 给工作协程5秒时间结束
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
|
||||
bus.logger.Info("Event bus stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Publish 发布事件(同步)
|
||||
func (bus *MemoryEventBus) Publish(ctx context.Context, event interfaces.Event) error {
|
||||
bus.mutex.RLock()
|
||||
handlers := bus.subscribers[event.GetType()]
|
||||
bus.mutex.RUnlock()
|
||||
|
||||
if len(handlers) == 0 {
|
||||
bus.logger.Debug("No handlers for event type", zap.String("type", event.GetType()))
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, handler := range handlers {
|
||||
if handler.IsAsync() {
|
||||
// 异步处理
|
||||
select {
|
||||
case bus.eventQueue <- eventTask{event: event, handler: handler, retries: 0}:
|
||||
default:
|
||||
bus.logger.Warn("Event queue is full, dropping event",
|
||||
zap.String("type", event.GetType()),
|
||||
zap.String("handler", handler.GetName()))
|
||||
}
|
||||
} else {
|
||||
// 同步处理
|
||||
if err := bus.handleEventWithRetry(ctx, event, handler); err != nil {
|
||||
bus.logger.Error("Failed to handle event synchronously",
|
||||
zap.String("type", event.GetType()),
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishBatch 批量发布事件
|
||||
func (bus *MemoryEventBus) PublishBatch(ctx context.Context, events []interfaces.Event) error {
|
||||
for _, event := range events {
|
||||
if err := bus.Publish(ctx, event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe 订阅事件
|
||||
func (bus *MemoryEventBus) Subscribe(eventType string, handler interfaces.EventHandler) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
handlers := bus.subscribers[eventType]
|
||||
|
||||
// 检查是否已经订阅
|
||||
for _, h := range handlers {
|
||||
if h.GetName() == handler.GetName() {
|
||||
return fmt.Errorf("handler %s already subscribed to event type %s", handler.GetName(), eventType)
|
||||
}
|
||||
}
|
||||
|
||||
bus.subscribers[eventType] = append(handlers, handler)
|
||||
|
||||
bus.logger.Info("Handler subscribed to event",
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.String("event_type", eventType))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unsubscribe 取消订阅
|
||||
func (bus *MemoryEventBus) Unsubscribe(eventType string, handler interfaces.EventHandler) error {
|
||||
bus.mutex.Lock()
|
||||
defer bus.mutex.Unlock()
|
||||
|
||||
handlers := bus.subscribers[eventType]
|
||||
for i, h := range handlers {
|
||||
if h.GetName() == handler.GetName() {
|
||||
// 删除处理器
|
||||
bus.subscribers[eventType] = append(handlers[:i], handlers[i+1:]...)
|
||||
|
||||
bus.logger.Info("Handler unsubscribed from event",
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.String("event_type", eventType))
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("handler %s not found for event type %s", handler.GetName(), eventType)
|
||||
}
|
||||
|
||||
// GetSubscribers 获取订阅者
|
||||
func (bus *MemoryEventBus) GetSubscribers(eventType string) []interfaces.EventHandler {
|
||||
bus.mutex.RLock()
|
||||
defer bus.mutex.RUnlock()
|
||||
|
||||
handlers := bus.subscribers[eventType]
|
||||
result := make([]interfaces.EventHandler, len(handlers))
|
||||
copy(result, handlers)
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// worker 工作协程
|
||||
func (bus *MemoryEventBus) worker(id int) {
|
||||
bus.logger.Debug("Event worker started", zap.Int("worker_id", id))
|
||||
|
||||
for {
|
||||
select {
|
||||
case task := <-bus.eventQueue:
|
||||
bus.processEventTask(task)
|
||||
case <-bus.stopCh:
|
||||
bus.logger.Debug("Event worker stopped", zap.Int("worker_id", id))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// processEventTask 处理事件任务
|
||||
func (bus *MemoryEventBus) processEventTask(task eventTask) {
|
||||
ctx := context.Background()
|
||||
|
||||
err := bus.handleEventWithRetry(ctx, task.event, task.handler)
|
||||
if err != nil {
|
||||
retryConfig := task.handler.GetRetryConfig()
|
||||
|
||||
if task.retries < retryConfig.MaxRetries {
|
||||
// 重试
|
||||
delay := time.Duration(float64(retryConfig.RetryDelay) *
|
||||
(1 + retryConfig.BackoffFactor*float64(task.retries)))
|
||||
|
||||
if delay > retryConfig.MaxDelay {
|
||||
delay = retryConfig.MaxDelay
|
||||
}
|
||||
|
||||
go func() {
|
||||
time.Sleep(delay)
|
||||
task.retries++
|
||||
|
||||
select {
|
||||
case bus.eventQueue <- task:
|
||||
default:
|
||||
bus.logger.Error("Failed to requeue event for retry",
|
||||
zap.String("type", task.event.GetType()),
|
||||
zap.String("handler", task.handler.GetName()),
|
||||
zap.Int("retries", task.retries))
|
||||
}
|
||||
}()
|
||||
} else {
|
||||
bus.logger.Error("Event processing failed after max retries",
|
||||
zap.String("type", task.event.GetType()),
|
||||
zap.String("handler", task.handler.GetName()),
|
||||
zap.Int("retries", task.retries),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleEventWithRetry 处理事件并支持重试
|
||||
func (bus *MemoryEventBus) handleEventWithRetry(ctx context.Context, event interfaces.Event, handler interfaces.EventHandler) error {
|
||||
start := time.Now()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(start)
|
||||
bus.logger.Debug("Event handled",
|
||||
zap.String("type", event.GetType()),
|
||||
zap.String("handler", handler.GetName()),
|
||||
zap.Duration("duration", duration))
|
||||
}()
|
||||
|
||||
return handler.Handle(ctx, event)
|
||||
}
|
||||
|
||||
// GetStats 获取事件总线统计信息
|
||||
func (bus *MemoryEventBus) GetStats() map[string]interface{} {
|
||||
bus.mutex.RLock()
|
||||
defer bus.mutex.RUnlock()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"running": bus.running,
|
||||
"worker_count": bus.workerCount,
|
||||
"queue_length": len(bus.eventQueue),
|
||||
"queue_capacity": cap(bus.eventQueue),
|
||||
"event_types": len(bus.subscribers),
|
||||
}
|
||||
|
||||
// 各事件类型的订阅者数量
|
||||
eventTypes := make(map[string]int)
|
||||
for eventType, handlers := range bus.subscribers {
|
||||
eventTypes[eventType] = len(handlers)
|
||||
}
|
||||
stats["subscribers"] = eventTypes
|
||||
|
||||
return stats
|
||||
}
|
||||
282
internal/shared/health/health_checker.go
Normal file
282
internal/shared/health/health_checker.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package health
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HealthChecker 健康检查器实现
|
||||
type HealthChecker struct {
|
||||
services map[string]interfaces.Service
|
||||
cache map[string]*interfaces.HealthStatus
|
||||
cacheTTL time.Duration
|
||||
mutex sync.RWMutex
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewHealthChecker 创建健康检查器
|
||||
func NewHealthChecker(logger *zap.Logger) *HealthChecker {
|
||||
return &HealthChecker{
|
||||
services: make(map[string]interfaces.Service),
|
||||
cache: make(map[string]*interfaces.HealthStatus),
|
||||
cacheTTL: 30 * time.Second, // 缓存30秒
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterService 注册服务
|
||||
func (h *HealthChecker) RegisterService(service interfaces.Service) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.services[service.Name()] = service
|
||||
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
|
||||
}
|
||||
|
||||
// CheckHealth 检查单个服务健康状态
|
||||
func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *interfaces.HealthStatus {
|
||||
h.mutex.RLock()
|
||||
service, exists := h.services[serviceName]
|
||||
if !exists {
|
||||
h.mutex.RUnlock()
|
||||
return &interfaces.HealthStatus{
|
||||
Status: "DOWN",
|
||||
Message: "Service not found",
|
||||
Details: map[string]interface{}{"error": "service not registered"},
|
||||
CheckedAt: time.Now().Unix(),
|
||||
ResponseTime: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// 检查缓存
|
||||
if cached, exists := h.cache[serviceName]; exists {
|
||||
if time.Since(time.Unix(cached.CheckedAt, 0)) < h.cacheTTL {
|
||||
h.mutex.RUnlock()
|
||||
return cached
|
||||
}
|
||||
}
|
||||
h.mutex.RUnlock()
|
||||
|
||||
// 执行健康检查
|
||||
start := time.Now()
|
||||
status := &interfaces.HealthStatus{
|
||||
CheckedAt: start.Unix(),
|
||||
}
|
||||
|
||||
// 设置超时上下文
|
||||
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := service.HealthCheck(checkCtx)
|
||||
responseTime := time.Since(start).Milliseconds()
|
||||
status.ResponseTime = responseTime
|
||||
|
||||
if err != nil {
|
||||
status.Status = "DOWN"
|
||||
status.Message = "Health check failed"
|
||||
status.Details = map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"service_name": serviceName,
|
||||
"check_time": start.Format(time.RFC3339),
|
||||
}
|
||||
h.logger.Warn("Service health check failed",
|
||||
zap.String("service", serviceName),
|
||||
zap.Error(err),
|
||||
zap.Int64("response_time_ms", responseTime))
|
||||
} else {
|
||||
status.Status = "UP"
|
||||
status.Message = "Service is healthy"
|
||||
status.Details = map[string]interface{}{
|
||||
"service_name": serviceName,
|
||||
"check_time": start.Format(time.RFC3339),
|
||||
}
|
||||
h.logger.Debug("Service health check passed",
|
||||
zap.String("service", serviceName),
|
||||
zap.Int64("response_time_ms", responseTime))
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
h.mutex.Lock()
|
||||
h.cache[serviceName] = status
|
||||
h.mutex.Unlock()
|
||||
|
||||
return status
|
||||
}
|
||||
|
||||
// CheckAllHealth 检查所有服务的健康状态
|
||||
func (h *HealthChecker) CheckAllHealth(ctx context.Context) map[string]*interfaces.HealthStatus {
|
||||
h.mutex.RLock()
|
||||
serviceNames := make([]string, 0, len(h.services))
|
||||
for name := range h.services {
|
||||
serviceNames = append(serviceNames, name)
|
||||
}
|
||||
h.mutex.RUnlock()
|
||||
|
||||
results := make(map[string]*interfaces.HealthStatus)
|
||||
var wg sync.WaitGroup
|
||||
var mutex sync.Mutex
|
||||
|
||||
// 并发检查所有服务
|
||||
for _, serviceName := range serviceNames {
|
||||
wg.Add(1)
|
||||
go func(name string) {
|
||||
defer wg.Done()
|
||||
status := h.CheckHealth(ctx, name)
|
||||
|
||||
mutex.Lock()
|
||||
results[name] = status
|
||||
mutex.Unlock()
|
||||
}(serviceName)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
return results
|
||||
}
|
||||
|
||||
// GetOverallStatus 获取整体健康状态
|
||||
func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.HealthStatus {
|
||||
allStatus := h.CheckAllHealth(ctx)
|
||||
|
||||
overall := &interfaces.HealthStatus{
|
||||
CheckedAt: time.Now().Unix(),
|
||||
ResponseTime: 0,
|
||||
Details: make(map[string]interface{}),
|
||||
}
|
||||
|
||||
var totalResponseTime int64
|
||||
healthyCount := 0
|
||||
totalCount := len(allStatus)
|
||||
|
||||
for serviceName, status := range allStatus {
|
||||
overall.Details[serviceName] = map[string]interface{}{
|
||||
"status": status.Status,
|
||||
"message": status.Message,
|
||||
"response_time": status.ResponseTime,
|
||||
}
|
||||
|
||||
totalResponseTime += status.ResponseTime
|
||||
if status.Status == "UP" {
|
||||
healthyCount++
|
||||
}
|
||||
}
|
||||
|
||||
if totalCount > 0 {
|
||||
overall.ResponseTime = totalResponseTime / int64(totalCount)
|
||||
}
|
||||
|
||||
// 确定整体状态
|
||||
if healthyCount == totalCount {
|
||||
overall.Status = "UP"
|
||||
overall.Message = "All services are healthy"
|
||||
} else if healthyCount == 0 {
|
||||
overall.Status = "DOWN"
|
||||
overall.Message = "All services are down"
|
||||
} else {
|
||||
overall.Status = "DEGRADED"
|
||||
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
|
||||
}
|
||||
|
||||
return overall
|
||||
}
|
||||
|
||||
// GetServiceNames 获取所有注册的服务名称
|
||||
func (h *HealthChecker) GetServiceNames() []string {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
names := make([]string, 0, len(h.services))
|
||||
for name := range h.services {
|
||||
names = append(names, name)
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// RemoveService 移除服务
|
||||
func (h *HealthChecker) RemoveService(serviceName string) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
delete(h.services, serviceName)
|
||||
delete(h.cache, serviceName)
|
||||
|
||||
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
|
||||
}
|
||||
|
||||
// ClearCache 清除缓存
|
||||
func (h *HealthChecker) ClearCache() {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.cache = make(map[string]*interfaces.HealthStatus)
|
||||
h.logger.Debug("Health check cache cleared")
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计
|
||||
func (h *HealthChecker) GetCacheStats() map[string]interface{} {
|
||||
h.mutex.RLock()
|
||||
defer h.mutex.RUnlock()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_services": len(h.services),
|
||||
"cached_results": len(h.cache),
|
||||
"cache_ttl_seconds": h.cacheTTL.Seconds(),
|
||||
}
|
||||
|
||||
// 计算缓存命中率
|
||||
if len(h.services) > 0 {
|
||||
hitRate := float64(len(h.cache)) / float64(len(h.services)) * 100
|
||||
stats["cache_hit_rate"] = fmt.Sprintf("%.2f%%", hitRate)
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// SetCacheTTL 设置缓存TTL
|
||||
func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
|
||||
h.mutex.Lock()
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.cacheTTL = ttl
|
||||
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
|
||||
}
|
||||
|
||||
// StartPeriodicCheck 启动定期健康检查
|
||||
func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Duration) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.logger.Info("Stopped periodic health check")
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.performPeriodicCheck(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// performPeriodicCheck 执行定期检查
|
||||
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
|
||||
overall := h.GetOverallStatus(ctx)
|
||||
|
||||
h.logger.Info("Periodic health check completed",
|
||||
zap.String("overall_status", overall.Status),
|
||||
zap.String("message", overall.Message),
|
||||
zap.Int64("response_time_ms", overall.ResponseTime))
|
||||
|
||||
// 如果有服务下线,记录警告
|
||||
if overall.Status != "UP" {
|
||||
h.logger.Warn("Some services are not healthy",
|
||||
zap.String("status", overall.Status),
|
||||
zap.Any("details", overall.Details))
|
||||
}
|
||||
}
|
||||
260
internal/shared/http/response.go
Normal file
260
internal/shared/http/response.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ResponseBuilder 响应构建器实现
|
||||
type ResponseBuilder struct{}
|
||||
|
||||
// NewResponseBuilder 创建响应构建器
|
||||
func NewResponseBuilder() interfaces.ResponseBuilder {
|
||||
return &ResponseBuilder{}
|
||||
}
|
||||
|
||||
// Success 成功响应
|
||||
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Success"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// Created 创建成功响应
|
||||
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Created successfully"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: msg,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, response)
|
||||
}
|
||||
|
||||
// Error 错误响应
|
||||
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
|
||||
// 根据错误类型确定状态码
|
||||
statusCode := http.StatusInternalServerError
|
||||
message := "Internal server error"
|
||||
errorDetail := err.Error()
|
||||
|
||||
// 这里可以根据不同的错误类型设置不同的状态码
|
||||
// 例如:ValidationError -> 400, NotFoundError -> 404, etc.
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
Errors: errorDetail,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// BadRequest 400错误响应
|
||||
func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
if len(errors) > 0 {
|
||||
response.Errors = errors[0]
|
||||
}
|
||||
|
||||
c.JSON(http.StatusBadRequest, response)
|
||||
}
|
||||
|
||||
// Unauthorized 401错误响应
|
||||
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
|
||||
msg := "Unauthorized"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnauthorized, response)
|
||||
}
|
||||
|
||||
// Forbidden 403错误响应
|
||||
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
|
||||
msg := "Forbidden"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusForbidden, response)
|
||||
}
|
||||
|
||||
// NotFound 404错误响应
|
||||
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
|
||||
msg := "Resource not found"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusNotFound, response)
|
||||
}
|
||||
|
||||
// Conflict 409错误响应
|
||||
func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: message,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusConflict, response)
|
||||
}
|
||||
|
||||
// InternalError 500错误响应
|
||||
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
|
||||
msg := "Internal server error"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, response)
|
||||
}
|
||||
|
||||
// Paginated 分页响应
|
||||
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: "Success",
|
||||
Data: data,
|
||||
Pagination: &pagination,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
// getRequestID 从上下文获取请求ID
|
||||
func (r *ResponseBuilder) getRequestID(c *gin.Context) string {
|
||||
if requestID, exists := c.Get("request_id"); exists {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// BuildPagination 构建分页元数据
|
||||
func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
|
||||
totalPages := int(math.Ceil(float64(total) / float64(pageSize)))
|
||||
|
||||
if totalPages < 1 {
|
||||
totalPages = 1
|
||||
}
|
||||
|
||||
return interfaces.PaginationMeta{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Total: total,
|
||||
TotalPages: totalPages,
|
||||
HasNext: page < totalPages,
|
||||
HasPrev: page > 1,
|
||||
}
|
||||
}
|
||||
|
||||
// CustomResponse 自定义响应
|
||||
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: statusCode >= 200 && statusCode < 300,
|
||||
Message: http.StatusText(statusCode),
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(statusCode, response)
|
||||
}
|
||||
|
||||
// ValidationError 验证错误响应
|
||||
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: "Validation failed",
|
||||
Errors: errors,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
|
||||
c.JSON(http.StatusUnprocessableEntity, response)
|
||||
}
|
||||
|
||||
// TooManyRequests 限流错误响应
|
||||
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
|
||||
msg := "Too many requests"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: msg,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
Meta: map[string]interface{}{
|
||||
"retry_after": "60s",
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, response)
|
||||
}
|
||||
258
internal/shared/http/router.go
Normal file
258
internal/shared/http/router.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GinRouter Gin路由器实现
|
||||
type GinRouter struct {
|
||||
engine *gin.Engine
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
middlewares []interfaces.Middleware
|
||||
server *http.Server
|
||||
}
|
||||
|
||||
// NewGinRouter 创建Gin路由器
|
||||
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *GinRouter {
|
||||
// 设置Gin模式
|
||||
if cfg.App.IsProduction() {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
} else {
|
||||
gin.SetMode(gin.DebugMode)
|
||||
}
|
||||
|
||||
// 创建Gin引擎
|
||||
engine := gin.New()
|
||||
|
||||
return &GinRouter{
|
||||
engine: engine,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
middlewares: make([]interfaces.Middleware, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册处理器
|
||||
func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
|
||||
// 应用处理器中间件
|
||||
middlewares := handler.GetMiddlewares()
|
||||
|
||||
// 注册路由
|
||||
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
|
||||
|
||||
r.logger.Info("Registered HTTP handler",
|
||||
zap.String("method", handler.GetMethod()),
|
||||
zap.String("path", handler.GetPath()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterMiddleware 注册中间件
|
||||
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
|
||||
r.middlewares = append(r.middlewares, middleware)
|
||||
|
||||
r.logger.Info("Registered middleware",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterGroup 注册路由组
|
||||
func (r *GinRouter) RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes {
|
||||
return r.engine.Group(prefix, middlewares...)
|
||||
}
|
||||
|
||||
// GetRoutes 获取路由信息
|
||||
func (r *GinRouter) GetRoutes() gin.RoutesInfo {
|
||||
return r.engine.Routes()
|
||||
}
|
||||
|
||||
// Start 启动路由器
|
||||
func (r *GinRouter) Start(addr string) error {
|
||||
// 应用中间件(按优先级排序)
|
||||
r.applyMiddlewares()
|
||||
|
||||
// 创建HTTP服务器
|
||||
r.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: r.engine,
|
||||
ReadTimeout: r.config.Server.ReadTimeout,
|
||||
WriteTimeout: r.config.Server.WriteTimeout,
|
||||
IdleTimeout: r.config.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
|
||||
|
||||
// 启动服务器
|
||||
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止路由器
|
||||
func (r *GinRouter) Stop(ctx context.Context) error {
|
||||
if r.server == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Info("Stopping HTTP server...")
|
||||
|
||||
// 优雅关闭服务器
|
||||
if err := r.server.Shutdown(ctx); err != nil {
|
||||
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("HTTP server stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEngine 获取Gin引擎
|
||||
func (r *GinRouter) GetEngine() *gin.Engine {
|
||||
return r.engine
|
||||
}
|
||||
|
||||
// applyMiddlewares 应用中间件
|
||||
func (r *GinRouter) applyMiddlewares() {
|
||||
// 按优先级排序中间件
|
||||
sort.Slice(r.middlewares, func(i, j int) bool {
|
||||
return r.middlewares[i].GetPriority() > r.middlewares[j].GetPriority()
|
||||
})
|
||||
|
||||
// 应用全局中间件
|
||||
for _, middleware := range r.middlewares {
|
||||
if middleware.IsGlobal() {
|
||||
r.engine.Use(middleware.Handle())
|
||||
r.logger.Debug("Applied global middleware",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetupDefaultRoutes 设置默认路由
|
||||
func (r *GinRouter) SetupDefaultRoutes() {
|
||||
// 健康检查
|
||||
r.engine.GET("/health", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"service": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
})
|
||||
})
|
||||
|
||||
// API信息
|
||||
r.engine.GET("/info", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
"environment": r.config.App.Env,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// 404处理
|
||||
r.engine.NoRoute(func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Route not found",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
|
||||
// 405处理
|
||||
r.engine.NoMethod(func(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"success": false,
|
||||
"message": "Method not allowed",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// PrintRoutes 打印路由信息
|
||||
func (r *GinRouter) PrintRoutes() {
|
||||
routes := r.GetRoutes()
|
||||
|
||||
r.logger.Info("Registered routes:")
|
||||
for _, route := range routes {
|
||||
r.logger.Info("Route",
|
||||
zap.String("method", route.Method),
|
||||
zap.String("path", route.Path),
|
||||
zap.String("handler", route.Handler))
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取路由器统计信息
|
||||
func (r *GinRouter) GetStats() map[string]interface{} {
|
||||
routes := r.GetRoutes()
|
||||
|
||||
stats := map[string]interface{}{
|
||||
"total_routes": len(routes),
|
||||
"total_middlewares": len(r.middlewares),
|
||||
"server_config": map[string]interface{}{
|
||||
"read_timeout": r.config.Server.ReadTimeout,
|
||||
"write_timeout": r.config.Server.WriteTimeout,
|
||||
"idle_timeout": r.config.Server.IdleTimeout,
|
||||
},
|
||||
}
|
||||
|
||||
// 按方法统计路由数量
|
||||
methodStats := make(map[string]int)
|
||||
for _, route := range routes {
|
||||
methodStats[route.Method]++
|
||||
}
|
||||
stats["routes_by_method"] = methodStats
|
||||
|
||||
// 中间件统计
|
||||
middlewareStats := make([]map[string]interface{}, 0, len(r.middlewares))
|
||||
for _, middleware := range r.middlewares {
|
||||
middlewareStats = append(middlewareStats, map[string]interface{}{
|
||||
"name": middleware.GetName(),
|
||||
"priority": middleware.GetPriority(),
|
||||
"global": middleware.IsGlobal(),
|
||||
})
|
||||
}
|
||||
stats["middlewares"] = middlewareStats
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// EnableMetrics 启用指标收集
|
||||
func (r *GinRouter) EnableMetrics(collector interfaces.MetricsCollector) {
|
||||
r.engine.Use(func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
duration := time.Since(start).Seconds()
|
||||
collector.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
|
||||
})
|
||||
}
|
||||
|
||||
// EnableProfiling 启用性能分析
|
||||
func (r *GinRouter) EnableProfiling() {
|
||||
if r.config.Development.EnableProfiler {
|
||||
// 这里可以集成pprof
|
||||
r.logger.Info("Profiling enabled")
|
||||
}
|
||||
}
|
||||
273
internal/shared/http/validator.go
Normal file
273
internal/shared/http/validator.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
// RequestValidator 请求验证器实现
|
||||
type RequestValidator struct {
|
||||
validator *validator.Validate
|
||||
response interfaces.ResponseBuilder
|
||||
}
|
||||
|
||||
// NewRequestValidator 创建请求验证器
|
||||
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
|
||||
v := validator.New()
|
||||
|
||||
// 注册自定义验证器
|
||||
registerCustomValidators(v)
|
||||
|
||||
return &RequestValidator{
|
||||
validator: v,
|
||||
response: response,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证请求体
|
||||
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateQuery 验证查询参数
|
||||
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindQuery(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid query parameters", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateParam 验证路径参数
|
||||
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindUri(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid path parameters", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindAndValidate 绑定并验证请求
|
||||
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
|
||||
// 绑定请求体
|
||||
if err := c.ShouldBindJSON(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid request body", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证数据
|
||||
return v.Validate(c, dto)
|
||||
}
|
||||
|
||||
// formatValidationErrors 格式化验证错误
|
||||
func (v *RequestValidator) formatValidationErrors(err error) map[string][]string {
|
||||
errors := make(map[string][]string)
|
||||
|
||||
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, fieldError := range validationErrors {
|
||||
fieldName := v.getFieldName(fieldError)
|
||||
errorMessage := v.getErrorMessage(fieldError)
|
||||
|
||||
if _, exists := errors[fieldName]; !exists {
|
||||
errors[fieldName] = []string{}
|
||||
}
|
||||
errors[fieldName] = append(errors[fieldName], errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// getFieldName 获取字段名(JSON标签优先)
|
||||
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
|
||||
// 可以通过反射获取JSON标签,这里简化处理
|
||||
fieldName := fieldError.Field()
|
||||
|
||||
// 转换为snake_case(可选)
|
||||
return v.toSnakeCase(fieldName)
|
||||
}
|
||||
|
||||
// getErrorMessage 获取错误消息
|
||||
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
|
||||
field := fieldError.Field()
|
||||
tag := fieldError.Tag()
|
||||
param := fieldError.Param()
|
||||
|
||||
switch tag {
|
||||
case "required":
|
||||
return fmt.Sprintf("%s is required", field)
|
||||
case "email":
|
||||
return fmt.Sprintf("%s must be a valid email address", field)
|
||||
case "min":
|
||||
return fmt.Sprintf("%s must be at least %s characters", field, param)
|
||||
case "max":
|
||||
return fmt.Sprintf("%s must be at most %s characters", field, param)
|
||||
case "len":
|
||||
return fmt.Sprintf("%s must be exactly %s characters", field, param)
|
||||
case "gt":
|
||||
return fmt.Sprintf("%s must be greater than %s", field, param)
|
||||
case "gte":
|
||||
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
|
||||
case "lt":
|
||||
return fmt.Sprintf("%s must be less than %s", field, param)
|
||||
case "lte":
|
||||
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
|
||||
case "oneof":
|
||||
return fmt.Sprintf("%s must be one of [%s]", field, param)
|
||||
case "url":
|
||||
return fmt.Sprintf("%s must be a valid URL", field)
|
||||
case "alpha":
|
||||
return fmt.Sprintf("%s must contain only alphabetic characters", field)
|
||||
case "alphanum":
|
||||
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
|
||||
case "numeric":
|
||||
return fmt.Sprintf("%s must be numeric", field)
|
||||
case "phone":
|
||||
return fmt.Sprintf("%s must be a valid phone number", field)
|
||||
case "username":
|
||||
return fmt.Sprintf("%s must be a valid username", field)
|
||||
default:
|
||||
return fmt.Sprintf("%s is invalid", field)
|
||||
}
|
||||
}
|
||||
|
||||
// toSnakeCase 转换为snake_case
|
||||
func (v *RequestValidator) toSnakeCase(str string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range str {
|
||||
if i > 0 && (r >= 'A' && r <= 'Z') {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// registerCustomValidators 注册自定义验证器
|
||||
func registerCustomValidators(v *validator.Validate) {
|
||||
// 注册手机号验证器
|
||||
v.RegisterValidation("phone", validatePhone)
|
||||
|
||||
// 注册用户名验证器
|
||||
v.RegisterValidation("username", validateUsername)
|
||||
|
||||
// 注册密码强度验证器
|
||||
v.RegisterValidation("strong_password", validateStrongPassword)
|
||||
}
|
||||
|
||||
// validatePhone 验证手机号
|
||||
func validatePhone(fl validator.FieldLevel) bool {
|
||||
phone := fl.Field().String()
|
||||
if phone == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 简单的手机号验证(可根据需要完善)
|
||||
if len(phone) < 10 || len(phone) > 15 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否以+开头或全是数字
|
||||
if strings.HasPrefix(phone, "+") {
|
||||
phone = phone[1:]
|
||||
}
|
||||
|
||||
for _, r := range phone {
|
||||
if r < '0' || r > '9' {
|
||||
if r != '-' && r != ' ' && r != '(' && r != ')' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateUsername 验证用户名
|
||||
func validateUsername(fl validator.FieldLevel) bool {
|
||||
username := fl.Field().String()
|
||||
if username == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 用户名规则:3-30个字符,只能包含字母、数字、下划线,不能以数字开头
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不能以数字开头
|
||||
if username[0] >= '0' && username[0] <= '9' {
|
||||
return false
|
||||
}
|
||||
|
||||
// 只能包含字母、数字、下划线
|
||||
for _, r := range username {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateStrongPassword 验证密码强度
|
||||
func validateStrongPassword(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
if password == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 密码强度规则:至少8个字符,包含大小写字母、数字
|
||||
if len(password) < 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
|
||||
for _, r := range password {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
hasUpper = true
|
||||
case r >= 'a' && r <= 'z':
|
||||
hasLower = true
|
||||
case r >= '0' && r <= '9':
|
||||
hasDigit = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasUpper && hasLower && hasDigit
|
||||
}
|
||||
|
||||
// ValidateStruct 直接验证结构体(不通过HTTP上下文)
|
||||
func (v *RequestValidator) ValidateStruct(dto interface{}) error {
|
||||
return v.validator.Struct(dto)
|
||||
}
|
||||
|
||||
// GetValidator 获取原始验证器(用于特殊情况)
|
||||
func (v *RequestValidator) GetValidator() *validator.Validate {
|
||||
return v.validator
|
||||
}
|
||||
92
internal/shared/interfaces/event.go
Normal file
92
internal/shared/interfaces/event.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Event 事件接口
|
||||
type Event interface {
|
||||
// 事件基础信息
|
||||
GetID() string
|
||||
GetType() string
|
||||
GetVersion() string
|
||||
GetTimestamp() time.Time
|
||||
|
||||
// 事件数据
|
||||
GetPayload() interface{}
|
||||
GetMetadata() map[string]interface{}
|
||||
|
||||
// 事件来源
|
||||
GetSource() string
|
||||
GetAggregateID() string
|
||||
GetAggregateType() string
|
||||
|
||||
// 序列化
|
||||
Marshal() ([]byte, error)
|
||||
Unmarshal(data []byte) error
|
||||
}
|
||||
|
||||
// EventHandler 事件处理器接口
|
||||
type EventHandler interface {
|
||||
// 处理器标识
|
||||
GetName() string
|
||||
GetEventTypes() []string
|
||||
|
||||
// 事件处理
|
||||
Handle(ctx context.Context, event Event) error
|
||||
|
||||
// 处理器配置
|
||||
IsAsync() bool
|
||||
GetRetryConfig() RetryConfig
|
||||
}
|
||||
|
||||
// DomainEvent 领域事件基础接口
|
||||
type DomainEvent interface {
|
||||
Event
|
||||
|
||||
// 领域特定信息
|
||||
GetDomainVersion() string
|
||||
GetCausationID() string
|
||||
GetCorrelationID() string
|
||||
}
|
||||
|
||||
// RetryConfig 重试配置
|
||||
type RetryConfig struct {
|
||||
MaxRetries int `json:"max_retries"`
|
||||
RetryDelay time.Duration `json:"retry_delay"`
|
||||
BackoffFactor float64 `json:"backoff_factor"`
|
||||
MaxDelay time.Duration `json:"max_delay"`
|
||||
}
|
||||
|
||||
// EventStore 事件存储接口
|
||||
type EventStore interface {
|
||||
// 事件存储
|
||||
SaveEvent(ctx context.Context, event Event) error
|
||||
SaveEvents(ctx context.Context, events []Event) error
|
||||
|
||||
// 事件查询
|
||||
GetEvents(ctx context.Context, aggregateID string, fromVersion int) ([]Event, error)
|
||||
GetEventsByType(ctx context.Context, eventType string, limit int) ([]Event, error)
|
||||
GetEventsSince(ctx context.Context, timestamp time.Time, limit int) ([]Event, error)
|
||||
|
||||
// 快照支持
|
||||
SaveSnapshot(ctx context.Context, aggregateID string, snapshot interface{}) error
|
||||
GetSnapshot(ctx context.Context, aggregateID string) (interface{}, error)
|
||||
}
|
||||
|
||||
// EventBus 事件总线接口
|
||||
type EventBus interface {
|
||||
// 事件发布
|
||||
Publish(ctx context.Context, event Event) error
|
||||
PublishBatch(ctx context.Context, events []Event) error
|
||||
|
||||
// 事件订阅
|
||||
Subscribe(eventType string, handler EventHandler) error
|
||||
Unsubscribe(eventType string, handler EventHandler) error
|
||||
|
||||
// 订阅管理
|
||||
GetSubscribers(eventType string) []EventHandler
|
||||
Start(ctx context.Context) error
|
||||
Stop(ctx context.Context) error
|
||||
}
|
||||
152
internal/shared/interfaces/http.go
Normal file
152
internal/shared/interfaces/http.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// HTTPHandler HTTP处理器接口
|
||||
type HTTPHandler interface {
|
||||
// 处理器信息
|
||||
GetPath() string
|
||||
GetMethod() string
|
||||
GetMiddlewares() []gin.HandlerFunc
|
||||
|
||||
// 处理函数
|
||||
Handle(c *gin.Context)
|
||||
|
||||
// 权限验证
|
||||
RequiresAuth() bool
|
||||
GetPermissions() []string
|
||||
}
|
||||
|
||||
// RESTHandler REST风格处理器接口
|
||||
type RESTHandler interface {
|
||||
HTTPHandler
|
||||
|
||||
// CRUD操作
|
||||
Create(c *gin.Context)
|
||||
GetByID(c *gin.Context)
|
||||
Update(c *gin.Context)
|
||||
Delete(c *gin.Context)
|
||||
List(c *gin.Context)
|
||||
}
|
||||
|
||||
// Middleware 中间件接口
|
||||
type Middleware interface {
|
||||
// 中间件名称
|
||||
GetName() string
|
||||
// 中间件优先级
|
||||
GetPriority() int
|
||||
// 中间件处理函数
|
||||
Handle() gin.HandlerFunc
|
||||
// 是否全局中间件
|
||||
IsGlobal() bool
|
||||
}
|
||||
|
||||
// Router 路由器接口
|
||||
type Router interface {
|
||||
// 路由注册
|
||||
RegisterHandler(handler HTTPHandler) error
|
||||
RegisterMiddleware(middleware Middleware) error
|
||||
RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes
|
||||
|
||||
// 路由管理
|
||||
GetRoutes() gin.RoutesInfo
|
||||
Start(addr string) error
|
||||
Stop(ctx context.Context) error
|
||||
|
||||
// 引擎获取
|
||||
GetEngine() *gin.Engine
|
||||
}
|
||||
|
||||
// ResponseBuilder 响应构建器接口
|
||||
type ResponseBuilder interface {
|
||||
// 成功响应
|
||||
Success(c *gin.Context, data interface{}, message ...string)
|
||||
Created(c *gin.Context, data interface{}, message ...string)
|
||||
|
||||
// 错误响应
|
||||
Error(c *gin.Context, err error)
|
||||
BadRequest(c *gin.Context, message string, errors ...interface{})
|
||||
Unauthorized(c *gin.Context, message ...string)
|
||||
Forbidden(c *gin.Context, message ...string)
|
||||
NotFound(c *gin.Context, message ...string)
|
||||
Conflict(c *gin.Context, message string)
|
||||
InternalError(c *gin.Context, message ...string)
|
||||
|
||||
// 分页响应
|
||||
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
|
||||
}
|
||||
|
||||
// RequestValidator 请求验证器接口
|
||||
type RequestValidator interface {
|
||||
// 验证请求
|
||||
Validate(c *gin.Context, dto interface{}) error
|
||||
ValidateQuery(c *gin.Context, dto interface{}) error
|
||||
ValidateParam(c *gin.Context, dto interface{}) error
|
||||
|
||||
// 绑定和验证
|
||||
BindAndValidate(c *gin.Context, dto interface{}) error
|
||||
}
|
||||
|
||||
// PaginationMeta 分页元数据
|
||||
type PaginationMeta struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Total int64 `json:"total"`
|
||||
TotalPages int `json:"total_pages"`
|
||||
HasNext bool `json:"has_next"`
|
||||
HasPrev bool `json:"has_prev"`
|
||||
}
|
||||
|
||||
// APIResponse 标准API响应结构
|
||||
type APIResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
Errors interface{} `json:"errors,omitempty"`
|
||||
Pagination *PaginationMeta `json:"pagination,omitempty"`
|
||||
Meta map[string]interface{} `json:"meta,omitempty"`
|
||||
RequestID string `json:"request_id"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
// HealthChecker 健康检查器接口
|
||||
type HealthChecker interface {
|
||||
// 健康检查
|
||||
CheckHealth(ctx context.Context) HealthStatus
|
||||
GetName() string
|
||||
GetDependencies() []string
|
||||
}
|
||||
|
||||
// HealthStatus 健康状态
|
||||
type HealthStatus struct {
|
||||
Status string `json:"status"` // UP, DOWN, DEGRADED
|
||||
Message string `json:"message"`
|
||||
Details map[string]interface{} `json:"details"`
|
||||
CheckedAt int64 `json:"checked_at"`
|
||||
ResponseTime int64 `json:"response_time_ms"`
|
||||
}
|
||||
|
||||
// MetricsCollector 指标收集器接口
|
||||
type MetricsCollector interface {
|
||||
// HTTP指标
|
||||
RecordHTTPRequest(method, path string, status int, duration float64)
|
||||
RecordHTTPDuration(method, path string, duration float64)
|
||||
|
||||
// 业务指标
|
||||
IncrementCounter(name string, labels map[string]string)
|
||||
RecordGauge(name string, value float64, labels map[string]string)
|
||||
RecordHistogram(name string, value float64, labels map[string]string)
|
||||
|
||||
// 自定义指标
|
||||
RegisterCounter(name, help string, labels []string) error
|
||||
RegisterGauge(name, help string, labels []string) error
|
||||
RegisterHistogram(name, help string, labels []string, buckets []float64) error
|
||||
|
||||
// 指标导出
|
||||
GetHandler() http.Handler
|
||||
}
|
||||
74
internal/shared/interfaces/repository.go
Normal file
74
internal/shared/interfaces/repository.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Entity 通用实体接口
|
||||
type Entity interface {
|
||||
GetID() string
|
||||
GetCreatedAt() time.Time
|
||||
GetUpdatedAt() time.Time
|
||||
}
|
||||
|
||||
// BaseRepository 基础仓储接口
|
||||
type BaseRepository interface {
|
||||
// 基础操作
|
||||
Delete(ctx context.Context, id string) error
|
||||
Count(ctx context.Context, options CountOptions) (int64, error)
|
||||
Exists(ctx context.Context, id string) (bool, error)
|
||||
|
||||
// 软删除支持
|
||||
SoftDelete(ctx context.Context, id string) error
|
||||
Restore(ctx context.Context, id string) error
|
||||
}
|
||||
|
||||
// Repository 通用仓储接口,支持泛型
|
||||
type Repository[T any] interface {
|
||||
BaseRepository
|
||||
|
||||
// 基础CRUD操作
|
||||
Create(ctx context.Context, entity T) error
|
||||
GetByID(ctx context.Context, id string) (T, error)
|
||||
Update(ctx context.Context, entity T) error
|
||||
|
||||
// 批量操作
|
||||
CreateBatch(ctx context.Context, entities []T) error
|
||||
GetByIDs(ctx context.Context, ids []string) ([]T, error)
|
||||
UpdateBatch(ctx context.Context, entities []T) error
|
||||
DeleteBatch(ctx context.Context, ids []string) error
|
||||
|
||||
// 查询操作
|
||||
List(ctx context.Context, options ListOptions) ([]T, error)
|
||||
|
||||
// 事务支持
|
||||
WithTx(tx interface{}) Repository[T]
|
||||
}
|
||||
|
||||
// ListOptions 列表查询选项
|
||||
type ListOptions struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Sort string `json:"sort"`
|
||||
Order string `json:"order"`
|
||||
Filters map[string]interface{} `json:"filters"`
|
||||
Search string `json:"search"`
|
||||
Include []string `json:"include"`
|
||||
}
|
||||
|
||||
// CountOptions 计数查询选项
|
||||
type CountOptions struct {
|
||||
Filters map[string]interface{} `json:"filters"`
|
||||
Search string `json:"search"`
|
||||
}
|
||||
|
||||
// CachedRepository 支持缓存的仓储接口
|
||||
type CachedRepository[T Entity] interface {
|
||||
Repository[T]
|
||||
|
||||
// 缓存操作
|
||||
InvalidateCache(ctx context.Context, keys ...string) error
|
||||
WarmupCache(ctx context.Context) error
|
||||
GetCacheKey(id string) string
|
||||
}
|
||||
101
internal/shared/interfaces/service.go
Normal file
101
internal/shared/interfaces/service.go
Normal file
@@ -0,0 +1,101 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Service 通用服务接口
|
||||
type Service interface {
|
||||
// 服务名称
|
||||
Name() string
|
||||
// 服务初始化
|
||||
Initialize(ctx context.Context) error
|
||||
// 服务健康检查
|
||||
HealthCheck(ctx context.Context) error
|
||||
// 服务关闭
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
// DomainService 领域服务接口,支持泛型
|
||||
type DomainService[T Entity] interface {
|
||||
Service
|
||||
|
||||
// 基础业务操作
|
||||
Create(ctx context.Context, dto interface{}) (*T, error)
|
||||
GetByID(ctx context.Context, id string) (*T, error)
|
||||
Update(ctx context.Context, id string, dto interface{}) (*T, error)
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// 列表和查询
|
||||
List(ctx context.Context, options ListOptions) ([]*T, error)
|
||||
Search(ctx context.Context, query string, options ListOptions) ([]*T, error)
|
||||
Count(ctx context.Context, options CountOptions) (int64, error)
|
||||
|
||||
// 业务规则验证
|
||||
Validate(ctx context.Context, entity *T) error
|
||||
ValidateCreate(ctx context.Context, dto interface{}) error
|
||||
ValidateUpdate(ctx context.Context, id string, dto interface{}) error
|
||||
}
|
||||
|
||||
// EventService 事件服务接口
|
||||
type EventService interface {
|
||||
Service
|
||||
|
||||
// 事件发布
|
||||
Publish(ctx context.Context, event Event) error
|
||||
PublishBatch(ctx context.Context, events []Event) error
|
||||
|
||||
// 事件订阅
|
||||
Subscribe(eventType string, handler EventHandler) error
|
||||
Unsubscribe(eventType string, handler EventHandler) error
|
||||
|
||||
// 异步处理
|
||||
PublishAsync(ctx context.Context, event Event) error
|
||||
}
|
||||
|
||||
// CacheService 缓存服务接口
|
||||
type CacheService interface {
|
||||
Service
|
||||
|
||||
// 基础缓存操作
|
||||
Get(ctx context.Context, key string, dest interface{}) error
|
||||
Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error
|
||||
Delete(ctx context.Context, keys ...string) error
|
||||
Exists(ctx context.Context, key string) (bool, error)
|
||||
|
||||
// 批量操作
|
||||
GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error)
|
||||
SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error
|
||||
|
||||
// 模式操作
|
||||
DeletePattern(ctx context.Context, pattern string) error
|
||||
Keys(ctx context.Context, pattern string) ([]string, error)
|
||||
|
||||
// 缓存统计
|
||||
Stats(ctx context.Context) (CacheStats, error)
|
||||
}
|
||||
|
||||
// CacheStats 缓存统计信息
|
||||
type CacheStats struct {
|
||||
Hits int64 `json:"hits"`
|
||||
Misses int64 `json:"misses"`
|
||||
Keys int64 `json:"keys"`
|
||||
Memory int64 `json:"memory"`
|
||||
Connections int64 `json:"connections"`
|
||||
}
|
||||
|
||||
// TransactionService 事务服务接口
|
||||
type TransactionService interface {
|
||||
Service
|
||||
|
||||
// 事务操作
|
||||
Begin(ctx context.Context) (Transaction, error)
|
||||
RunInTransaction(ctx context.Context, fn func(Transaction) error) error
|
||||
}
|
||||
|
||||
// Transaction 事务接口
|
||||
type Transaction interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
GetDB() interface{}
|
||||
}
|
||||
241
internal/shared/logger/logger.go
Normal file
241
internal/shared/logger/logger.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// Logger 日志接口
|
||||
type Logger interface {
|
||||
Debug(msg string, fields ...zapcore.Field)
|
||||
Info(msg string, fields ...zapcore.Field)
|
||||
Warn(msg string, fields ...zapcore.Field)
|
||||
Error(msg string, fields ...zapcore.Field)
|
||||
Fatal(msg string, fields ...zapcore.Field)
|
||||
Panic(msg string, fields ...zapcore.Field)
|
||||
|
||||
With(fields ...zapcore.Field) Logger
|
||||
WithContext(ctx context.Context) Logger
|
||||
Sync() error
|
||||
}
|
||||
|
||||
// ZapLogger Zap日志实现
|
||||
type ZapLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// Config 日志配置
|
||||
type Config struct {
|
||||
Level string
|
||||
Format string
|
||||
Output string
|
||||
FilePath string
|
||||
MaxSize int
|
||||
MaxBackups int
|
||||
MaxAge int
|
||||
Compress bool
|
||||
}
|
||||
|
||||
// NewLogger 创建新的日志实例
|
||||
func NewLogger(config Config) (Logger, error) {
|
||||
// 设置日志级别
|
||||
level, err := zapcore.ParseLevel(config.Level)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的日志级别: %w", err)
|
||||
}
|
||||
|
||||
// 配置编码器
|
||||
var encoder zapcore.Encoder
|
||||
encoderConfig := getEncoderConfig()
|
||||
|
||||
switch config.Format {
|
||||
case "json":
|
||||
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||
case "console":
|
||||
encoder = zapcore.NewConsoleEncoder(encoderConfig)
|
||||
default:
|
||||
encoder = zapcore.NewJSONEncoder(encoderConfig)
|
||||
}
|
||||
|
||||
// 配置输出
|
||||
var writeSyncer zapcore.WriteSyncer
|
||||
switch config.Output {
|
||||
case "stdout":
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
case "stderr":
|
||||
writeSyncer = zapcore.AddSync(os.Stderr)
|
||||
case "file":
|
||||
if config.FilePath == "" {
|
||||
config.FilePath = "logs/app.log"
|
||||
}
|
||||
// 确保目录存在
|
||||
if err := os.MkdirAll("logs", 0755); err != nil {
|
||||
return nil, fmt.Errorf("创建日志目录失败: %w", err)
|
||||
}
|
||||
|
||||
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开日志文件失败: %w", err)
|
||||
}
|
||||
writeSyncer = zapcore.AddSync(file)
|
||||
default:
|
||||
writeSyncer = zapcore.AddSync(os.Stdout)
|
||||
}
|
||||
|
||||
// 创建核心
|
||||
core := zapcore.NewCore(encoder, writeSyncer, level)
|
||||
|
||||
// 创建logger
|
||||
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
|
||||
|
||||
return &ZapLogger{
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getEncoderConfig 获取编码器配置
|
||||
func getEncoderConfig() zapcore.EncoderConfig {
|
||||
return zapcore.EncoderConfig{
|
||||
TimeKey: "timestamp",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
FunctionKey: zapcore.OmitKey,
|
||||
MessageKey: "message",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug 调试日志
|
||||
func (l *ZapLogger) Debug(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Debug(msg, fields...)
|
||||
}
|
||||
|
||||
// Info 信息日志
|
||||
func (l *ZapLogger) Info(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Info(msg, fields...)
|
||||
}
|
||||
|
||||
// Warn 警告日志
|
||||
func (l *ZapLogger) Warn(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Warn(msg, fields...)
|
||||
}
|
||||
|
||||
// Error 错误日志
|
||||
func (l *ZapLogger) Error(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Error(msg, fields...)
|
||||
}
|
||||
|
||||
// Fatal 致命错误日志
|
||||
func (l *ZapLogger) Fatal(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Fatal(msg, fields...)
|
||||
}
|
||||
|
||||
// Panic 恐慌日志
|
||||
func (l *ZapLogger) Panic(msg string, fields ...zapcore.Field) {
|
||||
l.logger.Panic(msg, fields...)
|
||||
}
|
||||
|
||||
// With 添加字段
|
||||
func (l *ZapLogger) With(fields ...zapcore.Field) Logger {
|
||||
return &ZapLogger{
|
||||
logger: l.logger.With(fields...),
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext 从上下文添加字段
|
||||
func (l *ZapLogger) WithContext(ctx context.Context) Logger {
|
||||
// 从上下文中提取常用字段
|
||||
fields := []zapcore.Field{}
|
||||
|
||||
if traceID := getTraceIDFromContext(ctx); traceID != "" {
|
||||
fields = append(fields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
if userID := getUserIDFromContext(ctx); userID != "" {
|
||||
fields = append(fields, zap.String("user_id", userID))
|
||||
}
|
||||
|
||||
if requestID := getRequestIDFromContext(ctx); requestID != "" {
|
||||
fields = append(fields, zap.String("request_id", requestID))
|
||||
}
|
||||
|
||||
return l.With(fields...)
|
||||
}
|
||||
|
||||
// Sync 同步日志
|
||||
func (l *ZapLogger) Sync() error {
|
||||
return l.logger.Sync()
|
||||
}
|
||||
|
||||
// getTraceIDFromContext 从上下文获取追踪ID
|
||||
func getTraceIDFromContext(ctx context.Context) string {
|
||||
if traceID := ctx.Value("trace_id"); traceID != nil {
|
||||
if id, ok := traceID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getUserIDFromContext 从上下文获取用户ID
|
||||
func getUserIDFromContext(ctx context.Context) string {
|
||||
if userID := ctx.Value("user_id"); userID != nil {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRequestIDFromContext 从上下文获取请求ID
|
||||
func getRequestIDFromContext(ctx context.Context) string {
|
||||
if requestID := ctx.Value("request_id"); requestID != nil {
|
||||
if id, ok := requestID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Field 创建日志字段的便捷函数
|
||||
func String(key, val string) zapcore.Field {
|
||||
return zap.String(key, val)
|
||||
}
|
||||
|
||||
func Int(key string, val int) zapcore.Field {
|
||||
return zap.Int(key, val)
|
||||
}
|
||||
|
||||
func Int64(key string, val int64) zapcore.Field {
|
||||
return zap.Int64(key, val)
|
||||
}
|
||||
|
||||
func Float64(key string, val float64) zapcore.Field {
|
||||
return zap.Float64(key, val)
|
||||
}
|
||||
|
||||
func Bool(key string, val bool) zapcore.Field {
|
||||
return zap.Bool(key, val)
|
||||
}
|
||||
|
||||
func Error(err error) zapcore.Field {
|
||||
return zap.Error(err)
|
||||
}
|
||||
|
||||
func Any(key string, val interface{}) zapcore.Field {
|
||||
return zap.Any(key, val)
|
||||
}
|
||||
|
||||
func Duration(key string, val interface{}) zapcore.Field {
|
||||
return zap.Any(key, val)
|
||||
}
|
||||
261
internal/shared/middleware/auth.go
Normal file
261
internal/shared/middleware/auth.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// JWTAuthMiddleware JWT认证中间件
|
||||
type JWTAuthMiddleware struct {
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewJWTAuthMiddleware 创建JWT认证中间件
|
||||
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
|
||||
return &JWTAuthMiddleware{
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *JWTAuthMiddleware) GetName() string {
|
||||
return "jwt_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *JWTAuthMiddleware) GetPriority() int {
|
||||
return 60 // 中等优先级,在日志之后,业务处理之前
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "Missing authorization header")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "Invalid authorization header format")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "Missing token")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("Invalid token",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "Invalid token")
|
||||
return
|
||||
}
|
||||
|
||||
// 将用户信息添加到上下文
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *JWTAuthMiddleware) IsGlobal() bool {
|
||||
return false // 不是全局中间件,需要手动应用到需要认证的路由
|
||||
}
|
||||
|
||||
// JWTClaims JWT声明结构
|
||||
type JWTClaims struct {
|
||||
UserID string `json:"user_id"`
|
||||
Username string `json:"username"`
|
||||
Email string `json:"email"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// validateToken 验证JWT token
|
||||
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 验证签名方法
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*JWTClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// respondUnauthorized 返回未授权响应
|
||||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Unauthorized",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
// GenerateToken 生成JWT token
|
||||
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &JWTClaims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Email: email,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "tyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"tyapi-client"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// GenerateRefreshToken 生成刷新token
|
||||
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
|
||||
now := time.Now()
|
||||
|
||||
claims := &jwt.RegisteredClaims{
|
||||
Issuer: "tyapi-server",
|
||||
Subject: userID,
|
||||
Audience: []string{"tyapi-refresh"},
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(m.config.JWT.Secret))
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 验证刷新token
|
||||
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, jwt.ErrSignatureInvalid
|
||||
}
|
||||
return []byte(m.config.JWT.Secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(*jwt.RegisteredClaims)
|
||||
if !ok || !token.Valid {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
// 检查是否为刷新token
|
||||
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
|
||||
return "", jwt.ErrSignatureInvalid
|
||||
}
|
||||
|
||||
return claims.Subject, nil
|
||||
}
|
||||
|
||||
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
|
||||
type OptionalAuthMiddleware struct {
|
||||
jwtAuth *JWTAuthMiddleware
|
||||
}
|
||||
|
||||
// NewOptionalAuthMiddleware 创建可选认证中间件
|
||||
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
|
||||
return &OptionalAuthMiddleware{
|
||||
jwtAuth: jwtAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *OptionalAuthMiddleware) GetName() string {
|
||||
return "optional_auth"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *OptionalAuthMiddleware) GetPriority() int {
|
||||
return 60 // 与JWT认证中间件相同
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
// 没有认证头部,设置匿名用户标识
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 提取并验证token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
claims, err := m.jwtAuth.validateToken(tokenString)
|
||||
if err != nil {
|
||||
// token无效,但不返回错误,设置为未认证
|
||||
c.Set("is_authenticated", false)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// token有效,设置用户信息
|
||||
c.Set("is_authenticated", true)
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("email", claims.Email)
|
||||
c.Set("token_claims", claims)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *OptionalAuthMiddleware) IsGlobal() bool {
|
||||
return false
|
||||
}
|
||||
104
internal/shared/middleware/cors.go
Normal file
104
internal/shared/middleware/cors.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// CORSMiddleware CORS中间件
|
||||
type CORSMiddleware struct {
|
||||
config *config.Config
|
||||
}
|
||||
|
||||
// NewCORSMiddleware 创建CORS中间件
|
||||
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
|
||||
return &CORSMiddleware{
|
||||
config: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *CORSMiddleware) GetName() string {
|
||||
return "cors"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *CORSMiddleware) GetPriority() int {
|
||||
return 100 // 高优先级,最先执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.config.Development.EnableCors {
|
||||
// 如果没有启用CORS,返回空处理函数
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
config := cors.Config{
|
||||
AllowAllOrigins: false,
|
||||
AllowOrigins: m.getAllowedOrigins(),
|
||||
AllowMethods: m.getAllowedMethods(),
|
||||
AllowHeaders: m.getAllowedHeaders(),
|
||||
ExposeHeaders: []string{
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"X-Request-ID",
|
||||
"X-Response-Time",
|
||||
},
|
||||
AllowCredentials: true,
|
||||
MaxAge: 86400, // 24小时
|
||||
}
|
||||
|
||||
return cors.New(config)
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *CORSMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getAllowedOrigins 获取允许的来源
|
||||
func (m *CORSMiddleware) getAllowedOrigins() []string {
|
||||
if m.config.Development.CorsOrigins == "" {
|
||||
return []string{"http://localhost:3000", "http://localhost:8080"}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的origins字符串
|
||||
return []string{m.config.Development.CorsOrigins}
|
||||
}
|
||||
|
||||
// getAllowedMethods 获取允许的方法
|
||||
func (m *CORSMiddleware) getAllowedMethods() []string {
|
||||
if m.config.Development.CorsMethods == "" {
|
||||
return []string{
|
||||
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的methods字符串
|
||||
return []string{m.config.Development.CorsMethods}
|
||||
}
|
||||
|
||||
// getAllowedHeaders 获取允许的头部
|
||||
func (m *CORSMiddleware) getAllowedHeaders() []string {
|
||||
if m.config.Development.CorsHeaders == "" {
|
||||
return []string{
|
||||
"Origin",
|
||||
"Content-Length",
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"X-Requested-With",
|
||||
"Accept",
|
||||
"Accept-Encoding",
|
||||
"Accept-Language",
|
||||
"X-Request-ID",
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 解析配置中的headers字符串
|
||||
return []string{m.config.Development.CorsHeaders}
|
||||
}
|
||||
166
internal/shared/middleware/ratelimit.go
Normal file
166
internal/shared/middleware/ratelimit.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RateLimitMiddleware) GetName() string {
|
||||
return "ratelimit"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RateLimitMiddleware) GetPriority() int {
|
||||
return 90 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取客户端标识(IP地址)
|
||||
clientID := m.getClientID(c)
|
||||
|
||||
// 获取或创建限流器
|
||||
limiter := m.getLimiter(clientID)
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
c.Header("Retry-After", "60")
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": "Rate limit exceeded",
|
||||
"error": "Too many requests",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RateLimitMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// getClientID 获取客户端标识
|
||||
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
|
||||
// 优先使用X-Forwarded-For头部
|
||||
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
|
||||
return xff
|
||||
}
|
||||
|
||||
// 使用X-Real-IP头部
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return xri
|
||||
}
|
||||
|
||||
// 使用RemoteAddr
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// getLimiter 获取或创建限流器
|
||||
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
|
||||
m.mutex.RLock()
|
||||
limiter, exists := m.limiters[clientID]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists := m.limiters[clientID]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
// 创建新的限流器
|
||||
// rate.Every计算每个请求之间的间隔
|
||||
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
|
||||
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
|
||||
|
||||
m.limiters[clientID] = limiter
|
||||
|
||||
// 启动清理协程(仅第一次创建时)
|
||||
if len(m.limiters) == 1 {
|
||||
go m.cleanupRoutine()
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupRoutine 定期清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanupRoutine() {
|
||||
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
m.cleanup()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup 清理不活跃的限流器
|
||||
func (m *RateLimitMiddleware) cleanup() {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for clientID, limiter := range m.limiters {
|
||||
// 如果限流器在过去1小时内没有被使用,则删除它
|
||||
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
|
||||
delete(m.limiters, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetStats 获取限流统计
|
||||
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"active_limiters": len(m.limiters),
|
||||
"rate_limit": map[string]interface{}{
|
||||
"requests": m.config.RateLimit.Requests,
|
||||
"window": m.config.RateLimit.Window,
|
||||
"burst": m.config.RateLimit.Burst,
|
||||
},
|
||||
}
|
||||
}
|
||||
241
internal/shared/middleware/request_logger.go
Normal file
241
internal/shared/middleware/request_logger.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 请求日志中间件
|
||||
type RequestLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
|
||||
return &RequestLoggerMiddleware{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestLoggerMiddleware) GetName() string {
|
||||
return "request_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestLoggerMiddleware) GetPriority() int {
|
||||
return 80 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 使用zap logger记录请求信息
|
||||
m.logger.Info("HTTP Request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
)
|
||||
|
||||
// 返回空字符串,因为我们已经用zap记录了
|
||||
return ""
|
||||
})
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestLoggerMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestIDMiddleware 请求ID中间件
|
||||
type RequestIDMiddleware struct{}
|
||||
|
||||
// NewRequestIDMiddleware 创建请求ID中间件
|
||||
func NewRequestIDMiddleware() *RequestIDMiddleware {
|
||||
return &RequestIDMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestIDMiddleware) GetName() string {
|
||||
return "request_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestIDMiddleware) GetPriority() int {
|
||||
return 95 // 最高优先级,第一个执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成请求ID
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 设置请求ID到上下文和响应头
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
|
||||
// 添加到响应头,方便客户端追踪
|
||||
c.Writer.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware 安全头部中间件
|
||||
type SecurityHeadersMiddleware struct{}
|
||||
|
||||
// NewSecurityHeadersMiddleware 创建安全头部中间件
|
||||
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
|
||||
return &SecurityHeadersMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *SecurityHeadersMiddleware) GetName() string {
|
||||
return "security_headers"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *SecurityHeadersMiddleware) GetPriority() int {
|
||||
return 85 // 高优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 设置安全头部
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("Content-Security-Policy", "default-src 'self'")
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// ResponseTimeMiddleware 响应时间中间件
|
||||
type ResponseTimeMiddleware struct{}
|
||||
|
||||
// NewResponseTimeMiddleware 创建响应时间中间件
|
||||
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
|
||||
return &ResponseTimeMiddleware{}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ResponseTimeMiddleware) GetName() string {
|
||||
return "response_time"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ResponseTimeMiddleware) GetPriority() int {
|
||||
return 75 // 中等优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
|
||||
c.Next()
|
||||
|
||||
// 计算响应时间并添加到头部
|
||||
duration := time.Since(start)
|
||||
c.Header("X-Response-Time", duration.String())
|
||||
|
||||
// 记录到上下文中,供其他中间件使用
|
||||
c.Set("response_time", duration)
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ResponseTimeMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
|
||||
type RequestBodyLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
enable bool
|
||||
}
|
||||
|
||||
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
|
||||
return &RequestBodyLoggerMiddleware{
|
||||
logger: logger,
|
||||
enable: enable,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *RequestBodyLoggerMiddleware) GetName() string {
|
||||
return "request_body_logger"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
|
||||
return 70 // 较低优先级
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
if !m.enable {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// 只记录POST, PUT, PATCH请求的body
|
||||
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err == nil {
|
||||
// 重新设置body供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
|
||||
m.logger.Debug("Request Body",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("body", string(bodyBytes)),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
|
||||
return false // 可选中间件,不是全局的
|
||||
}
|
||||
Reference in New Issue
Block a user