Initial commit: Basic project structure and dependencies

This commit is contained in:
2025-06-30 19:21:56 +08:00
commit 03e615a8fd
50 changed files with 11664 additions and 0 deletions

284
internal/shared/cache/redis_cache.go vendored Normal file
View File

@@ -0,0 +1,284 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
logger *zap.Logger
prefix string
// 统计信息
hits int64
misses int64
}
// NewRedisCache 创建Redis缓存实例
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
return &RedisCache{
client: client,
logger: logger,
prefix: prefix,
}
}
// Name 返回服务名称
func (r *RedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (r *RedisCache) Initialize(ctx context.Context) error {
// 测试连接
_, err := r.client.Ping(ctx).Result()
if err != nil {
r.logger.Error("Failed to connect to Redis", zap.Error(err))
return fmt.Errorf("redis connection failed: %w", err)
}
r.logger.Info("Redis cache service initialized")
return nil
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
_, err := r.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (r *RedisCache) Shutdown(ctx context.Context) error {
return r.client.Close()
}
// Get 获取缓存值
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
fullKey := r.getFullKey(key)
val, err := r.client.Get(ctx, fullKey).Result()
if err != nil {
if err == redis.Nil {
r.misses++
return fmt.Errorf("cache miss: key %s not found", key)
}
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
return err
}
r.hits++
return json.Unmarshal([]byte(val), dest)
}
// Set 设置缓存值
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
fullKey := r.getFullKey(key)
data, err := json.Marshal(value)
if err != nil {
return fmt.Errorf("failed to marshal value: %w", err)
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
case string:
expiration, _ = time.ParseDuration(v)
default:
expiration = 24 * time.Hour // 默认24小时
}
} else {
expiration = 24 * time.Hour // 默认24小时
}
err = r.client.Set(ctx, fullKey, data, expiration).Err()
if err != nil {
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
return err
}
return nil
}
// Delete 删除缓存
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
err := r.client.Del(ctx, fullKeys...).Err()
if err != nil {
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
return err
}
return nil
}
// Exists 检查键是否存在
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
fullKey := r.getFullKey(key)
count, err := r.client.Exists(ctx, fullKey).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// GetMultiple 批量获取
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
values, err := r.client.MGet(ctx, fullKeys...).Result()
if err != nil {
return nil, err
}
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
var data interface{}
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
result[keys[i]] = data
}
}
}
return result, nil
}
// SetMultiple 批量设置
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
if len(data) == 0 {
return nil
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
default:
expiration = 24 * time.Hour
}
} else {
expiration = 24 * time.Hour
}
pipe := r.client.Pipeline()
for key, value := range data {
fullKey := r.getFullKey(key)
jsonData, err := json.Marshal(value)
if err != nil {
continue
}
pipe.Set(ctx, fullKey, jsonData, expiration)
}
_, err := pipe.Exec(ctx)
return err
}
// DeletePattern 按模式删除
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return err
}
if len(keys) > 0 {
return r.client.Del(ctx, keys...).Err()
}
return nil
}
// Keys 获取匹配的键
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return nil, err
}
// 移除前缀
result := make([]string, len(keys))
prefixLen := len(r.prefix) + 1 // +1 for ":"
for i, key := range keys {
if len(key) > prefixLen {
result[i] = key[prefixLen:]
} else {
result[i] = key
}
}
return result, nil
}
// Stats 获取缓存统计
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
dbSize, _ := r.client.DBSize(ctx).Result()
return interfaces.CacheStats{
Hits: r.hits,
Misses: r.misses,
Keys: dbSize,
Memory: 0, // 暂时设为0后续可解析Redis info
Connections: 0, // 暂时设为0后续可解析Redis info
}, nil
}
// getFullKey 获取完整键名
func (r *RedisCache) getFullKey(key string) string {
if r.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", r.prefix, key)
}
// Flush 清空所有缓存
func (r *RedisCache) Flush(ctx context.Context) error {
if r.prefix == "" {
return r.client.FlushDB(ctx).Err()
}
// 只删除带前缀的键
return r.DeletePattern(ctx, "*")
}
// GetClient 获取原始Redis客户端
func (r *RedisCache) GetClient() *redis.Client {
return r.client
}

View File

@@ -0,0 +1,195 @@
package database
import (
"context"
"fmt"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// Config 数据库配置
type Config struct {
Host string
Port string
User string
Password string
Name string
SSLMode string
Timezone string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
}
// DB 数据库包装器
type DB struct {
*gorm.DB
config Config
}
// NewConnection 创建新的数据库连接
func NewConnection(config Config) (*DB, error) {
// 构建DSN
dsn := buildDSN(config)
// 配置GORM
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 使用单数表名
},
DisableForeignKeyConstraintWhenMigrating: true,
}
// 连接数据库
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
// 获取底层sql.DB
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 配置连接池
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
return &DB{
DB: db,
config: config,
}, nil
}
// buildDSN 构建数据库连接字符串
func buildDSN(config Config) string {
return fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s",
config.Host,
config.User,
config.Password,
config.Name,
config.Port,
config.SSLMode,
config.Timezone,
)
}
// Close 关闭数据库连接
func (db *DB) Close() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Ping 检查数据库连接
func (db *DB) Ping() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}
// GetStats 获取连接池统计信息
func (db *DB) GetStats() (map[string]interface{}, error) {
sqlDB, err := db.DB.DB()
if err != nil {
return nil, err
}
stats := sqlDB.Stats()
return map[string]interface{}{
"max_open_connections": stats.MaxOpenConnections,
"open_connections": stats.OpenConnections,
"in_use": stats.InUse,
"idle": stats.Idle,
"wait_count": stats.WaitCount,
"wait_duration": stats.WaitDuration,
"max_idle_closed": stats.MaxIdleClosed,
"max_idle_time_closed": stats.MaxIdleTimeClosed,
"max_lifetime_closed": stats.MaxLifetimeClosed,
}, nil
}
// BeginTx 开始事务
func (db *DB) BeginTx() *gorm.DB {
return db.DB.Begin()
}
// Migrate 执行数据库迁移
func (db *DB) Migrate(models ...interface{}) error {
return db.DB.AutoMigrate(models...)
}
// IsHealthy 检查数据库健康状态
func (db *DB) IsHealthy() bool {
return db.Ping() == nil
}
// WithContext 返回带上下文的数据库实例
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
if c, ok := ctx.(context.Context); ok {
return db.DB.WithContext(c)
}
return db.DB
}
// 事务包装器
type TxWrapper struct {
tx *gorm.DB
}
// NewTxWrapper 创建事务包装器
func (db *DB) NewTxWrapper() *TxWrapper {
return &TxWrapper{
tx: db.BeginTx(),
}
}
// Commit 提交事务
func (tx *TxWrapper) Commit() error {
return tx.tx.Commit().Error
}
// Rollback 回滚事务
func (tx *TxWrapper) Rollback() error {
return tx.tx.Rollback().Error
}
// GetDB 获取事务数据库实例
func (tx *TxWrapper) GetDB() *gorm.DB {
return tx.tx
}
// WithTx 在事务中执行函数
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
tx := db.BeginTx()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
panic(r)
}
}()
if err := fn(tx); err != nil {
tx.Rollback()
return err
}
return tx.Commit().Error
}

View File

@@ -0,0 +1,313 @@
package events
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// MemoryEventBus 内存事件总线实现
type MemoryEventBus struct {
subscribers map[string][]interfaces.EventHandler
mutex sync.RWMutex
logger *zap.Logger
running bool
stopCh chan struct{}
eventQueue chan eventTask
workerCount int
}
// eventTask 事件任务
type eventTask struct {
event interfaces.Event
handler interfaces.EventHandler
retries int
}
// NewMemoryEventBus 创建内存事件总线
func NewMemoryEventBus(logger *zap.Logger, workerCount int) *MemoryEventBus {
if workerCount <= 0 {
workerCount = 5 // 默认5个工作协程
}
return &MemoryEventBus{
subscribers: make(map[string][]interfaces.EventHandler),
logger: logger,
eventQueue: make(chan eventTask, 1000), // 缓冲1000个事件
workerCount: workerCount,
stopCh: make(chan struct{}),
}
}
// Name 返回服务名称
func (bus *MemoryEventBus) Name() string {
return "memory-event-bus"
}
// Initialize 初始化服务
func (bus *MemoryEventBus) Initialize(ctx context.Context) error {
bus.logger.Info("Memory event bus service initialized")
return nil
}
// HealthCheck 健康检查
func (bus *MemoryEventBus) HealthCheck(ctx context.Context) error {
if !bus.running {
return fmt.Errorf("event bus is not running")
}
return nil
}
// Shutdown 关闭服务
func (bus *MemoryEventBus) Shutdown(ctx context.Context) error {
bus.Stop(ctx)
return nil
}
// Start 启动事件总线
func (bus *MemoryEventBus) Start(ctx context.Context) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
if bus.running {
return nil
}
bus.running = true
// 启动工作协程
for i := 0; i < bus.workerCount; i++ {
go bus.worker(i)
}
bus.logger.Info("Event bus started", zap.Int("workers", bus.workerCount))
return nil
}
// Stop 停止事件总线
func (bus *MemoryEventBus) Stop(ctx context.Context) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
if !bus.running {
return nil
}
bus.running = false
close(bus.stopCh)
// 等待所有工作协程结束或超时
done := make(chan struct{})
go func() {
time.Sleep(5 * time.Second) // 给工作协程5秒时间结束
close(done)
}()
select {
case <-done:
case <-ctx.Done():
}
bus.logger.Info("Event bus stopped")
return nil
}
// Publish 发布事件(同步)
func (bus *MemoryEventBus) Publish(ctx context.Context, event interfaces.Event) error {
bus.mutex.RLock()
handlers := bus.subscribers[event.GetType()]
bus.mutex.RUnlock()
if len(handlers) == 0 {
bus.logger.Debug("No handlers for event type", zap.String("type", event.GetType()))
return nil
}
for _, handler := range handlers {
if handler.IsAsync() {
// 异步处理
select {
case bus.eventQueue <- eventTask{event: event, handler: handler, retries: 0}:
default:
bus.logger.Warn("Event queue is full, dropping event",
zap.String("type", event.GetType()),
zap.String("handler", handler.GetName()))
}
} else {
// 同步处理
if err := bus.handleEventWithRetry(ctx, event, handler); err != nil {
bus.logger.Error("Failed to handle event synchronously",
zap.String("type", event.GetType()),
zap.String("handler", handler.GetName()),
zap.Error(err))
}
}
}
return nil
}
// PublishBatch 批量发布事件
func (bus *MemoryEventBus) PublishBatch(ctx context.Context, events []interfaces.Event) error {
for _, event := range events {
if err := bus.Publish(ctx, event); err != nil {
return err
}
}
return nil
}
// Subscribe 订阅事件
func (bus *MemoryEventBus) Subscribe(eventType string, handler interfaces.EventHandler) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
handlers := bus.subscribers[eventType]
// 检查是否已经订阅
for _, h := range handlers {
if h.GetName() == handler.GetName() {
return fmt.Errorf("handler %s already subscribed to event type %s", handler.GetName(), eventType)
}
}
bus.subscribers[eventType] = append(handlers, handler)
bus.logger.Info("Handler subscribed to event",
zap.String("handler", handler.GetName()),
zap.String("event_type", eventType))
return nil
}
// Unsubscribe 取消订阅
func (bus *MemoryEventBus) Unsubscribe(eventType string, handler interfaces.EventHandler) error {
bus.mutex.Lock()
defer bus.mutex.Unlock()
handlers := bus.subscribers[eventType]
for i, h := range handlers {
if h.GetName() == handler.GetName() {
// 删除处理器
bus.subscribers[eventType] = append(handlers[:i], handlers[i+1:]...)
bus.logger.Info("Handler unsubscribed from event",
zap.String("handler", handler.GetName()),
zap.String("event_type", eventType))
return nil
}
}
return fmt.Errorf("handler %s not found for event type %s", handler.GetName(), eventType)
}
// GetSubscribers 获取订阅者
func (bus *MemoryEventBus) GetSubscribers(eventType string) []interfaces.EventHandler {
bus.mutex.RLock()
defer bus.mutex.RUnlock()
handlers := bus.subscribers[eventType]
result := make([]interfaces.EventHandler, len(handlers))
copy(result, handlers)
return result
}
// worker 工作协程
func (bus *MemoryEventBus) worker(id int) {
bus.logger.Debug("Event worker started", zap.Int("worker_id", id))
for {
select {
case task := <-bus.eventQueue:
bus.processEventTask(task)
case <-bus.stopCh:
bus.logger.Debug("Event worker stopped", zap.Int("worker_id", id))
return
}
}
}
// processEventTask 处理事件任务
func (bus *MemoryEventBus) processEventTask(task eventTask) {
ctx := context.Background()
err := bus.handleEventWithRetry(ctx, task.event, task.handler)
if err != nil {
retryConfig := task.handler.GetRetryConfig()
if task.retries < retryConfig.MaxRetries {
// 重试
delay := time.Duration(float64(retryConfig.RetryDelay) *
(1 + retryConfig.BackoffFactor*float64(task.retries)))
if delay > retryConfig.MaxDelay {
delay = retryConfig.MaxDelay
}
go func() {
time.Sleep(delay)
task.retries++
select {
case bus.eventQueue <- task:
default:
bus.logger.Error("Failed to requeue event for retry",
zap.String("type", task.event.GetType()),
zap.String("handler", task.handler.GetName()),
zap.Int("retries", task.retries))
}
}()
} else {
bus.logger.Error("Event processing failed after max retries",
zap.String("type", task.event.GetType()),
zap.String("handler", task.handler.GetName()),
zap.Int("retries", task.retries),
zap.Error(err))
}
}
}
// handleEventWithRetry 处理事件并支持重试
func (bus *MemoryEventBus) handleEventWithRetry(ctx context.Context, event interfaces.Event, handler interfaces.EventHandler) error {
start := time.Now()
defer func() {
duration := time.Since(start)
bus.logger.Debug("Event handled",
zap.String("type", event.GetType()),
zap.String("handler", handler.GetName()),
zap.Duration("duration", duration))
}()
return handler.Handle(ctx, event)
}
// GetStats 获取事件总线统计信息
func (bus *MemoryEventBus) GetStats() map[string]interface{} {
bus.mutex.RLock()
defer bus.mutex.RUnlock()
stats := map[string]interface{}{
"running": bus.running,
"worker_count": bus.workerCount,
"queue_length": len(bus.eventQueue),
"queue_capacity": cap(bus.eventQueue),
"event_types": len(bus.subscribers),
}
// 各事件类型的订阅者数量
eventTypes := make(map[string]int)
for eventType, handlers := range bus.subscribers {
eventTypes[eventType] = len(handlers)
}
stats["subscribers"] = eventTypes
return stats
}

View File

@@ -0,0 +1,282 @@
package health
import (
"context"
"fmt"
"sync"
"time"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
)
// HealthChecker 健康检查器实现
type HealthChecker struct {
services map[string]interfaces.Service
cache map[string]*interfaces.HealthStatus
cacheTTL time.Duration
mutex sync.RWMutex
logger *zap.Logger
}
// NewHealthChecker 创建健康检查器
func NewHealthChecker(logger *zap.Logger) *HealthChecker {
return &HealthChecker{
services: make(map[string]interfaces.Service),
cache: make(map[string]*interfaces.HealthStatus),
cacheTTL: 30 * time.Second, // 缓存30秒
logger: logger,
}
}
// RegisterService 注册服务
func (h *HealthChecker) RegisterService(service interfaces.Service) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.services[service.Name()] = service
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
}
// CheckHealth 检查单个服务健康状态
func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *interfaces.HealthStatus {
h.mutex.RLock()
service, exists := h.services[serviceName]
if !exists {
h.mutex.RUnlock()
return &interfaces.HealthStatus{
Status: "DOWN",
Message: "Service not found",
Details: map[string]interface{}{"error": "service not registered"},
CheckedAt: time.Now().Unix(),
ResponseTime: 0,
}
}
// 检查缓存
if cached, exists := h.cache[serviceName]; exists {
if time.Since(time.Unix(cached.CheckedAt, 0)) < h.cacheTTL {
h.mutex.RUnlock()
return cached
}
}
h.mutex.RUnlock()
// 执行健康检查
start := time.Now()
status := &interfaces.HealthStatus{
CheckedAt: start.Unix(),
}
// 设置超时上下文
checkCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
err := service.HealthCheck(checkCtx)
responseTime := time.Since(start).Milliseconds()
status.ResponseTime = responseTime
if err != nil {
status.Status = "DOWN"
status.Message = "Health check failed"
status.Details = map[string]interface{}{
"error": err.Error(),
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Warn("Service health check failed",
zap.String("service", serviceName),
zap.Error(err),
zap.Int64("response_time_ms", responseTime))
} else {
status.Status = "UP"
status.Message = "Service is healthy"
status.Details = map[string]interface{}{
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Debug("Service health check passed",
zap.String("service", serviceName),
zap.Int64("response_time_ms", responseTime))
}
// 更新缓存
h.mutex.Lock()
h.cache[serviceName] = status
h.mutex.Unlock()
return status
}
// CheckAllHealth 检查所有服务的健康状态
func (h *HealthChecker) CheckAllHealth(ctx context.Context) map[string]*interfaces.HealthStatus {
h.mutex.RLock()
serviceNames := make([]string, 0, len(h.services))
for name := range h.services {
serviceNames = append(serviceNames, name)
}
h.mutex.RUnlock()
results := make(map[string]*interfaces.HealthStatus)
var wg sync.WaitGroup
var mutex sync.Mutex
// 并发检查所有服务
for _, serviceName := range serviceNames {
wg.Add(1)
go func(name string) {
defer wg.Done()
status := h.CheckHealth(ctx, name)
mutex.Lock()
results[name] = status
mutex.Unlock()
}(serviceName)
}
wg.Wait()
return results
}
// GetOverallStatus 获取整体健康状态
func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.HealthStatus {
allStatus := h.CheckAllHealth(ctx)
overall := &interfaces.HealthStatus{
CheckedAt: time.Now().Unix(),
ResponseTime: 0,
Details: make(map[string]interface{}),
}
var totalResponseTime int64
healthyCount := 0
totalCount := len(allStatus)
for serviceName, status := range allStatus {
overall.Details[serviceName] = map[string]interface{}{
"status": status.Status,
"message": status.Message,
"response_time": status.ResponseTime,
}
totalResponseTime += status.ResponseTime
if status.Status == "UP" {
healthyCount++
}
}
if totalCount > 0 {
overall.ResponseTime = totalResponseTime / int64(totalCount)
}
// 确定整体状态
if healthyCount == totalCount {
overall.Status = "UP"
overall.Message = "All services are healthy"
} else if healthyCount == 0 {
overall.Status = "DOWN"
overall.Message = "All services are down"
} else {
overall.Status = "DEGRADED"
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
}
return overall
}
// GetServiceNames 获取所有注册的服务名称
func (h *HealthChecker) GetServiceNames() []string {
h.mutex.RLock()
defer h.mutex.RUnlock()
names := make([]string, 0, len(h.services))
for name := range h.services {
names = append(names, name)
}
return names
}
// RemoveService 移除服务
func (h *HealthChecker) RemoveService(serviceName string) {
h.mutex.Lock()
defer h.mutex.Unlock()
delete(h.services, serviceName)
delete(h.cache, serviceName)
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
}
// ClearCache 清除缓存
func (h *HealthChecker) ClearCache() {
h.mutex.Lock()
defer h.mutex.Unlock()
h.cache = make(map[string]*interfaces.HealthStatus)
h.logger.Debug("Health check cache cleared")
}
// GetCacheStats 获取缓存统计
func (h *HealthChecker) GetCacheStats() map[string]interface{} {
h.mutex.RLock()
defer h.mutex.RUnlock()
stats := map[string]interface{}{
"total_services": len(h.services),
"cached_results": len(h.cache),
"cache_ttl_seconds": h.cacheTTL.Seconds(),
}
// 计算缓存命中率
if len(h.services) > 0 {
hitRate := float64(len(h.cache)) / float64(len(h.services)) * 100
stats["cache_hit_rate"] = fmt.Sprintf("%.2f%%", hitRate)
}
return stats
}
// SetCacheTTL 设置缓存TTL
func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
h.mutex.Lock()
defer h.mutex.Unlock()
h.cacheTTL = ttl
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
}
// StartPeriodicCheck 启动定期健康检查
func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Duration) {
ticker := time.NewTicker(interval)
defer ticker.Stop()
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
for {
select {
case <-ctx.Done():
h.logger.Info("Stopped periodic health check")
return
case <-ticker.C:
h.performPeriodicCheck(ctx)
}
}
}
// performPeriodicCheck 执行定期检查
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
overall := h.GetOverallStatus(ctx)
h.logger.Info("Periodic health check completed",
zap.String("overall_status", overall.Status),
zap.String("message", overall.Message),
zap.Int64("response_time_ms", overall.ResponseTime))
// 如果有服务下线,记录警告
if overall.Status != "UP" {
h.logger.Warn("Some services are not healthy",
zap.String("status", overall.Status),
zap.Any("details", overall.Details))
}
}

View File

@@ -0,0 +1,260 @@
package http
import (
"math"
"net/http"
"time"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
)
// ResponseBuilder 响应构建器实现
type ResponseBuilder struct{}
// NewResponseBuilder 创建响应构建器
func NewResponseBuilder() interfaces.ResponseBuilder {
return &ResponseBuilder{}
}
// Success 成功响应
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
msg := "Success"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: true,
Message: msg,
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusOK, response)
}
// Created 创建成功响应
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
msg := "Created successfully"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: true,
Message: msg,
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusCreated, response)
}
// Error 错误响应
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
// 根据错误类型确定状态码
statusCode := http.StatusInternalServerError
message := "Internal server error"
errorDetail := err.Error()
// 这里可以根据不同的错误类型设置不同的状态码
// 例如ValidationError -> 400, NotFoundError -> 404, etc.
response := interfaces.APIResponse{
Success: false,
Message: message,
Errors: errorDetail,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(statusCode, response)
}
// BadRequest 400错误响应
func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...interface{}) {
response := interfaces.APIResponse{
Success: false,
Message: message,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
if len(errors) > 0 {
response.Errors = errors[0]
}
c.JSON(http.StatusBadRequest, response)
}
// Unauthorized 401错误响应
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
msg := "Unauthorized"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusUnauthorized, response)
}
// Forbidden 403错误响应
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
msg := "Forbidden"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusForbidden, response)
}
// NotFound 404错误响应
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
msg := "Resource not found"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusNotFound, response)
}
// Conflict 409错误响应
func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
response := interfaces.APIResponse{
Success: false,
Message: message,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusConflict, response)
}
// InternalError 500错误响应
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
msg := "Internal server error"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusInternalServerError, response)
}
// Paginated 分页响应
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
response := interfaces.APIResponse{
Success: true,
Message: "Success",
Data: data,
Pagination: &pagination,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusOK, response)
}
// getRequestID 从上下文获取请求ID
func (r *ResponseBuilder) getRequestID(c *gin.Context) string {
if requestID, exists := c.Get("request_id"); exists {
if id, ok := requestID.(string); ok {
return id
}
}
return ""
}
// BuildPagination 构建分页元数据
func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta {
totalPages := int(math.Ceil(float64(total) / float64(pageSize)))
if totalPages < 1 {
totalPages = 1
}
return interfaces.PaginationMeta{
Page: page,
PageSize: pageSize,
Total: total,
TotalPages: totalPages,
HasNext: page < totalPages,
HasPrev: page > 1,
}
}
// CustomResponse 自定义响应
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
response := interfaces.APIResponse{
Success: statusCode >= 200 && statusCode < 300,
Message: http.StatusText(statusCode),
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(statusCode, response)
}
// ValidationError 验证错误响应
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
response := interfaces.APIResponse{
Success: false,
Message: "Validation failed",
Errors: errors,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
}
c.JSON(http.StatusUnprocessableEntity, response)
}
// TooManyRequests 限流错误响应
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
msg := "Too many requests"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
response := interfaces.APIResponse{
Success: false,
Message: msg,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
Meta: map[string]interface{}{
"retry_after": "60s",
},
}
c.JSON(http.StatusTooManyRequests, response)
}

View File

@@ -0,0 +1,258 @@
package http
import (
"context"
"fmt"
"net/http"
"sort"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
)
// GinRouter Gin路由器实现
type GinRouter struct {
engine *gin.Engine
config *config.Config
logger *zap.Logger
middlewares []interfaces.Middleware
server *http.Server
}
// NewGinRouter 创建Gin路由器
func NewGinRouter(cfg *config.Config, logger *zap.Logger) *GinRouter {
// 设置Gin模式
if cfg.App.IsProduction() {
gin.SetMode(gin.ReleaseMode)
} else {
gin.SetMode(gin.DebugMode)
}
// 创建Gin引擎
engine := gin.New()
return &GinRouter{
engine: engine,
config: cfg,
logger: logger,
middlewares: make([]interfaces.Middleware, 0),
}
}
// RegisterHandler 注册处理器
func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
// 应用处理器中间件
middlewares := handler.GetMiddlewares()
// 注册路由
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
r.logger.Info("Registered HTTP handler",
zap.String("method", handler.GetMethod()),
zap.String("path", handler.GetPath()))
return nil
}
// RegisterMiddleware 注册中间件
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
r.middlewares = append(r.middlewares, middleware)
r.logger.Info("Registered middleware",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
return nil
}
// RegisterGroup 注册路由组
func (r *GinRouter) RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes {
return r.engine.Group(prefix, middlewares...)
}
// GetRoutes 获取路由信息
func (r *GinRouter) GetRoutes() gin.RoutesInfo {
return r.engine.Routes()
}
// Start 启动路由器
func (r *GinRouter) Start(addr string) error {
// 应用中间件(按优先级排序)
r.applyMiddlewares()
// 创建HTTP服务器
r.server = &http.Server{
Addr: addr,
Handler: r.engine,
ReadTimeout: r.config.Server.ReadTimeout,
WriteTimeout: r.config.Server.WriteTimeout,
IdleTimeout: r.config.Server.IdleTimeout,
}
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
// 启动服务器
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
return fmt.Errorf("failed to start server: %w", err)
}
return nil
}
// Stop 停止路由器
func (r *GinRouter) Stop(ctx context.Context) error {
if r.server == nil {
return nil
}
r.logger.Info("Stopping HTTP server...")
// 优雅关闭服务器
if err := r.server.Shutdown(ctx); err != nil {
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
return err
}
r.logger.Info("HTTP server stopped")
return nil
}
// GetEngine 获取Gin引擎
func (r *GinRouter) GetEngine() *gin.Engine {
return r.engine
}
// applyMiddlewares 应用中间件
func (r *GinRouter) applyMiddlewares() {
// 按优先级排序中间件
sort.Slice(r.middlewares, func(i, j int) bool {
return r.middlewares[i].GetPriority() > r.middlewares[j].GetPriority()
})
// 应用全局中间件
for _, middleware := range r.middlewares {
if middleware.IsGlobal() {
r.engine.Use(middleware.Handle())
r.logger.Debug("Applied global middleware",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
}
}
}
// SetupDefaultRoutes 设置默认路由
func (r *GinRouter) SetupDefaultRoutes() {
// 健康检查
r.engine.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp": time.Now().Unix(),
"service": r.config.App.Name,
"version": r.config.App.Version,
})
})
// API信息
r.engine.GET("/info", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"name": r.config.App.Name,
"version": r.config.App.Version,
"environment": r.config.App.Env,
"timestamp": time.Now().Unix(),
})
})
// 404处理
r.engine.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Route not found",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),
})
})
// 405处理
r.engine.NoMethod(func(c *gin.Context) {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"success": false,
"message": "Method not allowed",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),
})
})
}
// PrintRoutes 打印路由信息
func (r *GinRouter) PrintRoutes() {
routes := r.GetRoutes()
r.logger.Info("Registered routes:")
for _, route := range routes {
r.logger.Info("Route",
zap.String("method", route.Method),
zap.String("path", route.Path),
zap.String("handler", route.Handler))
}
}
// GetStats 获取路由器统计信息
func (r *GinRouter) GetStats() map[string]interface{} {
routes := r.GetRoutes()
stats := map[string]interface{}{
"total_routes": len(routes),
"total_middlewares": len(r.middlewares),
"server_config": map[string]interface{}{
"read_timeout": r.config.Server.ReadTimeout,
"write_timeout": r.config.Server.WriteTimeout,
"idle_timeout": r.config.Server.IdleTimeout,
},
}
// 按方法统计路由数量
methodStats := make(map[string]int)
for _, route := range routes {
methodStats[route.Method]++
}
stats["routes_by_method"] = methodStats
// 中间件统计
middlewareStats := make([]map[string]interface{}, 0, len(r.middlewares))
for _, middleware := range r.middlewares {
middlewareStats = append(middlewareStats, map[string]interface{}{
"name": middleware.GetName(),
"priority": middleware.GetPriority(),
"global": middleware.IsGlobal(),
})
}
stats["middlewares"] = middlewareStats
return stats
}
// EnableMetrics 启用指标收集
func (r *GinRouter) EnableMetrics(collector interfaces.MetricsCollector) {
r.engine.Use(func(c *gin.Context) {
start := time.Now()
c.Next()
duration := time.Since(start).Seconds()
collector.RecordHTTPRequest(c.Request.Method, c.FullPath(), c.Writer.Status(), duration)
})
}
// EnableProfiling 启用性能分析
func (r *GinRouter) EnableProfiling() {
if r.config.Development.EnableProfiler {
// 这里可以集成pprof
r.logger.Info("Profiling enabled")
}
}

View File

@@ -0,0 +1,273 @@
package http
import (
"fmt"
"strings"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/go-playground/validator/v10"
)
// RequestValidator 请求验证器实现
type RequestValidator struct {
validator *validator.Validate
response interfaces.ResponseBuilder
}
// NewRequestValidator 创建请求验证器
func NewRequestValidator(response interfaces.ResponseBuilder) interfaces.RequestValidator {
v := validator.New()
// 注册自定义验证器
registerCustomValidators(v)
return &RequestValidator{
validator: v,
response: response,
}
}
// Validate 验证请求体
func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
return err
}
return nil
}
// ValidateQuery 验证查询参数
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
v.response.BadRequest(c, "Invalid query parameters", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
return err
}
return nil
}
// ValidateParam 验证路径参数
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
v.response.BadRequest(c, "Invalid path parameters", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
return err
}
return nil
}
// BindAndValidate 绑定并验证请求
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
// 绑定请求体
if err := c.ShouldBindJSON(dto); err != nil {
v.response.BadRequest(c, "Invalid request body", err.Error())
return err
}
// 验证数据
return v.Validate(c, dto)
}
// formatValidationErrors 格式化验证错误
func (v *RequestValidator) formatValidationErrors(err error) map[string][]string {
errors := make(map[string][]string)
if validationErrors, ok := err.(validator.ValidationErrors); ok {
for _, fieldError := range validationErrors {
fieldName := v.getFieldName(fieldError)
errorMessage := v.getErrorMessage(fieldError)
if _, exists := errors[fieldName]; !exists {
errors[fieldName] = []string{}
}
errors[fieldName] = append(errors[fieldName], errorMessage)
}
}
return errors
}
// getFieldName 获取字段名JSON标签优先
func (v *RequestValidator) getFieldName(fieldError validator.FieldError) string {
// 可以通过反射获取JSON标签这里简化处理
fieldName := fieldError.Field()
// 转换为snake_case可选
return v.toSnakeCase(fieldName)
}
// getErrorMessage 获取错误消息
func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) string {
field := fieldError.Field()
tag := fieldError.Tag()
param := fieldError.Param()
switch tag {
case "required":
return fmt.Sprintf("%s is required", field)
case "email":
return fmt.Sprintf("%s must be a valid email address", field)
case "min":
return fmt.Sprintf("%s must be at least %s characters", field, param)
case "max":
return fmt.Sprintf("%s must be at most %s characters", field, param)
case "len":
return fmt.Sprintf("%s must be exactly %s characters", field, param)
case "gt":
return fmt.Sprintf("%s must be greater than %s", field, param)
case "gte":
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
case "lt":
return fmt.Sprintf("%s must be less than %s", field, param)
case "lte":
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
case "oneof":
return fmt.Sprintf("%s must be one of [%s]", field, param)
case "url":
return fmt.Sprintf("%s must be a valid URL", field)
case "alpha":
return fmt.Sprintf("%s must contain only alphabetic characters", field)
case "alphanum":
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
case "numeric":
return fmt.Sprintf("%s must be numeric", field)
case "phone":
return fmt.Sprintf("%s must be a valid phone number", field)
case "username":
return fmt.Sprintf("%s must be a valid username", field)
default:
return fmt.Sprintf("%s is invalid", field)
}
}
// toSnakeCase 转换为snake_case
func (v *RequestValidator) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if i > 0 && (r >= 'A' && r <= 'Z') {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// registerCustomValidators 注册自定义验证器
func registerCustomValidators(v *validator.Validate) {
// 注册手机号验证器
v.RegisterValidation("phone", validatePhone)
// 注册用户名验证器
v.RegisterValidation("username", validateUsername)
// 注册密码强度验证器
v.RegisterValidation("strong_password", validateStrongPassword)
}
// validatePhone 验证手机号
func validatePhone(fl validator.FieldLevel) bool {
phone := fl.Field().String()
if phone == "" {
return true // 空值由required标签处理
}
// 简单的手机号验证(可根据需要完善)
if len(phone) < 10 || len(phone) > 15 {
return false
}
// 检查是否以+开头或全是数字
if strings.HasPrefix(phone, "+") {
phone = phone[1:]
}
for _, r := range phone {
if r < '0' || r > '9' {
if r != '-' && r != ' ' && r != '(' && r != ')' {
return false
}
}
}
return true
}
// validateUsername 验证用户名
func validateUsername(fl validator.FieldLevel) bool {
username := fl.Field().String()
if username == "" {
return true // 空值由required标签处理
}
// 用户名规则3-30个字符只能包含字母、数字、下划线不能以数字开头
if len(username) < 3 || len(username) > 30 {
return false
}
// 不能以数字开头
if username[0] >= '0' && username[0] <= '9' {
return false
}
// 只能包含字母、数字、下划线
for _, r := range username {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
return false
}
}
return true
}
// validateStrongPassword 验证密码强度
func validateStrongPassword(fl validator.FieldLevel) bool {
password := fl.Field().String()
if password == "" {
return true // 空值由required标签处理
}
// 密码强度规则至少8个字符包含大小写字母、数字
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
for _, r := range password {
switch {
case r >= 'A' && r <= 'Z':
hasUpper = true
case r >= 'a' && r <= 'z':
hasLower = true
case r >= '0' && r <= '9':
hasDigit = true
}
}
return hasUpper && hasLower && hasDigit
}
// ValidateStruct 直接验证结构体不通过HTTP上下文
func (v *RequestValidator) ValidateStruct(dto interface{}) error {
return v.validator.Struct(dto)
}
// GetValidator 获取原始验证器(用于特殊情况)
func (v *RequestValidator) GetValidator() *validator.Validate {
return v.validator
}

View File

@@ -0,0 +1,92 @@
package interfaces
import (
"context"
"time"
)
// Event 事件接口
type Event interface {
// 事件基础信息
GetID() string
GetType() string
GetVersion() string
GetTimestamp() time.Time
// 事件数据
GetPayload() interface{}
GetMetadata() map[string]interface{}
// 事件来源
GetSource() string
GetAggregateID() string
GetAggregateType() string
// 序列化
Marshal() ([]byte, error)
Unmarshal(data []byte) error
}
// EventHandler 事件处理器接口
type EventHandler interface {
// 处理器标识
GetName() string
GetEventTypes() []string
// 事件处理
Handle(ctx context.Context, event Event) error
// 处理器配置
IsAsync() bool
GetRetryConfig() RetryConfig
}
// DomainEvent 领域事件基础接口
type DomainEvent interface {
Event
// 领域特定信息
GetDomainVersion() string
GetCausationID() string
GetCorrelationID() string
}
// RetryConfig 重试配置
type RetryConfig struct {
MaxRetries int `json:"max_retries"`
RetryDelay time.Duration `json:"retry_delay"`
BackoffFactor float64 `json:"backoff_factor"`
MaxDelay time.Duration `json:"max_delay"`
}
// EventStore 事件存储接口
type EventStore interface {
// 事件存储
SaveEvent(ctx context.Context, event Event) error
SaveEvents(ctx context.Context, events []Event) error
// 事件查询
GetEvents(ctx context.Context, aggregateID string, fromVersion int) ([]Event, error)
GetEventsByType(ctx context.Context, eventType string, limit int) ([]Event, error)
GetEventsSince(ctx context.Context, timestamp time.Time, limit int) ([]Event, error)
// 快照支持
SaveSnapshot(ctx context.Context, aggregateID string, snapshot interface{}) error
GetSnapshot(ctx context.Context, aggregateID string) (interface{}, error)
}
// EventBus 事件总线接口
type EventBus interface {
// 事件发布
Publish(ctx context.Context, event Event) error
PublishBatch(ctx context.Context, events []Event) error
// 事件订阅
Subscribe(eventType string, handler EventHandler) error
Unsubscribe(eventType string, handler EventHandler) error
// 订阅管理
GetSubscribers(eventType string) []EventHandler
Start(ctx context.Context) error
Stop(ctx context.Context) error
}

View File

@@ -0,0 +1,152 @@
package interfaces
import (
"context"
"net/http"
"github.com/gin-gonic/gin"
)
// HTTPHandler HTTP处理器接口
type HTTPHandler interface {
// 处理器信息
GetPath() string
GetMethod() string
GetMiddlewares() []gin.HandlerFunc
// 处理函数
Handle(c *gin.Context)
// 权限验证
RequiresAuth() bool
GetPermissions() []string
}
// RESTHandler REST风格处理器接口
type RESTHandler interface {
HTTPHandler
// CRUD操作
Create(c *gin.Context)
GetByID(c *gin.Context)
Update(c *gin.Context)
Delete(c *gin.Context)
List(c *gin.Context)
}
// Middleware 中间件接口
type Middleware interface {
// 中间件名称
GetName() string
// 中间件优先级
GetPriority() int
// 中间件处理函数
Handle() gin.HandlerFunc
// 是否全局中间件
IsGlobal() bool
}
// Router 路由器接口
type Router interface {
// 路由注册
RegisterHandler(handler HTTPHandler) error
RegisterMiddleware(middleware Middleware) error
RegisterGroup(prefix string, middlewares ...gin.HandlerFunc) gin.IRoutes
// 路由管理
GetRoutes() gin.RoutesInfo
Start(addr string) error
Stop(ctx context.Context) error
// 引擎获取
GetEngine() *gin.Engine
}
// ResponseBuilder 响应构建器接口
type ResponseBuilder interface {
// 成功响应
Success(c *gin.Context, data interface{}, message ...string)
Created(c *gin.Context, data interface{}, message ...string)
// 错误响应
Error(c *gin.Context, err error)
BadRequest(c *gin.Context, message string, errors ...interface{})
Unauthorized(c *gin.Context, message ...string)
Forbidden(c *gin.Context, message ...string)
NotFound(c *gin.Context, message ...string)
Conflict(c *gin.Context, message string)
InternalError(c *gin.Context, message ...string)
// 分页响应
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
}
// RequestValidator 请求验证器接口
type RequestValidator interface {
// 验证请求
Validate(c *gin.Context, dto interface{}) error
ValidateQuery(c *gin.Context, dto interface{}) error
ValidateParam(c *gin.Context, dto interface{}) error
// 绑定和验证
BindAndValidate(c *gin.Context, dto interface{}) error
}
// PaginationMeta 分页元数据
type PaginationMeta struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Total int64 `json:"total"`
TotalPages int `json:"total_pages"`
HasNext bool `json:"has_next"`
HasPrev bool `json:"has_prev"`
}
// APIResponse 标准API响应结构
type APIResponse struct {
Success bool `json:"success"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
Errors interface{} `json:"errors,omitempty"`
Pagination *PaginationMeta `json:"pagination,omitempty"`
Meta map[string]interface{} `json:"meta,omitempty"`
RequestID string `json:"request_id"`
Timestamp int64 `json:"timestamp"`
}
// HealthChecker 健康检查器接口
type HealthChecker interface {
// 健康检查
CheckHealth(ctx context.Context) HealthStatus
GetName() string
GetDependencies() []string
}
// HealthStatus 健康状态
type HealthStatus struct {
Status string `json:"status"` // UP, DOWN, DEGRADED
Message string `json:"message"`
Details map[string]interface{} `json:"details"`
CheckedAt int64 `json:"checked_at"`
ResponseTime int64 `json:"response_time_ms"`
}
// MetricsCollector 指标收集器接口
type MetricsCollector interface {
// HTTP指标
RecordHTTPRequest(method, path string, status int, duration float64)
RecordHTTPDuration(method, path string, duration float64)
// 业务指标
IncrementCounter(name string, labels map[string]string)
RecordGauge(name string, value float64, labels map[string]string)
RecordHistogram(name string, value float64, labels map[string]string)
// 自定义指标
RegisterCounter(name, help string, labels []string) error
RegisterGauge(name, help string, labels []string) error
RegisterHistogram(name, help string, labels []string, buckets []float64) error
// 指标导出
GetHandler() http.Handler
}

View File

@@ -0,0 +1,74 @@
package interfaces
import (
"context"
"time"
)
// Entity 通用实体接口
type Entity interface {
GetID() string
GetCreatedAt() time.Time
GetUpdatedAt() time.Time
}
// BaseRepository 基础仓储接口
type BaseRepository interface {
// 基础操作
Delete(ctx context.Context, id string) error
Count(ctx context.Context, options CountOptions) (int64, error)
Exists(ctx context.Context, id string) (bool, error)
// 软删除支持
SoftDelete(ctx context.Context, id string) error
Restore(ctx context.Context, id string) error
}
// Repository 通用仓储接口,支持泛型
type Repository[T any] interface {
BaseRepository
// 基础CRUD操作
Create(ctx context.Context, entity T) error
GetByID(ctx context.Context, id string) (T, error)
Update(ctx context.Context, entity T) error
// 批量操作
CreateBatch(ctx context.Context, entities []T) error
GetByIDs(ctx context.Context, ids []string) ([]T, error)
UpdateBatch(ctx context.Context, entities []T) error
DeleteBatch(ctx context.Context, ids []string) error
// 查询操作
List(ctx context.Context, options ListOptions) ([]T, error)
// 事务支持
WithTx(tx interface{}) Repository[T]
}
// ListOptions 列表查询选项
type ListOptions struct {
Page int `json:"page"`
PageSize int `json:"page_size"`
Sort string `json:"sort"`
Order string `json:"order"`
Filters map[string]interface{} `json:"filters"`
Search string `json:"search"`
Include []string `json:"include"`
}
// CountOptions 计数查询选项
type CountOptions struct {
Filters map[string]interface{} `json:"filters"`
Search string `json:"search"`
}
// CachedRepository 支持缓存的仓储接口
type CachedRepository[T Entity] interface {
Repository[T]
// 缓存操作
InvalidateCache(ctx context.Context, keys ...string) error
WarmupCache(ctx context.Context) error
GetCacheKey(id string) string
}

View File

@@ -0,0 +1,101 @@
package interfaces
import (
"context"
)
// Service 通用服务接口
type Service interface {
// 服务名称
Name() string
// 服务初始化
Initialize(ctx context.Context) error
// 服务健康检查
HealthCheck(ctx context.Context) error
// 服务关闭
Shutdown(ctx context.Context) error
}
// DomainService 领域服务接口,支持泛型
type DomainService[T Entity] interface {
Service
// 基础业务操作
Create(ctx context.Context, dto interface{}) (*T, error)
GetByID(ctx context.Context, id string) (*T, error)
Update(ctx context.Context, id string, dto interface{}) (*T, error)
Delete(ctx context.Context, id string) error
// 列表和查询
List(ctx context.Context, options ListOptions) ([]*T, error)
Search(ctx context.Context, query string, options ListOptions) ([]*T, error)
Count(ctx context.Context, options CountOptions) (int64, error)
// 业务规则验证
Validate(ctx context.Context, entity *T) error
ValidateCreate(ctx context.Context, dto interface{}) error
ValidateUpdate(ctx context.Context, id string, dto interface{}) error
}
// EventService 事件服务接口
type EventService interface {
Service
// 事件发布
Publish(ctx context.Context, event Event) error
PublishBatch(ctx context.Context, events []Event) error
// 事件订阅
Subscribe(eventType string, handler EventHandler) error
Unsubscribe(eventType string, handler EventHandler) error
// 异步处理
PublishAsync(ctx context.Context, event Event) error
}
// CacheService 缓存服务接口
type CacheService interface {
Service
// 基础缓存操作
Get(ctx context.Context, key string, dest interface{}) error
Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error
Delete(ctx context.Context, keys ...string) error
Exists(ctx context.Context, key string) (bool, error)
// 批量操作
GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error)
SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error
// 模式操作
DeletePattern(ctx context.Context, pattern string) error
Keys(ctx context.Context, pattern string) ([]string, error)
// 缓存统计
Stats(ctx context.Context) (CacheStats, error)
}
// CacheStats 缓存统计信息
type CacheStats struct {
Hits int64 `json:"hits"`
Misses int64 `json:"misses"`
Keys int64 `json:"keys"`
Memory int64 `json:"memory"`
Connections int64 `json:"connections"`
}
// TransactionService 事务服务接口
type TransactionService interface {
Service
// 事务操作
Begin(ctx context.Context) (Transaction, error)
RunInTransaction(ctx context.Context, fn func(Transaction) error) error
}
// Transaction 事务接口
type Transaction interface {
Commit() error
Rollback() error
GetDB() interface{}
}

View File

@@ -0,0 +1,241 @@
package logger
import (
"context"
"fmt"
"os"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Logger 日志接口
type Logger interface {
Debug(msg string, fields ...zapcore.Field)
Info(msg string, fields ...zapcore.Field)
Warn(msg string, fields ...zapcore.Field)
Error(msg string, fields ...zapcore.Field)
Fatal(msg string, fields ...zapcore.Field)
Panic(msg string, fields ...zapcore.Field)
With(fields ...zapcore.Field) Logger
WithContext(ctx context.Context) Logger
Sync() error
}
// ZapLogger Zap日志实现
type ZapLogger struct {
logger *zap.Logger
}
// Config 日志配置
type Config struct {
Level string
Format string
Output string
FilePath string
MaxSize int
MaxBackups int
MaxAge int
Compress bool
}
// NewLogger 创建新的日志实例
func NewLogger(config Config) (Logger, error) {
// 设置日志级别
level, err := zapcore.ParseLevel(config.Level)
if err != nil {
return nil, fmt.Errorf("无效的日志级别: %w", err)
}
// 配置编码器
var encoder zapcore.Encoder
encoderConfig := getEncoderConfig()
switch config.Format {
case "json":
encoder = zapcore.NewJSONEncoder(encoderConfig)
case "console":
encoder = zapcore.NewConsoleEncoder(encoderConfig)
default:
encoder = zapcore.NewJSONEncoder(encoderConfig)
}
// 配置输出
var writeSyncer zapcore.WriteSyncer
switch config.Output {
case "stdout":
writeSyncer = zapcore.AddSync(os.Stdout)
case "stderr":
writeSyncer = zapcore.AddSync(os.Stderr)
case "file":
if config.FilePath == "" {
config.FilePath = "logs/app.log"
}
// 确保目录存在
if err := os.MkdirAll("logs", 0755); err != nil {
return nil, fmt.Errorf("创建日志目录失败: %w", err)
}
file, err := os.OpenFile(config.FilePath, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0666)
if err != nil {
return nil, fmt.Errorf("打开日志文件失败: %w", err)
}
writeSyncer = zapcore.AddSync(file)
default:
writeSyncer = zapcore.AddSync(os.Stdout)
}
// 创建核心
core := zapcore.NewCore(encoder, writeSyncer, level)
// 创建logger
logger := zap.New(core, zap.AddCaller(), zap.AddStacktrace(zapcore.ErrorLevel))
return &ZapLogger{
logger: logger,
}, nil
}
// getEncoderConfig 获取编码器配置
func getEncoderConfig() zapcore.EncoderConfig {
return zapcore.EncoderConfig{
TimeKey: "timestamp",
LevelKey: "level",
NameKey: "logger",
CallerKey: "caller",
FunctionKey: zapcore.OmitKey,
MessageKey: "message",
StacktraceKey: "stacktrace",
LineEnding: zapcore.DefaultLineEnding,
EncodeLevel: zapcore.LowercaseLevelEncoder,
EncodeTime: zapcore.ISO8601TimeEncoder,
EncodeDuration: zapcore.StringDurationEncoder,
EncodeCaller: zapcore.ShortCallerEncoder,
}
}
// Debug 调试日志
func (l *ZapLogger) Debug(msg string, fields ...zapcore.Field) {
l.logger.Debug(msg, fields...)
}
// Info 信息日志
func (l *ZapLogger) Info(msg string, fields ...zapcore.Field) {
l.logger.Info(msg, fields...)
}
// Warn 警告日志
func (l *ZapLogger) Warn(msg string, fields ...zapcore.Field) {
l.logger.Warn(msg, fields...)
}
// Error 错误日志
func (l *ZapLogger) Error(msg string, fields ...zapcore.Field) {
l.logger.Error(msg, fields...)
}
// Fatal 致命错误日志
func (l *ZapLogger) Fatal(msg string, fields ...zapcore.Field) {
l.logger.Fatal(msg, fields...)
}
// Panic 恐慌日志
func (l *ZapLogger) Panic(msg string, fields ...zapcore.Field) {
l.logger.Panic(msg, fields...)
}
// With 添加字段
func (l *ZapLogger) With(fields ...zapcore.Field) Logger {
return &ZapLogger{
logger: l.logger.With(fields...),
}
}
// WithContext 从上下文添加字段
func (l *ZapLogger) WithContext(ctx context.Context) Logger {
// 从上下文中提取常用字段
fields := []zapcore.Field{}
if traceID := getTraceIDFromContext(ctx); traceID != "" {
fields = append(fields, zap.String("trace_id", traceID))
}
if userID := getUserIDFromContext(ctx); userID != "" {
fields = append(fields, zap.String("user_id", userID))
}
if requestID := getRequestIDFromContext(ctx); requestID != "" {
fields = append(fields, zap.String("request_id", requestID))
}
return l.With(fields...)
}
// Sync 同步日志
func (l *ZapLogger) Sync() error {
return l.logger.Sync()
}
// getTraceIDFromContext 从上下文获取追踪ID
func getTraceIDFromContext(ctx context.Context) string {
if traceID := ctx.Value("trace_id"); traceID != nil {
if id, ok := traceID.(string); ok {
return id
}
}
return ""
}
// getUserIDFromContext 从上下文获取用户ID
func getUserIDFromContext(ctx context.Context) string {
if userID := ctx.Value("user_id"); userID != nil {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// getRequestIDFromContext 从上下文获取请求ID
func getRequestIDFromContext(ctx context.Context) string {
if requestID := ctx.Value("request_id"); requestID != nil {
if id, ok := requestID.(string); ok {
return id
}
}
return ""
}
// Field 创建日志字段的便捷函数
func String(key, val string) zapcore.Field {
return zap.String(key, val)
}
func Int(key string, val int) zapcore.Field {
return zap.Int(key, val)
}
func Int64(key string, val int64) zapcore.Field {
return zap.Int64(key, val)
}
func Float64(key string, val float64) zapcore.Field {
return zap.Float64(key, val)
}
func Bool(key string, val bool) zapcore.Field {
return zap.Bool(key, val)
}
func Error(err error) zapcore.Field {
return zap.Error(err)
}
func Any(key string, val interface{}) zapcore.Field {
return zap.Any(key, val)
}
func Duration(key string, val interface{}) zapcore.Field {
return zap.Any(key, val)
}

View File

@@ -0,0 +1,261 @@
package middleware
import (
"net/http"
"strings"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"go.uber.org/zap"
)
// JWTAuthMiddleware JWT认证中间件
type JWTAuthMiddleware struct {
config *config.Config
logger *zap.Logger
}
// NewJWTAuthMiddleware 创建JWT认证中间件
func NewJWTAuthMiddleware(cfg *config.Config, logger *zap.Logger) *JWTAuthMiddleware {
return &JWTAuthMiddleware{
config: cfg,
logger: logger,
}
}
// GetName 返回中间件名称
func (m *JWTAuthMiddleware) GetName() string {
return "jwt_auth"
}
// GetPriority 返回中间件优先级
func (m *JWTAuthMiddleware) GetPriority() int {
return 60 // 中等优先级,在日志之后,业务处理之前
}
// Handle 返回中间件处理函数
func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "Missing authorization header")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "Invalid authorization header format")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "Missing token")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
m.logger.Warn("Invalid token",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "Invalid token")
return
}
// 将用户信息添加到上下文
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *JWTAuthMiddleware) IsGlobal() bool {
return false // 不是全局中间件,需要手动应用到需要认证的路由
}
// JWTClaims JWT声明结构
type JWTClaims struct {
UserID string `json:"user_id"`
Username string `json:"username"`
Email string `json:"email"`
jwt.RegisteredClaims
}
// validateToken 验证JWT token
func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &JWTClaims{}, func(token *jwt.Token) (interface{}, error) {
// 验证签名方法
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return nil, err
}
claims, ok := token.Claims.(*JWTClaims)
if !ok || !token.Valid {
return nil, jwt.ErrSignatureInvalid
}
return claims, nil
}
// respondUnauthorized 返回未授权响应
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "Unauthorized",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),
})
c.Abort()
}
// GenerateToken 生成JWT token
func (m *JWTAuthMiddleware) GenerateToken(userID, username, email string) (string, error) {
now := time.Now()
claims := &JWTClaims{
UserID: userID,
Username: username,
Email: email,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-client"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.ExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// GenerateRefreshToken 生成刷新token
func (m *JWTAuthMiddleware) GenerateRefreshToken(userID string) (string, error) {
now := time.Now()
claims := &jwt.RegisteredClaims{
Issuer: "tyapi-server",
Subject: userID,
Audience: []string{"tyapi-refresh"},
ExpiresAt: jwt.NewNumericDate(now.Add(m.config.JWT.RefreshExpiresIn)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(m.config.JWT.Secret))
}
// ValidateRefreshToken 验证刷新token
func (m *JWTAuthMiddleware) ValidateRefreshToken(tokenString string) (string, error) {
token, err := jwt.ParseWithClaims(tokenString, &jwt.RegisteredClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, jwt.ErrSignatureInvalid
}
return []byte(m.config.JWT.Secret), nil
})
if err != nil {
return "", err
}
claims, ok := token.Claims.(*jwt.RegisteredClaims)
if !ok || !token.Valid {
return "", jwt.ErrSignatureInvalid
}
// 检查是否为刷新token
if len(claims.Audience) == 0 || claims.Audience[0] != "tyapi-refresh" {
return "", jwt.ErrSignatureInvalid
}
return claims.Subject, nil
}
// OptionalAuthMiddleware 可选认证中间件(用户可能登录也可能未登录)
type OptionalAuthMiddleware struct {
jwtAuth *JWTAuthMiddleware
}
// NewOptionalAuthMiddleware 创建可选认证中间件
func NewOptionalAuthMiddleware(jwtAuth *JWTAuthMiddleware) *OptionalAuthMiddleware {
return &OptionalAuthMiddleware{
jwtAuth: jwtAuth,
}
}
// GetName 返回中间件名称
func (m *OptionalAuthMiddleware) GetName() string {
return "optional_auth"
}
// GetPriority 返回中间件优先级
func (m *OptionalAuthMiddleware) GetPriority() int {
return 60 // 与JWT认证中间件相同
}
// Handle 返回中间件处理函数
func (m *OptionalAuthMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// 没有认证头部,设置匿名用户标识
c.Set("is_authenticated", false)
c.Next()
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
c.Set("is_authenticated", false)
c.Next()
return
}
// 提取并验证token
tokenString := authHeader[len(bearerPrefix):]
claims, err := m.jwtAuth.validateToken(tokenString)
if err != nil {
// token无效但不返回错误设置为未认证
c.Set("is_authenticated", false)
c.Next()
return
}
// token有效设置用户信息
c.Set("is_authenticated", true)
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("email", claims.Email)
c.Set("token_claims", claims)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *OptionalAuthMiddleware) IsGlobal() bool {
return false
}

View File

@@ -0,0 +1,104 @@
package middleware
import (
"tyapi-server/internal/config"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
)
// CORSMiddleware CORS中间件
type CORSMiddleware struct {
config *config.Config
}
// NewCORSMiddleware 创建CORS中间件
func NewCORSMiddleware(cfg *config.Config) *CORSMiddleware {
return &CORSMiddleware{
config: cfg,
}
}
// GetName 返回中间件名称
func (m *CORSMiddleware) GetName() string {
return "cors"
}
// GetPriority 返回中间件优先级
func (m *CORSMiddleware) GetPriority() int {
return 100 // 高优先级,最先执行
}
// Handle 返回中间件处理函数
func (m *CORSMiddleware) Handle() gin.HandlerFunc {
if !m.config.Development.EnableCors {
// 如果没有启用CORS返回空处理函数
return func(c *gin.Context) {
c.Next()
}
}
config := cors.Config{
AllowAllOrigins: false,
AllowOrigins: m.getAllowedOrigins(),
AllowMethods: m.getAllowedMethods(),
AllowHeaders: m.getAllowedHeaders(),
ExposeHeaders: []string{
"Content-Length",
"Content-Type",
"X-Request-ID",
"X-Response-Time",
},
AllowCredentials: true,
MaxAge: 86400, // 24小时
}
return cors.New(config)
}
// IsGlobal 是否为全局中间件
func (m *CORSMiddleware) IsGlobal() bool {
return true
}
// getAllowedOrigins 获取允许的来源
func (m *CORSMiddleware) getAllowedOrigins() []string {
if m.config.Development.CorsOrigins == "" {
return []string{"http://localhost:3000", "http://localhost:8080"}
}
// TODO: 解析配置中的origins字符串
return []string{m.config.Development.CorsOrigins}
}
// getAllowedMethods 获取允许的方法
func (m *CORSMiddleware) getAllowedMethods() []string {
if m.config.Development.CorsMethods == "" {
return []string{
"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS",
}
}
// TODO: 解析配置中的methods字符串
return []string{m.config.Development.CorsMethods}
}
// getAllowedHeaders 获取允许的头部
func (m *CORSMiddleware) getAllowedHeaders() []string {
if m.config.Development.CorsHeaders == "" {
return []string{
"Origin",
"Content-Length",
"Content-Type",
"Authorization",
"X-Requested-With",
"Accept",
"Accept-Encoding",
"Accept-Language",
"X-Request-ID",
}
}
// TODO: 解析配置中的headers字符串
return []string{m.config.Development.CorsHeaders}
}

View File

@@ -0,0 +1,166 @@
package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"tyapi-server/internal/config"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
limiters: make(map[string]*rate.Limiter),
}
}
// GetName 返回中间件名称
func (m *RateLimitMiddleware) GetName() string {
return "ratelimit"
}
// GetPriority 返回中间件优先级
func (m *RateLimitMiddleware) GetPriority() int {
return 90 // 高优先级
}
// Handle 返回中间件处理函数
func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取客户端标识IP地址
clientID := m.getClientID(c)
// 获取或创建限流器
limiter := m.getLimiter(clientID)
// 检查是否允许请求
if !limiter.Allow() {
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Header("Retry-After", "60")
c.JSON(http.StatusTooManyRequests, gin.H{
"success": false,
"message": "Rate limit exceeded",
"error": "Too many requests",
})
c.Abort()
return
}
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RateLimitMiddleware) IsGlobal() bool {
return true
}
// getClientID 获取客户端标识
func (m *RateLimitMiddleware) getClientID(c *gin.Context) string {
// 优先使用X-Forwarded-For头部
if xff := c.GetHeader("X-Forwarded-For"); xff != "" {
return xff
}
// 使用X-Real-IP头部
if xri := c.GetHeader("X-Real-IP"); xri != "" {
return xri
}
// 使用RemoteAddr
return c.ClientIP()
}
// getLimiter 获取或创建限流器
func (m *RateLimitMiddleware) getLimiter(clientID string) *rate.Limiter {
m.mutex.RLock()
limiter, exists := m.limiters[clientID]
m.mutex.RUnlock()
if exists {
return limiter
}
m.mutex.Lock()
defer m.mutex.Unlock()
// 双重检查
if limiter, exists := m.limiters[clientID]; exists {
return limiter
}
// 创建新的限流器
// rate.Every计算每个请求之间的间隔
rateLimit := rate.Every(m.config.RateLimit.Window / time.Duration(m.config.RateLimit.Requests))
limiter = rate.NewLimiter(rateLimit, m.config.RateLimit.Burst)
m.limiters[clientID] = limiter
// 启动清理协程(仅第一次创建时)
if len(m.limiters) == 1 {
go m.cleanupRoutine()
}
return limiter
}
// cleanupRoutine 定期清理不活跃的限流器
func (m *RateLimitMiddleware) cleanupRoutine() {
ticker := time.NewTicker(10 * time.Minute) // 每10分钟清理一次
defer ticker.Stop()
for {
select {
case <-ticker.C:
m.cleanup()
}
}
}
// cleanup 清理不活跃的限流器
func (m *RateLimitMiddleware) cleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
now := time.Now()
for clientID, limiter := range m.limiters {
// 如果限流器在过去1小时内没有被使用则删除它
if limiter.Reserve().Delay() == 0 && now.Sub(time.Now()) > time.Hour {
delete(m.limiters, clientID)
}
}
}
// GetStats 获取限流统计
func (m *RateLimitMiddleware) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"active_limiters": len(m.limiters),
"rate_limit": map[string]interface{}{
"requests": m.config.RateLimit.Requests,
"window": m.config.RateLimit.Window,
"burst": m.config.RateLimit.Burst,
},
}
}

View File

@@ -0,0 +1,241 @@
package middleware
import (
"bytes"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
// RequestLoggerMiddleware 请求日志中间件
type RequestLoggerMiddleware struct {
logger *zap.Logger
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
return &RequestLoggerMiddleware{
logger: logger,
}
}
// GetName 返回中间件名称
func (m *RequestLoggerMiddleware) GetName() string {
return "request_logger"
}
// GetPriority 返回中间件优先级
func (m *RequestLoggerMiddleware) GetPriority() int {
return 80 // 中等优先级
}
// Handle 返回中间件处理函数
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 使用zap logger记录请求信息
m.logger.Info("HTTP Request",
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
)
// 返回空字符串因为我们已经用zap记录了
return ""
})
}
// IsGlobal 是否为全局中间件
func (m *RequestLoggerMiddleware) IsGlobal() bool {
return true
}
// RequestIDMiddleware 请求ID中间件
type RequestIDMiddleware struct{}
// NewRequestIDMiddleware 创建请求ID中间件
func NewRequestIDMiddleware() *RequestIDMiddleware {
return &RequestIDMiddleware{}
}
// GetName 返回中间件名称
func (m *RequestIDMiddleware) GetName() string {
return "request_id"
}
// GetPriority 返回中间件优先级
func (m *RequestIDMiddleware) GetPriority() int {
return 95 // 最高优先级,第一个执行
}
// Handle 返回中间件处理函数
func (m *RequestIDMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取或生成请求ID
requestID := c.GetHeader("X-Request-ID")
if requestID == "" {
requestID = uuid.New().String()
}
// 设置请求ID到上下文和响应头
c.Set("request_id", requestID)
c.Header("X-Request-ID", requestID)
// 添加到响应头,方便客户端追踪
c.Writer.Header().Set("X-Request-ID", requestID)
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RequestIDMiddleware) IsGlobal() bool {
return true
}
// SecurityHeadersMiddleware 安全头部中间件
type SecurityHeadersMiddleware struct{}
// NewSecurityHeadersMiddleware 创建安全头部中间件
func NewSecurityHeadersMiddleware() *SecurityHeadersMiddleware {
return &SecurityHeadersMiddleware{}
}
// GetName 返回中间件名称
func (m *SecurityHeadersMiddleware) GetName() string {
return "security_headers"
}
// GetPriority 返回中间件优先级
func (m *SecurityHeadersMiddleware) GetPriority() int {
return 85 // 高优先级
}
// Handle 返回中间件处理函数
func (m *SecurityHeadersMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 设置安全头部
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("X-XSS-Protection", "1; mode=block")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
c.Header("Content-Security-Policy", "default-src 'self'")
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *SecurityHeadersMiddleware) IsGlobal() bool {
return true
}
// ResponseTimeMiddleware 响应时间中间件
type ResponseTimeMiddleware struct{}
// NewResponseTimeMiddleware 创建响应时间中间件
func NewResponseTimeMiddleware() *ResponseTimeMiddleware {
return &ResponseTimeMiddleware{}
}
// GetName 返回中间件名称
func (m *ResponseTimeMiddleware) GetName() string {
return "response_time"
}
// GetPriority 返回中间件优先级
func (m *ResponseTimeMiddleware) GetPriority() int {
return 75 // 中等优先级
}
// Handle 返回中间件处理函数
func (m *ResponseTimeMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
c.Next()
// 计算响应时间并添加到头部
duration := time.Since(start)
c.Header("X-Response-Time", duration.String())
// 记录到上下文中,供其他中间件使用
c.Set("response_time", duration)
}
}
// IsGlobal 是否为全局中间件
func (m *ResponseTimeMiddleware) IsGlobal() bool {
return true
}
// RequestBodyLoggerMiddleware 请求体日志中间件(用于调试)
type RequestBodyLoggerMiddleware struct {
logger *zap.Logger
enable bool
}
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
return &RequestBodyLoggerMiddleware{
logger: logger,
enable: enable,
}
}
// GetName 返回中间件名称
func (m *RequestBodyLoggerMiddleware) GetName() string {
return "request_body_logger"
}
// GetPriority 返回中间件优先级
func (m *RequestBodyLoggerMiddleware) GetPriority() int {
return 70 // 较低优先级
}
// Handle 返回中间件处理函数
func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
if !m.enable {
return func(c *gin.Context) {
c.Next()
}
}
return func(c *gin.Context) {
// 只记录POST, PUT, PATCH请求的body
if c.Request.Method == "POST" || c.Request.Method == "PUT" || c.Request.Method == "PATCH" {
if c.Request.Body != nil {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err == nil {
// 重新设置body供后续处理使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
m.logger.Debug("Request Body",
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.String("body", string(bodyBytes)),
zap.String("request_id", c.GetString("request_id")),
)
}
}
}
c.Next()
}
}
// IsGlobal 是否为全局中间件
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
return false // 可选中间件,不是全局的
}