feat(架构): 完善基础架构设计
This commit is contained in:
@@ -36,7 +36,7 @@ func (h *HealthChecker) RegisterService(service interfaces.Service) {
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.services[service.Name()] = service
|
||||
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
|
||||
h.logger.Info("服务已注册健康检查", zap.String("service", service.Name()))
|
||||
}
|
||||
|
||||
// CheckHealth 检查单个服务健康状态
|
||||
@@ -47,8 +47,8 @@ func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *in
|
||||
h.mutex.RUnlock()
|
||||
return &interfaces.HealthStatus{
|
||||
Status: "DOWN",
|
||||
Message: "Service not found",
|
||||
Details: map[string]interface{}{"error": "service not registered"},
|
||||
Message: "服务未找到",
|
||||
Details: map[string]interface{}{"error": "服务未注册"},
|
||||
CheckedAt: time.Now().Unix(),
|
||||
ResponseTime: 0,
|
||||
}
|
||||
@@ -79,24 +79,24 @@ func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *in
|
||||
|
||||
if err != nil {
|
||||
status.Status = "DOWN"
|
||||
status.Message = "Health check failed"
|
||||
status.Message = "健康检查失败"
|
||||
status.Details = map[string]interface{}{
|
||||
"error": err.Error(),
|
||||
"service_name": serviceName,
|
||||
"check_time": start.Format(time.RFC3339),
|
||||
}
|
||||
h.logger.Warn("Service health check failed",
|
||||
h.logger.Warn("服务健康检查失败",
|
||||
zap.String("service", serviceName),
|
||||
zap.Error(err),
|
||||
zap.Int64("response_time_ms", responseTime))
|
||||
} else {
|
||||
status.Status = "UP"
|
||||
status.Message = "Service is healthy"
|
||||
status.Message = "服务运行正常"
|
||||
status.Details = map[string]interface{}{
|
||||
"service_name": serviceName,
|
||||
"check_time": start.Format(time.RFC3339),
|
||||
}
|
||||
h.logger.Debug("Service health check passed",
|
||||
h.logger.Debug("服务健康检查通过",
|
||||
zap.String("service", serviceName),
|
||||
zap.Int64("response_time_ms", responseTime))
|
||||
}
|
||||
@@ -173,13 +173,13 @@ func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.Health
|
||||
// 确定整体状态
|
||||
if healthyCount == totalCount {
|
||||
overall.Status = "UP"
|
||||
overall.Message = "All services are healthy"
|
||||
overall.Message = "所有服务运行正常"
|
||||
} else if healthyCount == 0 {
|
||||
overall.Status = "DOWN"
|
||||
overall.Message = "All services are down"
|
||||
overall.Message = "所有服务均已下线"
|
||||
} else {
|
||||
overall.Status = "DEGRADED"
|
||||
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
|
||||
overall.Message = fmt.Sprintf("%d/%d 个服务运行正常", healthyCount, totalCount)
|
||||
}
|
||||
|
||||
return overall
|
||||
@@ -205,7 +205,7 @@ func (h *HealthChecker) RemoveService(serviceName string) {
|
||||
delete(h.services, serviceName)
|
||||
delete(h.cache, serviceName)
|
||||
|
||||
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
|
||||
h.logger.Info("服务已从健康检查中移除", zap.String("service", serviceName))
|
||||
}
|
||||
|
||||
// ClearCache 清除缓存
|
||||
@@ -214,7 +214,7 @@ func (h *HealthChecker) ClearCache() {
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.cache = make(map[string]*interfaces.HealthStatus)
|
||||
h.logger.Debug("Health check cache cleared")
|
||||
h.logger.Debug("健康检查缓存已清除")
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计
|
||||
@@ -243,7 +243,7 @@ func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
|
||||
defer h.mutex.Unlock()
|
||||
|
||||
h.cacheTTL = ttl
|
||||
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
|
||||
h.logger.Info("健康检查缓存TTL已更新", zap.Duration("ttl", ttl))
|
||||
}
|
||||
|
||||
// StartPeriodicCheck 启动定期健康检查
|
||||
@@ -251,12 +251,12 @@ func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Du
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
|
||||
h.logger.Info("已启动定期健康检查", zap.Duration("interval", interval))
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.logger.Info("Stopped periodic health check")
|
||||
h.logger.Info("已停止定期健康检查")
|
||||
return
|
||||
case <-ticker.C:
|
||||
h.performPeriodicCheck(ctx)
|
||||
@@ -268,14 +268,14 @@ func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Du
|
||||
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
|
||||
overall := h.GetOverallStatus(ctx)
|
||||
|
||||
h.logger.Info("Periodic health check completed",
|
||||
h.logger.Info("定期健康检查已完成",
|
||||
zap.String("overall_status", overall.Status),
|
||||
zap.String("message", overall.Message),
|
||||
zap.Int64("response_time_ms", overall.ResponseTime))
|
||||
|
||||
// 如果有服务下线,记录警告
|
||||
if overall.Status != "UP" {
|
||||
h.logger.Warn("Some services are not healthy",
|
||||
h.logger.Warn("部分服务不健康",
|
||||
zap.String("status", overall.Status),
|
||||
zap.Any("details", overall.Details))
|
||||
}
|
||||
|
||||
587
internal/shared/hooks/hook_system.go
Normal file
587
internal/shared/hooks/hook_system.go
Normal file
@@ -0,0 +1,587 @@
|
||||
package hooks
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// HookPriority 钩子优先级
|
||||
type HookPriority int
|
||||
|
||||
const (
|
||||
// PriorityLowest 最低优先级
|
||||
PriorityLowest HookPriority = 0
|
||||
// PriorityLow 低优先级
|
||||
PriorityLow HookPriority = 25
|
||||
// PriorityNormal 普通优先级
|
||||
PriorityNormal HookPriority = 50
|
||||
// PriorityHigh 高优先级
|
||||
PriorityHigh HookPriority = 75
|
||||
// PriorityHighest 最高优先级
|
||||
PriorityHighest HookPriority = 100
|
||||
)
|
||||
|
||||
// HookFunc 钩子函数类型
|
||||
type HookFunc func(ctx context.Context, data interface{}) error
|
||||
|
||||
// Hook 钩子定义
|
||||
type Hook struct {
|
||||
Name string
|
||||
Func HookFunc
|
||||
Priority HookPriority
|
||||
Async bool
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// HookResult 钩子执行结果
|
||||
type HookResult struct {
|
||||
HookName string `json:"hook_name"`
|
||||
Success bool `json:"success"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// HookConfig 钩子配置
|
||||
type HookConfig struct {
|
||||
// 默认超时时间
|
||||
DefaultTimeout time.Duration
|
||||
// 是否记录执行时间
|
||||
TrackDuration bool
|
||||
// 错误处理策略
|
||||
ErrorStrategy ErrorStrategy
|
||||
}
|
||||
|
||||
// ErrorStrategy 错误处理策略
|
||||
type ErrorStrategy int
|
||||
|
||||
const (
|
||||
// ContinueOnError 遇到错误继续执行
|
||||
ContinueOnError ErrorStrategy = iota
|
||||
// StopOnError 遇到错误停止执行
|
||||
StopOnError
|
||||
// CollectErrors 收集所有错误
|
||||
CollectErrors
|
||||
)
|
||||
|
||||
// DefaultHookConfig 默认钩子配置
|
||||
func DefaultHookConfig() HookConfig {
|
||||
return HookConfig{
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
TrackDuration: true,
|
||||
ErrorStrategy: ContinueOnError,
|
||||
}
|
||||
}
|
||||
|
||||
// HookSystem 钩子系统
|
||||
type HookSystem struct {
|
||||
hooks map[string][]*Hook
|
||||
config HookConfig
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
stats map[string]*HookStats
|
||||
}
|
||||
|
||||
// HookStats 钩子统计
|
||||
type HookStats struct {
|
||||
TotalExecutions int `json:"total_executions"`
|
||||
Successes int `json:"successes"`
|
||||
Failures int `json:"failures"`
|
||||
TotalDuration time.Duration `json:"total_duration"`
|
||||
AverageDuration time.Duration `json:"average_duration"`
|
||||
LastExecution time.Time `json:"last_execution"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
}
|
||||
|
||||
// NewHookSystem 创建钩子系统
|
||||
func NewHookSystem(config HookConfig, logger *zap.Logger) *HookSystem {
|
||||
return &HookSystem{
|
||||
hooks: make(map[string][]*Hook),
|
||||
config: config,
|
||||
logger: logger,
|
||||
stats: make(map[string]*HookStats),
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册钩子
|
||||
func (hs *HookSystem) Register(event string, hook *Hook) error {
|
||||
if hook.Name == "" {
|
||||
return fmt.Errorf("hook name cannot be empty")
|
||||
}
|
||||
|
||||
if hook.Func == nil {
|
||||
return fmt.Errorf("hook function cannot be nil")
|
||||
}
|
||||
|
||||
if hook.Timeout == 0 {
|
||||
hook.Timeout = hs.config.DefaultTimeout
|
||||
}
|
||||
|
||||
hs.mutex.Lock()
|
||||
defer hs.mutex.Unlock()
|
||||
|
||||
// 检查是否已经注册了同名钩子
|
||||
for _, existingHook := range hs.hooks[event] {
|
||||
if existingHook.Name == hook.Name {
|
||||
return fmt.Errorf("hook %s already registered for event %s", hook.Name, event)
|
||||
}
|
||||
}
|
||||
|
||||
hs.hooks[event] = append(hs.hooks[event], hook)
|
||||
|
||||
// 按优先级排序
|
||||
sort.Slice(hs.hooks[event], func(i, j int) bool {
|
||||
return hs.hooks[event][i].Priority > hs.hooks[event][j].Priority
|
||||
})
|
||||
|
||||
// 初始化统计
|
||||
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
|
||||
hs.stats[hookKey] = &HookStats{}
|
||||
|
||||
hs.logger.Info("Registered hook",
|
||||
zap.String("event", event),
|
||||
zap.String("hook_name", hook.Name),
|
||||
zap.Int("priority", int(hook.Priority)),
|
||||
zap.Bool("async", hook.Async))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterFunc 注册钩子函数(简化版)
|
||||
func (hs *HookSystem) RegisterFunc(event, name string, priority HookPriority, fn HookFunc) error {
|
||||
hook := &Hook{
|
||||
Name: name,
|
||||
Func: fn,
|
||||
Priority: priority,
|
||||
Async: false,
|
||||
Timeout: hs.config.DefaultTimeout,
|
||||
}
|
||||
|
||||
return hs.Register(event, hook)
|
||||
}
|
||||
|
||||
// RegisterAsyncFunc 注册异步钩子函数
|
||||
func (hs *HookSystem) RegisterAsyncFunc(event, name string, priority HookPriority, fn HookFunc) error {
|
||||
hook := &Hook{
|
||||
Name: name,
|
||||
Func: fn,
|
||||
Priority: priority,
|
||||
Async: true,
|
||||
Timeout: hs.config.DefaultTimeout,
|
||||
}
|
||||
|
||||
return hs.Register(event, hook)
|
||||
}
|
||||
|
||||
// Unregister 取消注册钩子
|
||||
func (hs *HookSystem) Unregister(event, hookName string) error {
|
||||
hs.mutex.Lock()
|
||||
defer hs.mutex.Unlock()
|
||||
|
||||
hooks := hs.hooks[event]
|
||||
for i, hook := range hooks {
|
||||
if hook.Name == hookName {
|
||||
// 删除钩子
|
||||
hs.hooks[event] = append(hooks[:i], hooks[i+1:]...)
|
||||
|
||||
// 删除统计
|
||||
hookKey := fmt.Sprintf("%s.%s", event, hookName)
|
||||
delete(hs.stats, hookKey)
|
||||
|
||||
hs.logger.Info("Unregistered hook",
|
||||
zap.String("event", event),
|
||||
zap.String("hook_name", hookName))
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("hook %s not found for event %s", hookName, event)
|
||||
}
|
||||
|
||||
// Trigger 触发事件
|
||||
func (hs *HookSystem) Trigger(ctx context.Context, event string, data interface{}) ([]HookResult, error) {
|
||||
hs.mutex.RLock()
|
||||
hooks := make([]*Hook, len(hs.hooks[event]))
|
||||
copy(hooks, hs.hooks[event])
|
||||
hs.mutex.RUnlock()
|
||||
|
||||
if len(hooks) == 0 {
|
||||
hs.logger.Debug("No hooks registered for event", zap.String("event", event))
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
hs.logger.Debug("Triggering event",
|
||||
zap.String("event", event),
|
||||
zap.Int("hook_count", len(hooks)))
|
||||
|
||||
results := make([]HookResult, 0, len(hooks))
|
||||
var errors []error
|
||||
|
||||
for _, hook := range hooks {
|
||||
result := hs.executeHook(ctx, event, hook, data)
|
||||
results = append(results, result)
|
||||
|
||||
if !result.Success {
|
||||
err := fmt.Errorf("hook %s failed: %s", hook.Name, result.Error)
|
||||
errors = append(errors, err)
|
||||
|
||||
// 根据错误策略决定是否继续
|
||||
if hs.config.ErrorStrategy == StopOnError {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误
|
||||
if len(errors) > 0 {
|
||||
switch hs.config.ErrorStrategy {
|
||||
case StopOnError:
|
||||
return results, errors[0]
|
||||
case CollectErrors:
|
||||
return results, fmt.Errorf("multiple hook errors: %v", errors)
|
||||
case ContinueOnError:
|
||||
// 继续执行,但记录错误
|
||||
hs.logger.Warn("Some hooks failed but continuing execution",
|
||||
zap.String("event", event),
|
||||
zap.Int("error_count", len(errors)))
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// executeHook 执行单个钩子
|
||||
func (hs *HookSystem) executeHook(ctx context.Context, event string, hook *Hook, data interface{}) HookResult {
|
||||
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
|
||||
start := time.Now()
|
||||
|
||||
result := HookResult{
|
||||
HookName: hook.Name,
|
||||
Success: false,
|
||||
}
|
||||
|
||||
// 更新统计
|
||||
defer func() {
|
||||
result.Duration = time.Since(start)
|
||||
hs.updateStats(hookKey, result)
|
||||
}()
|
||||
|
||||
if hook.Async {
|
||||
// 异步执行
|
||||
go func() {
|
||||
hs.doExecuteHook(ctx, hook, data)
|
||||
}()
|
||||
result.Success = true // 异步执行总是认为成功
|
||||
return result
|
||||
}
|
||||
|
||||
// 同步执行
|
||||
err := hs.doExecuteHook(ctx, hook, data)
|
||||
if err != nil {
|
||||
result.Error = err.Error()
|
||||
hs.logger.Error("Hook execution failed",
|
||||
zap.String("event", event),
|
||||
zap.String("hook_name", hook.Name),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
result.Success = true
|
||||
hs.logger.Debug("Hook executed successfully",
|
||||
zap.String("event", event),
|
||||
zap.String("hook_name", hook.Name))
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// doExecuteHook 实际执行钩子
|
||||
func (hs *HookSystem) doExecuteHook(ctx context.Context, hook *Hook, data interface{}) error {
|
||||
// 设置超时上下文
|
||||
hookCtx, cancel := context.WithTimeout(ctx, hook.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// 在goroutine中执行,以便处理超时
|
||||
errChan := make(chan error, 1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
errChan <- fmt.Errorf("hook panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
errChan <- hook.Func(hookCtx, data)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-hookCtx.Done():
|
||||
return fmt.Errorf("hook execution timeout after %v", hook.Timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// updateStats 更新统计信息
|
||||
func (hs *HookSystem) updateStats(hookKey string, result HookResult) {
|
||||
hs.mutex.Lock()
|
||||
defer hs.mutex.Unlock()
|
||||
|
||||
stats, exists := hs.stats[hookKey]
|
||||
if !exists {
|
||||
stats = &HookStats{}
|
||||
hs.stats[hookKey] = stats
|
||||
}
|
||||
|
||||
stats.TotalExecutions++
|
||||
stats.LastExecution = time.Now()
|
||||
|
||||
if result.Success {
|
||||
stats.Successes++
|
||||
} else {
|
||||
stats.Failures++
|
||||
stats.LastError = result.Error
|
||||
}
|
||||
|
||||
if hs.config.TrackDuration {
|
||||
stats.TotalDuration += result.Duration
|
||||
stats.AverageDuration = stats.TotalDuration / time.Duration(stats.TotalExecutions)
|
||||
}
|
||||
}
|
||||
|
||||
// GetHooks 获取事件的所有钩子
|
||||
func (hs *HookSystem) GetHooks(event string) []*Hook {
|
||||
hs.mutex.RLock()
|
||||
defer hs.mutex.RUnlock()
|
||||
|
||||
hooks := make([]*Hook, len(hs.hooks[event]))
|
||||
copy(hooks, hs.hooks[event])
|
||||
return hooks
|
||||
}
|
||||
|
||||
// GetEvents 获取所有注册的事件
|
||||
func (hs *HookSystem) GetEvents() []string {
|
||||
hs.mutex.RLock()
|
||||
defer hs.mutex.RUnlock()
|
||||
|
||||
events := make([]string, 0, len(hs.hooks))
|
||||
for event := range hs.hooks {
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
sort.Strings(events)
|
||||
return events
|
||||
}
|
||||
|
||||
// GetStats 获取钩子统计信息
|
||||
func (hs *HookSystem) GetStats() map[string]*HookStats {
|
||||
hs.mutex.RLock()
|
||||
defer hs.mutex.RUnlock()
|
||||
|
||||
stats := make(map[string]*HookStats)
|
||||
for key, stat := range hs.stats {
|
||||
statCopy := *stat
|
||||
stats[key] = &statCopy
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetEventStats 获取特定事件的统计信息
|
||||
func (hs *HookSystem) GetEventStats(event string) map[string]*HookStats {
|
||||
allStats := hs.GetStats()
|
||||
eventStats := make(map[string]*HookStats)
|
||||
|
||||
prefix := event + "."
|
||||
for key, stat := range allStats {
|
||||
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
|
||||
hookName := key[len(prefix):]
|
||||
eventStats[hookName] = stat
|
||||
}
|
||||
}
|
||||
|
||||
return eventStats
|
||||
}
|
||||
|
||||
// Clear 清除所有钩子
|
||||
func (hs *HookSystem) Clear() {
|
||||
hs.mutex.Lock()
|
||||
defer hs.mutex.Unlock()
|
||||
|
||||
hs.hooks = make(map[string][]*Hook)
|
||||
hs.stats = make(map[string]*HookStats)
|
||||
|
||||
hs.logger.Info("Cleared all hooks")
|
||||
}
|
||||
|
||||
// ClearEvent 清除特定事件的所有钩子
|
||||
func (hs *HookSystem) ClearEvent(event string) {
|
||||
hs.mutex.Lock()
|
||||
defer hs.mutex.Unlock()
|
||||
|
||||
// 删除钩子
|
||||
delete(hs.hooks, event)
|
||||
|
||||
// 删除统计
|
||||
prefix := event + "."
|
||||
for key := range hs.stats {
|
||||
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
|
||||
delete(hs.stats, key)
|
||||
}
|
||||
}
|
||||
|
||||
hs.logger.Info("Cleared hooks for event", zap.String("event", event))
|
||||
}
|
||||
|
||||
// Count 获取钩子总数
|
||||
func (hs *HookSystem) Count() int {
|
||||
hs.mutex.RLock()
|
||||
defer hs.mutex.RUnlock()
|
||||
|
||||
total := 0
|
||||
for _, hooks := range hs.hooks {
|
||||
total += len(hooks)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// EventCount 获取特定事件的钩子数量
|
||||
func (hs *HookSystem) EventCount(event string) int {
|
||||
hs.mutex.RLock()
|
||||
defer hs.mutex.RUnlock()
|
||||
|
||||
return len(hs.hooks[event])
|
||||
}
|
||||
|
||||
// 实现Service接口
|
||||
|
||||
// Name 返回服务名称
|
||||
func (hs *HookSystem) Name() string {
|
||||
return "hook-system"
|
||||
}
|
||||
|
||||
// Initialize 初始化钩子系统
|
||||
func (hs *HookSystem) Initialize(ctx context.Context) error {
|
||||
hs.logger.Info("Hook system initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Start 启动钩子系统
|
||||
func (hs *HookSystem) Start(ctx context.Context) error {
|
||||
hs.logger.Info("Hook system started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (hs *HookSystem) HealthCheck(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭钩子系统
|
||||
func (hs *HookSystem) Shutdown(ctx context.Context) error {
|
||||
hs.logger.Info("Hook system shutdown")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 便捷方法
|
||||
|
||||
// OnUserCreated 用户创建事件钩子
|
||||
func (hs *HookSystem) OnUserCreated(name string, priority HookPriority, fn HookFunc) error {
|
||||
return hs.RegisterFunc("user.created", name, priority, fn)
|
||||
}
|
||||
|
||||
// OnUserUpdated 用户更新事件钩子
|
||||
func (hs *HookSystem) OnUserUpdated(name string, priority HookPriority, fn HookFunc) error {
|
||||
return hs.RegisterFunc("user.updated", name, priority, fn)
|
||||
}
|
||||
|
||||
// OnUserDeleted 用户删除事件钩子
|
||||
func (hs *HookSystem) OnUserDeleted(name string, priority HookPriority, fn HookFunc) error {
|
||||
return hs.RegisterFunc("user.deleted", name, priority, fn)
|
||||
}
|
||||
|
||||
// OnOrderCreated 订单创建事件钩子
|
||||
func (hs *HookSystem) OnOrderCreated(name string, priority HookPriority, fn HookFunc) error {
|
||||
return hs.RegisterFunc("order.created", name, priority, fn)
|
||||
}
|
||||
|
||||
// OnOrderCompleted 订单完成事件钩子
|
||||
func (hs *HookSystem) OnOrderCompleted(name string, priority HookPriority, fn HookFunc) error {
|
||||
return hs.RegisterFunc("order.completed", name, priority, fn)
|
||||
}
|
||||
|
||||
// TriggerUserCreated 触发用户创建事件
|
||||
func (hs *HookSystem) TriggerUserCreated(ctx context.Context, user interface{}) ([]HookResult, error) {
|
||||
return hs.Trigger(ctx, "user.created", user)
|
||||
}
|
||||
|
||||
// TriggerUserUpdated 触发用户更新事件
|
||||
func (hs *HookSystem) TriggerUserUpdated(ctx context.Context, user interface{}) ([]HookResult, error) {
|
||||
return hs.Trigger(ctx, "user.updated", user)
|
||||
}
|
||||
|
||||
// TriggerUserDeleted 触发用户删除事件
|
||||
func (hs *HookSystem) TriggerUserDeleted(ctx context.Context, user interface{}) ([]HookResult, error) {
|
||||
return hs.Trigger(ctx, "user.deleted", user)
|
||||
}
|
||||
|
||||
// HookBuilder 钩子构建器
|
||||
type HookBuilder struct {
|
||||
hook *Hook
|
||||
}
|
||||
|
||||
// NewHookBuilder 创建钩子构建器
|
||||
func NewHookBuilder(name string, fn HookFunc) *HookBuilder {
|
||||
return &HookBuilder{
|
||||
hook: &Hook{
|
||||
Name: name,
|
||||
Func: fn,
|
||||
Priority: PriorityNormal,
|
||||
Async: false,
|
||||
Timeout: 30 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// WithPriority 设置优先级
|
||||
func (hb *HookBuilder) WithPriority(priority HookPriority) *HookBuilder {
|
||||
hb.hook.Priority = priority
|
||||
return hb
|
||||
}
|
||||
|
||||
// WithTimeout 设置超时时间
|
||||
func (hb *HookBuilder) WithTimeout(timeout time.Duration) *HookBuilder {
|
||||
hb.hook.Timeout = timeout
|
||||
return hb
|
||||
}
|
||||
|
||||
// Async 设置为异步执行
|
||||
func (hb *HookBuilder) Async() *HookBuilder {
|
||||
hb.hook.Async = true
|
||||
return hb
|
||||
}
|
||||
|
||||
// Build 构建钩子
|
||||
func (hb *HookBuilder) Build() *Hook {
|
||||
return hb.hook
|
||||
}
|
||||
|
||||
// TypedHookFunc 类型化钩子函数
|
||||
type TypedHookFunc[T any] func(ctx context.Context, data T) error
|
||||
|
||||
// RegisterTypedFunc 注册类型化钩子函数
|
||||
func RegisterTypedFunc[T any](hs *HookSystem, event, name string, priority HookPriority, fn TypedHookFunc[T]) error {
|
||||
hookFunc := func(ctx context.Context, data interface{}) error {
|
||||
typedData, ok := data.(T)
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid data type for hook %s, expected %s", name, reflect.TypeOf((*T)(nil)).Elem().Name())
|
||||
}
|
||||
return fn(ctx, typedData)
|
||||
}
|
||||
|
||||
return hs.RegisterFunc(event, name, priority, hookFunc)
|
||||
}
|
||||
@@ -20,7 +20,7 @@ func NewResponseBuilder() interfaces.ResponseBuilder {
|
||||
|
||||
// Success 成功响应
|
||||
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Success"
|
||||
msg := "操作成功"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
@@ -38,7 +38,7 @@ func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...s
|
||||
|
||||
// Created 创建成功响应
|
||||
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
|
||||
msg := "Created successfully"
|
||||
msg := "创建成功"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
@@ -58,7 +58,7 @@ func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...s
|
||||
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
|
||||
// 根据错误类型确定状态码
|
||||
statusCode := http.StatusInternalServerError
|
||||
message := "Internal server error"
|
||||
message := "服务器内部错误"
|
||||
errorDetail := err.Error()
|
||||
|
||||
// 这里可以根据不同的错误类型设置不同的状态码
|
||||
@@ -93,7 +93,7 @@ func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...i
|
||||
|
||||
// Unauthorized 401错误响应
|
||||
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
|
||||
msg := "Unauthorized"
|
||||
msg := "用户未登录或认证已过期"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
@@ -110,7 +110,7 @@ func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
|
||||
|
||||
// Forbidden 403错误响应
|
||||
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
|
||||
msg := "Forbidden"
|
||||
msg := "权限不足,无法访问此资源"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
@@ -127,7 +127,7 @@ func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
|
||||
|
||||
// NotFound 404错误响应
|
||||
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
|
||||
msg := "Resource not found"
|
||||
msg := "请求的资源不存在"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
@@ -156,7 +156,7 @@ func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
|
||||
|
||||
// InternalError 500错误响应
|
||||
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
|
||||
msg := "Internal server error"
|
||||
msg := "服务器内部错误"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
@@ -175,7 +175,7 @@ func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
|
||||
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: true,
|
||||
Message: "Success",
|
||||
Message: "查询成功",
|
||||
Data: data,
|
||||
Pagination: &pagination,
|
||||
RequestID: r.getRequestID(c),
|
||||
@@ -215,9 +215,35 @@ func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta
|
||||
|
||||
// CustomResponse 自定义响应
|
||||
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
|
||||
var message string
|
||||
switch statusCode {
|
||||
case http.StatusOK:
|
||||
message = "请求成功"
|
||||
case http.StatusCreated:
|
||||
message = "创建成功"
|
||||
case http.StatusNoContent:
|
||||
message = "无内容"
|
||||
case http.StatusBadRequest:
|
||||
message = "请求参数错误"
|
||||
case http.StatusUnauthorized:
|
||||
message = "认证失败"
|
||||
case http.StatusForbidden:
|
||||
message = "权限不足"
|
||||
case http.StatusNotFound:
|
||||
message = "资源不存在"
|
||||
case http.StatusConflict:
|
||||
message = "资源冲突"
|
||||
case http.StatusTooManyRequests:
|
||||
message = "请求过于频繁"
|
||||
case http.StatusInternalServerError:
|
||||
message = "服务器内部错误"
|
||||
default:
|
||||
message = "未知状态"
|
||||
}
|
||||
|
||||
response := interfaces.APIResponse{
|
||||
Success: statusCode >= 200 && statusCode < 300,
|
||||
Message: http.StatusText(statusCode),
|
||||
Message: message,
|
||||
Data: data,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
@@ -230,7 +256,7 @@ func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data in
|
||||
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
|
||||
response := interfaces.APIResponse{
|
||||
Success: false,
|
||||
Message: "Validation failed",
|
||||
Message: "请求参数验证失败",
|
||||
Errors: errors,
|
||||
RequestID: r.getRequestID(c),
|
||||
Timestamp: time.Now().Unix(),
|
||||
@@ -241,7 +267,7 @@ func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
|
||||
|
||||
// TooManyRequests 限流错误响应
|
||||
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
|
||||
msg := "Too many requests"
|
||||
msg := "请求过于频繁,请稍后再试"
|
||||
if len(message) > 0 && message[0] != "" {
|
||||
msg = message[0]
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
swaggerFiles "github.com/swaggo/files"
|
||||
ginSwagger "github.com/swaggo/gin-swagger"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
@@ -51,7 +53,7 @@ func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
|
||||
// 注册路由
|
||||
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
|
||||
|
||||
r.logger.Info("Registered HTTP handler",
|
||||
r.logger.Info("已注册HTTP处理器",
|
||||
zap.String("method", handler.GetMethod()),
|
||||
zap.String("path", handler.GetPath()))
|
||||
|
||||
@@ -62,7 +64,7 @@ func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
|
||||
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
|
||||
r.middlewares = append(r.middlewares, middleware)
|
||||
|
||||
r.logger.Info("Registered middleware",
|
||||
r.logger.Info("已注册中间件",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
|
||||
@@ -93,7 +95,7 @@ func (r *GinRouter) Start(addr string) error {
|
||||
IdleTimeout: r.config.Server.IdleTimeout,
|
||||
}
|
||||
|
||||
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
|
||||
r.logger.Info("正在启动HTTP服务器", zap.String("addr", addr))
|
||||
|
||||
// 启动服务器
|
||||
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||
@@ -109,15 +111,15 @@ func (r *GinRouter) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
r.logger.Info("Stopping HTTP server...")
|
||||
r.logger.Info("正在关闭HTTP服务器...")
|
||||
|
||||
// 优雅关闭服务器
|
||||
if err := r.server.Shutdown(ctx); err != nil {
|
||||
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
|
||||
r.logger.Error("优雅关闭服务器失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("HTTP server stopped")
|
||||
r.logger.Info("HTTP服务器已关闭")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -137,7 +139,7 @@ func (r *GinRouter) applyMiddlewares() {
|
||||
for _, middleware := range r.middlewares {
|
||||
if middleware.IsGlobal() {
|
||||
r.engine.Use(middleware.Handle())
|
||||
r.logger.Debug("Applied global middleware",
|
||||
r.logger.Debug("已应用全局中间件",
|
||||
zap.String("name", middleware.GetName()),
|
||||
zap.Int("priority", middleware.GetPriority()))
|
||||
}
|
||||
@@ -156,6 +158,18 @@ func (r *GinRouter) SetupDefaultRoutes() {
|
||||
})
|
||||
})
|
||||
|
||||
// 详细健康检查
|
||||
r.engine.GET("/health/detailed", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now().Unix(),
|
||||
"service": r.config.App.Name,
|
||||
"version": r.config.App.Version,
|
||||
"uptime": time.Now().Unix(),
|
||||
"environment": r.config.App.Env,
|
||||
})
|
||||
})
|
||||
|
||||
// API信息
|
||||
r.engine.GET("/info", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
@@ -166,11 +180,37 @@ func (r *GinRouter) SetupDefaultRoutes() {
|
||||
})
|
||||
})
|
||||
|
||||
// Swagger文档路由 (仅在开发环境启用)
|
||||
if !r.config.App.IsProduction() {
|
||||
// Swagger UI
|
||||
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
|
||||
|
||||
// API文档重定向
|
||||
r.engine.GET("/docs", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusMovedPermanently, "/swagger/index.html")
|
||||
})
|
||||
|
||||
// API文档信息
|
||||
r.engine.GET("/api/docs", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"swagger_ui": fmt.Sprintf("http://%s/swagger/index.html", c.Request.Host),
|
||||
"openapi_json": fmt.Sprintf("http://%s/swagger/doc.json", c.Request.Host),
|
||||
"redoc": fmt.Sprintf("http://%s/redoc", c.Request.Host),
|
||||
"message": "API文档已可用",
|
||||
})
|
||||
})
|
||||
|
||||
r.logger.Info("Swagger documentation enabled",
|
||||
zap.String("swagger_url", "/swagger/index.html"),
|
||||
zap.String("docs_url", "/docs"),
|
||||
zap.String("api_docs_url", "/api/docs"))
|
||||
}
|
||||
|
||||
// 404处理
|
||||
r.engine.NoRoute(func(c *gin.Context) {
|
||||
c.JSON(http.StatusNotFound, gin.H{
|
||||
"success": false,
|
||||
"message": "Route not found",
|
||||
"message": "路由未找到",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
@@ -181,7 +221,7 @@ func (r *GinRouter) SetupDefaultRoutes() {
|
||||
r.engine.NoMethod(func(c *gin.Context) {
|
||||
c.JSON(http.StatusMethodNotAllowed, gin.H{
|
||||
"success": false,
|
||||
"message": "Method not allowed",
|
||||
"message": "请求方法不允许",
|
||||
"path": c.Request.URL.Path,
|
||||
"method": c.Request.Method,
|
||||
"timestamp": time.Now().Unix(),
|
||||
|
||||
@@ -42,13 +42,13 @@ func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
|
||||
// ValidateQuery 验证查询参数
|
||||
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindQuery(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid query parameters", err.Error())
|
||||
v.response.BadRequest(c, "查询参数格式错误", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
v.response.ValidationError(c, validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -57,13 +57,13 @@ func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error
|
||||
// ValidateParam 验证路径参数
|
||||
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindUri(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid path parameters", err.Error())
|
||||
v.response.BadRequest(c, "路径参数格式错误", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrors(err)
|
||||
v.response.BadRequest(c, "Validation failed", validationErrors)
|
||||
v.response.ValidationError(c, validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -73,7 +73,7 @@ func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error
|
||||
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
|
||||
// 绑定请求体
|
||||
if err := c.ShouldBindJSON(dto); err != nil {
|
||||
v.response.BadRequest(c, "Invalid request body", err.Error())
|
||||
v.response.BadRequest(c, "请求体格式错误", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -115,44 +115,74 @@ func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) stri
|
||||
tag := fieldError.Tag()
|
||||
param := fieldError.Param()
|
||||
|
||||
fieldDisplayName := v.getFieldDisplayName(field)
|
||||
|
||||
switch tag {
|
||||
case "required":
|
||||
return fmt.Sprintf("%s is required", field)
|
||||
return fmt.Sprintf("%s 不能为空", fieldDisplayName)
|
||||
case "email":
|
||||
return fmt.Sprintf("%s must be a valid email address", field)
|
||||
return fmt.Sprintf("%s 必须是有效的邮箱地址", fieldDisplayName)
|
||||
case "min":
|
||||
return fmt.Sprintf("%s must be at least %s characters", field, param)
|
||||
return fmt.Sprintf("%s 长度不能少于 %s 位", fieldDisplayName, param)
|
||||
case "max":
|
||||
return fmt.Sprintf("%s must be at most %s characters", field, param)
|
||||
return fmt.Sprintf("%s 长度不能超过 %s 位", fieldDisplayName, param)
|
||||
case "len":
|
||||
return fmt.Sprintf("%s must be exactly %s characters", field, param)
|
||||
return fmt.Sprintf("%s 长度必须为 %s 位", fieldDisplayName, param)
|
||||
case "gt":
|
||||
return fmt.Sprintf("%s must be greater than %s", field, param)
|
||||
return fmt.Sprintf("%s 必须大于 %s", fieldDisplayName, param)
|
||||
case "gte":
|
||||
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
|
||||
return fmt.Sprintf("%s 必须大于等于 %s", fieldDisplayName, param)
|
||||
case "lt":
|
||||
return fmt.Sprintf("%s must be less than %s", field, param)
|
||||
return fmt.Sprintf("%s 必须小于 %s", fieldDisplayName, param)
|
||||
case "lte":
|
||||
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
|
||||
return fmt.Sprintf("%s 必须小于等于 %s", fieldDisplayName, param)
|
||||
case "oneof":
|
||||
return fmt.Sprintf("%s must be one of [%s]", field, param)
|
||||
return fmt.Sprintf("%s 必须是以下值之一:[%s]", fieldDisplayName, param)
|
||||
case "url":
|
||||
return fmt.Sprintf("%s must be a valid URL", field)
|
||||
return fmt.Sprintf("%s 必须是有效的URL地址", fieldDisplayName)
|
||||
case "alpha":
|
||||
return fmt.Sprintf("%s must contain only alphabetic characters", field)
|
||||
return fmt.Sprintf("%s 只能包含字母", fieldDisplayName)
|
||||
case "alphanum":
|
||||
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
|
||||
return fmt.Sprintf("%s 只能包含字母和数字", fieldDisplayName)
|
||||
case "numeric":
|
||||
return fmt.Sprintf("%s must be numeric", field)
|
||||
return fmt.Sprintf("%s 必须是数字", fieldDisplayName)
|
||||
case "phone":
|
||||
return fmt.Sprintf("%s must be a valid phone number", field)
|
||||
return fmt.Sprintf("%s 必须是有效的手机号", fieldDisplayName)
|
||||
case "username":
|
||||
return fmt.Sprintf("%s must be a valid username", field)
|
||||
return fmt.Sprintf("%s 格式不正确,只能包含字母、数字、下划线,且不能以数字开头", fieldDisplayName)
|
||||
case "strong_password":
|
||||
return fmt.Sprintf("%s 强度不足,必须包含大小写字母和数字,且不少于8位", fieldDisplayName)
|
||||
case "eqfield":
|
||||
return fmt.Sprintf("%s 必须与 %s 一致", fieldDisplayName, v.getFieldDisplayName(param))
|
||||
default:
|
||||
return fmt.Sprintf("%s is invalid", field)
|
||||
return fmt.Sprintf("%s 格式不正确", fieldDisplayName)
|
||||
}
|
||||
}
|
||||
|
||||
// getFieldDisplayName 获取字段显示名称(中文)
|
||||
func (v *RequestValidator) getFieldDisplayName(field string) string {
|
||||
fieldNames := map[string]string{
|
||||
"phone": "手机号",
|
||||
"password": "密码",
|
||||
"confirm_password": "确认密码",
|
||||
"old_password": "原密码",
|
||||
"new_password": "新密码",
|
||||
"confirm_new_password": "确认新密码",
|
||||
"code": "验证码",
|
||||
"username": "用户名",
|
||||
"email": "邮箱",
|
||||
"display_name": "显示名称",
|
||||
"scene": "使用场景",
|
||||
"Password": "密码",
|
||||
"NewPassword": "新密码",
|
||||
}
|
||||
|
||||
if displayName, exists := fieldNames[field]; exists {
|
||||
return displayName
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
// toSnakeCase 转换为snake_case
|
||||
func (v *RequestValidator) toSnakeCase(str string) string {
|
||||
var result strings.Builder
|
||||
|
||||
294
internal/shared/http/validator_zh.go
Normal file
294
internal/shared/http/validator_zh.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-playground/locales/zh"
|
||||
ut "github.com/go-playground/universal-translator"
|
||||
"github.com/go-playground/validator/v10"
|
||||
zh_translations "github.com/go-playground/validator/v10/translations/zh"
|
||||
)
|
||||
|
||||
// RequestValidatorZh 中文验证器实现
|
||||
type RequestValidatorZh struct {
|
||||
validator *validator.Validate
|
||||
translator ut.Translator
|
||||
response interfaces.ResponseBuilder
|
||||
}
|
||||
|
||||
// NewRequestValidatorZh 创建支持中文翻译的请求验证器
|
||||
func NewRequestValidatorZh(response interfaces.ResponseBuilder) interfaces.RequestValidator {
|
||||
// 创建验证器实例
|
||||
validate := validator.New()
|
||||
|
||||
// 创建中文locale
|
||||
zhLocale := zh.New()
|
||||
uni := ut.New(zhLocale, zhLocale)
|
||||
|
||||
// 获取中文翻译器
|
||||
trans, _ := uni.GetTranslator("zh")
|
||||
|
||||
// 注册中文翻译
|
||||
zh_translations.RegisterDefaultTranslations(validate, trans)
|
||||
|
||||
// 注册自定义验证器
|
||||
registerCustomValidatorsZh(validate, trans)
|
||||
|
||||
return &RequestValidatorZh{
|
||||
validator: validate,
|
||||
translator: trans,
|
||||
response: response,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate 验证请求体
|
||||
func (v *RequestValidatorZh) Validate(c *gin.Context, dto interface{}) error {
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrorsZh(err)
|
||||
v.response.ValidationError(c, validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateQuery 验证查询参数
|
||||
func (v *RequestValidatorZh) ValidateQuery(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindQuery(dto); err != nil {
|
||||
v.response.BadRequest(c, "查询参数格式错误", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrorsZh(err)
|
||||
v.response.ValidationError(c, validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateParam 验证路径参数
|
||||
func (v *RequestValidatorZh) ValidateParam(c *gin.Context, dto interface{}) error {
|
||||
if err := c.ShouldBindUri(dto); err != nil {
|
||||
v.response.BadRequest(c, "路径参数格式错误", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if err := v.validator.Struct(dto); err != nil {
|
||||
validationErrors := v.formatValidationErrorsZh(err)
|
||||
v.response.ValidationError(c, validationErrors)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BindAndValidate 绑定并验证请求
|
||||
func (v *RequestValidatorZh) BindAndValidate(c *gin.Context, dto interface{}) error {
|
||||
// 绑定请求体
|
||||
if err := c.ShouldBindJSON(dto); err != nil {
|
||||
v.response.BadRequest(c, "请求体格式错误", err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
// 验证数据
|
||||
return v.Validate(c, dto)
|
||||
}
|
||||
|
||||
// formatValidationErrorsZh 格式化验证错误(中文翻译版)
|
||||
func (v *RequestValidatorZh) formatValidationErrorsZh(err error) map[string][]string {
|
||||
errors := make(map[string][]string)
|
||||
|
||||
if validationErrors, ok := err.(validator.ValidationErrors); ok {
|
||||
for _, fieldError := range validationErrors {
|
||||
fieldName := v.getFieldNameZh(fieldError)
|
||||
|
||||
// 首先尝试使用翻译器获取翻译后的错误消息
|
||||
errorMessage := fieldError.Translate(v.translator)
|
||||
|
||||
// 如果翻译后的消息包含英文字段名,则替换为中文字段名
|
||||
fieldDisplayName := v.getFieldDisplayName(fieldError.Field())
|
||||
if fieldDisplayName != fieldError.Field() {
|
||||
// 替换字段名为中文
|
||||
errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName)
|
||||
}
|
||||
|
||||
if _, exists := errors[fieldName]; !exists {
|
||||
errors[fieldName] = []string{}
|
||||
}
|
||||
errors[fieldName] = append(errors[fieldName], errorMessage)
|
||||
}
|
||||
}
|
||||
|
||||
return errors
|
||||
}
|
||||
|
||||
// getFieldNameZh 获取字段名(JSON标签优先)
|
||||
func (v *RequestValidatorZh) getFieldNameZh(fieldError validator.FieldError) string {
|
||||
fieldName := fieldError.Field()
|
||||
return v.toSnakeCase(fieldName)
|
||||
}
|
||||
|
||||
// getFieldDisplayName 获取字段显示名称(中文)
|
||||
func (v *RequestValidatorZh) getFieldDisplayName(field string) string {
|
||||
fieldNames := map[string]string{
|
||||
"phone": "手机号",
|
||||
"password": "密码",
|
||||
"confirm_password": "确认密码",
|
||||
"old_password": "原密码",
|
||||
"new_password": "新密码",
|
||||
"confirm_new_password": "确认新密码",
|
||||
"code": "验证码",
|
||||
"username": "用户名",
|
||||
"email": "邮箱",
|
||||
"display_name": "显示名称",
|
||||
"scene": "使用场景",
|
||||
"Password": "密码",
|
||||
"NewPassword": "新密码",
|
||||
"ConfirmPassword": "确认密码",
|
||||
}
|
||||
|
||||
if displayName, exists := fieldNames[field]; exists {
|
||||
return displayName
|
||||
}
|
||||
return field
|
||||
}
|
||||
|
||||
// toSnakeCase 转换为snake_case
|
||||
func (v *RequestValidatorZh) toSnakeCase(str string) string {
|
||||
var result strings.Builder
|
||||
for i, r := range str {
|
||||
if i > 0 && (r >= 'A' && r <= 'Z') {
|
||||
result.WriteRune('_')
|
||||
}
|
||||
result.WriteRune(r)
|
||||
}
|
||||
return strings.ToLower(result.String())
|
||||
}
|
||||
|
||||
// registerCustomValidatorsZh 注册自定义验证器和中文翻译
|
||||
func registerCustomValidatorsZh(v *validator.Validate, trans ut.Translator) {
|
||||
// 注册手机号验证器
|
||||
v.RegisterValidation("phone", validatePhoneZh)
|
||||
v.RegisterTranslation("phone", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("phone", "{0}必须是有效的手机号", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("phone", fe.Field())
|
||||
return t
|
||||
})
|
||||
|
||||
// 注册用户名验证器
|
||||
v.RegisterValidation("username", validateUsernameZh)
|
||||
v.RegisterTranslation("username", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("username", "{0}格式不正确,只能包含字母、数字、下划线,且不能以数字开头", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("username", fe.Field())
|
||||
return t
|
||||
})
|
||||
|
||||
// 注册密码强度验证器
|
||||
v.RegisterValidation("strong_password", validateStrongPasswordZh)
|
||||
v.RegisterTranslation("strong_password", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("strong_password", "{0}强度不足,必须包含大小写字母和数字,且不少于8位", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("strong_password", fe.Field())
|
||||
return t
|
||||
})
|
||||
|
||||
// 自定义eqfield翻译
|
||||
v.RegisterTranslation("eqfield", trans, func(ut ut.Translator) error {
|
||||
return ut.Add("eqfield", "{0}必须等于{1}", true)
|
||||
}, func(ut ut.Translator, fe validator.FieldError) string {
|
||||
t, _ := ut.T("eqfield", fe.Field(), fe.Param())
|
||||
return t
|
||||
})
|
||||
}
|
||||
|
||||
// validatePhoneZh 验证手机号
|
||||
func validatePhoneZh(fl validator.FieldLevel) bool {
|
||||
phone := fl.Field().String()
|
||||
if phone == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 中国手机号验证:11位,以1开头
|
||||
if len(phone) != 11 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(phone, "1") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否全是数字
|
||||
for _, r := range phone {
|
||||
if r < '0' || r > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateUsernameZh 验证用户名
|
||||
func validateUsernameZh(fl validator.FieldLevel) bool {
|
||||
username := fl.Field().String()
|
||||
if username == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 用户名规则:3-30个字符,只能包含字母、数字、下划线,不能以数字开头
|
||||
if len(username) < 3 || len(username) > 30 {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不能以数字开头
|
||||
if username[0] >= '0' && username[0] <= '9' {
|
||||
return false
|
||||
}
|
||||
|
||||
// 只能包含字母、数字、下划线
|
||||
for _, r := range username {
|
||||
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// validateStrongPasswordZh 验证密码强度
|
||||
func validateStrongPasswordZh(fl validator.FieldLevel) bool {
|
||||
password := fl.Field().String()
|
||||
if password == "" {
|
||||
return true // 空值由required标签处理
|
||||
}
|
||||
|
||||
// 密码强度规则:至少8个字符,包含大小写字母、数字
|
||||
if len(password) < 8 {
|
||||
return false
|
||||
}
|
||||
|
||||
hasUpper := false
|
||||
hasLower := false
|
||||
hasDigit := false
|
||||
|
||||
for _, r := range password {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
hasUpper = true
|
||||
case r >= 'a' && r <= 'z':
|
||||
hasLower = true
|
||||
case r >= '0' && r <= '9':
|
||||
hasDigit = true
|
||||
}
|
||||
}
|
||||
|
||||
return hasUpper && hasLower && hasDigit
|
||||
}
|
||||
|
||||
// ValidateStruct 直接验证结构体
|
||||
func (v *RequestValidatorZh) ValidateStruct(dto interface{}) error {
|
||||
return v.validator.Struct(dto)
|
||||
}
|
||||
@@ -76,9 +76,14 @@ type ResponseBuilder interface {
|
||||
NotFound(c *gin.Context, message ...string)
|
||||
Conflict(c *gin.Context, message string)
|
||||
InternalError(c *gin.Context, message ...string)
|
||||
ValidationError(c *gin.Context, errors interface{})
|
||||
TooManyRequests(c *gin.Context, message ...string)
|
||||
|
||||
// 分页响应
|
||||
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
|
||||
|
||||
// 自定义响应
|
||||
CustomResponse(c *gin.Context, statusCode int, data interface{})
|
||||
}
|
||||
|
||||
// RequestValidator 请求验证器接口
|
||||
@@ -90,6 +95,9 @@ type RequestValidator interface {
|
||||
|
||||
// 绑定和验证
|
||||
BindAndValidate(c *gin.Context, dto interface{}) error
|
||||
|
||||
// 直接验证结构体
|
||||
ValidateStruct(dto interface{}) error
|
||||
}
|
||||
|
||||
// PaginationMeta 分页元数据
|
||||
|
||||
@@ -2,6 +2,15 @@ package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"tyapi-server/internal/domains/user/dto"
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
)
|
||||
|
||||
// 常见错误定义
|
||||
var (
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
)
|
||||
|
||||
// Service 通用服务接口
|
||||
@@ -16,6 +25,22 @@ type Service interface {
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
// UserService 用户服务接口
|
||||
type UserService interface {
|
||||
Service
|
||||
|
||||
// 用户注册
|
||||
Register(ctx context.Context, req *dto.RegisterRequest) (*entities.User, error)
|
||||
// 密码登录
|
||||
LoginWithPassword(ctx context.Context, req *dto.LoginWithPasswordRequest) (*entities.User, error)
|
||||
// 短信验证码登录
|
||||
LoginWithSMS(ctx context.Context, req *dto.LoginWithSMSRequest) (*entities.User, error)
|
||||
// 修改密码
|
||||
ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error
|
||||
// 根据ID获取用户
|
||||
GetByID(ctx context.Context, id string) (*entities.User, error)
|
||||
}
|
||||
|
||||
// DomainService 领域服务接口,支持泛型
|
||||
type DomainService[T Entity] interface {
|
||||
Service
|
||||
|
||||
214
internal/shared/logger/enhanced_logger.go
Normal file
214
internal/shared/logger/enhanced_logger.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// LogLevel 日志级别
|
||||
type LogLevel string
|
||||
|
||||
const (
|
||||
DebugLevel LogLevel = "debug"
|
||||
InfoLevel LogLevel = "info"
|
||||
WarnLevel LogLevel = "warn"
|
||||
ErrorLevel LogLevel = "error"
|
||||
)
|
||||
|
||||
// LogContext 日志上下文
|
||||
type LogContext struct {
|
||||
RequestID string
|
||||
UserID string
|
||||
TraceID string
|
||||
OperationName string
|
||||
Layer string // repository/service/handler
|
||||
Component string
|
||||
}
|
||||
|
||||
// ContextualLogger 上下文感知的日志器
|
||||
type ContextualLogger struct {
|
||||
logger *zap.Logger
|
||||
ctx LogContext
|
||||
}
|
||||
|
||||
// NewContextualLogger 创建上下文日志器
|
||||
func NewContextualLogger(logger *zap.Logger) *ContextualLogger {
|
||||
return &ContextualLogger{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext 添加上下文信息
|
||||
func (l *ContextualLogger) WithContext(ctx context.Context) *ContextualLogger {
|
||||
logCtx := LogContext{}
|
||||
|
||||
// 从context中提取常用字段
|
||||
if requestID := getStringFromContext(ctx, "request_id"); requestID != "" {
|
||||
logCtx.RequestID = requestID
|
||||
}
|
||||
if userID := getStringFromContext(ctx, "user_id"); userID != "" {
|
||||
logCtx.UserID = userID
|
||||
}
|
||||
if traceID := getStringFromContext(ctx, "trace_id"); traceID != "" {
|
||||
logCtx.TraceID = traceID
|
||||
}
|
||||
|
||||
return &ContextualLogger{
|
||||
logger: l.logger,
|
||||
ctx: logCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// WithLayer 设置层级信息
|
||||
func (l *ContextualLogger) WithLayer(layer string) *ContextualLogger {
|
||||
newCtx := l.ctx
|
||||
newCtx.Layer = layer
|
||||
return &ContextualLogger{
|
||||
logger: l.logger,
|
||||
ctx: newCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// WithComponent 设置组件信息
|
||||
func (l *ContextualLogger) WithComponent(component string) *ContextualLogger {
|
||||
newCtx := l.ctx
|
||||
newCtx.Component = component
|
||||
return &ContextualLogger{
|
||||
logger: l.logger,
|
||||
ctx: newCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// WithOperation 设置操作名称
|
||||
func (l *ContextualLogger) WithOperation(operation string) *ContextualLogger {
|
||||
newCtx := l.ctx
|
||||
newCtx.OperationName = operation
|
||||
return &ContextualLogger{
|
||||
logger: l.logger,
|
||||
ctx: newCtx,
|
||||
}
|
||||
}
|
||||
|
||||
// 构建基础字段
|
||||
func (l *ContextualLogger) buildBaseFields() []zapcore.Field {
|
||||
fields := []zapcore.Field{}
|
||||
|
||||
if l.ctx.RequestID != "" {
|
||||
fields = append(fields, zap.String("request_id", l.ctx.RequestID))
|
||||
}
|
||||
if l.ctx.UserID != "" {
|
||||
fields = append(fields, zap.String("user_id", l.ctx.UserID))
|
||||
}
|
||||
if l.ctx.TraceID != "" {
|
||||
fields = append(fields, zap.String("trace_id", l.ctx.TraceID))
|
||||
}
|
||||
if l.ctx.Layer != "" {
|
||||
fields = append(fields, zap.String("layer", l.ctx.Layer))
|
||||
}
|
||||
if l.ctx.Component != "" {
|
||||
fields = append(fields, zap.String("component", l.ctx.Component))
|
||||
}
|
||||
if l.ctx.OperationName != "" {
|
||||
fields = append(fields, zap.String("operation", l.ctx.OperationName))
|
||||
}
|
||||
|
||||
return fields
|
||||
}
|
||||
|
||||
// LogTechnicalError 记录技术性错误(Repository层)
|
||||
func (l *ContextualLogger) LogTechnicalError(msg string, err error, fields ...zapcore.Field) {
|
||||
allFields := l.buildBaseFields()
|
||||
allFields = append(allFields, zap.Error(err))
|
||||
allFields = append(allFields, zap.String("error_type", "technical"))
|
||||
allFields = append(allFields, fields...)
|
||||
|
||||
l.logger.Error(msg, allFields...)
|
||||
}
|
||||
|
||||
// LogBusinessWarn 记录业务警告(Service层)
|
||||
func (l *ContextualLogger) LogBusinessWarn(msg string, fields ...zapcore.Field) {
|
||||
allFields := l.buildBaseFields()
|
||||
allFields = append(allFields, zap.String("log_type", "business"))
|
||||
allFields = append(allFields, fields...)
|
||||
|
||||
l.logger.Warn(msg, allFields...)
|
||||
}
|
||||
|
||||
// LogBusinessInfo 记录业务信息(Service层)
|
||||
func (l *ContextualLogger) LogBusinessInfo(msg string, fields ...zapcore.Field) {
|
||||
allFields := l.buildBaseFields()
|
||||
allFields = append(allFields, zap.String("log_type", "business"))
|
||||
allFields = append(allFields, fields...)
|
||||
|
||||
l.logger.Info(msg, allFields...)
|
||||
}
|
||||
|
||||
// LogUserAction 记录用户行为(Handler层)
|
||||
func (l *ContextualLogger) LogUserAction(msg string, fields ...zapcore.Field) {
|
||||
allFields := l.buildBaseFields()
|
||||
allFields = append(allFields, zap.String("log_type", "user_action"))
|
||||
allFields = append(allFields, fields...)
|
||||
|
||||
l.logger.Info(msg, allFields...)
|
||||
}
|
||||
|
||||
// LogRequestFailed 记录请求失败(Handler层)
|
||||
func (l *ContextualLogger) LogRequestFailed(msg string, errorType string, fields ...zapcore.Field) {
|
||||
allFields := l.buildBaseFields()
|
||||
allFields = append(allFields, zap.String("log_type", "request_failed"))
|
||||
allFields = append(allFields, zap.String("error_category", errorType))
|
||||
allFields = append(allFields, fields...)
|
||||
|
||||
l.logger.Info(msg, allFields...)
|
||||
}
|
||||
|
||||
// getStringFromContext 从上下文获取字符串值
|
||||
func getStringFromContext(ctx context.Context, key string) string {
|
||||
if value := ctx.Value(key); value != nil {
|
||||
if str, ok := value.(string); ok {
|
||||
return str
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ErrorCategory 错误分类
|
||||
type ErrorCategory string
|
||||
|
||||
const (
|
||||
DatabaseError ErrorCategory = "database"
|
||||
NetworkError ErrorCategory = "network"
|
||||
ValidationError ErrorCategory = "validation"
|
||||
BusinessError ErrorCategory = "business"
|
||||
AuthError ErrorCategory = "auth"
|
||||
ExternalAPIError ErrorCategory = "external_api"
|
||||
)
|
||||
|
||||
// CategorizeError 错误分类
|
||||
func CategorizeError(err error) ErrorCategory {
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
|
||||
switch {
|
||||
case strings.Contains(errMsg, "database") ||
|
||||
strings.Contains(errMsg, "sql") ||
|
||||
strings.Contains(errMsg, "gorm"):
|
||||
return DatabaseError
|
||||
case strings.Contains(errMsg, "network") ||
|
||||
strings.Contains(errMsg, "connection") ||
|
||||
strings.Contains(errMsg, "timeout"):
|
||||
return NetworkError
|
||||
case strings.Contains(errMsg, "validation") ||
|
||||
strings.Contains(errMsg, "invalid") ||
|
||||
strings.Contains(errMsg, "format"):
|
||||
return ValidationError
|
||||
case strings.Contains(errMsg, "unauthorized") ||
|
||||
strings.Contains(errMsg, "forbidden") ||
|
||||
strings.Contains(errMsg, "token"):
|
||||
return AuthError
|
||||
default:
|
||||
return BusinessError
|
||||
}
|
||||
}
|
||||
263
internal/shared/metrics/business_metrics.go
Normal file
263
internal/shared/metrics/business_metrics.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// BusinessMetrics 业务指标收集器
|
||||
type BusinessMetrics struct {
|
||||
metrics interfaces.MetricsCollector
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
|
||||
// 业务指标缓存
|
||||
userMetrics map[string]int64
|
||||
orderMetrics map[string]int64
|
||||
}
|
||||
|
||||
// NewBusinessMetrics 创建业务指标收集器
|
||||
func NewBusinessMetrics(metrics interfaces.MetricsCollector, logger *zap.Logger) *BusinessMetrics {
|
||||
bm := &BusinessMetrics{
|
||||
metrics: metrics,
|
||||
logger: logger,
|
||||
userMetrics: make(map[string]int64),
|
||||
orderMetrics: make(map[string]int64),
|
||||
}
|
||||
|
||||
// 注册业务指标
|
||||
bm.registerBusinessMetrics()
|
||||
|
||||
return bm
|
||||
}
|
||||
|
||||
// registerBusinessMetrics 注册业务指标
|
||||
func (bm *BusinessMetrics) registerBusinessMetrics() {
|
||||
// 用户相关指标
|
||||
bm.metrics.RegisterCounter("users_created_total", "Total number of users created", []string{"source"})
|
||||
bm.metrics.RegisterCounter("users_login_total", "Total number of user logins", []string{"method", "status"})
|
||||
bm.metrics.RegisterGauge("users_active_sessions", "Current number of active user sessions", nil)
|
||||
|
||||
// 订单相关指标
|
||||
bm.metrics.RegisterCounter("orders_created_total", "Total number of orders created", []string{"status"})
|
||||
bm.metrics.RegisterCounter("orders_amount_total", "Total order amount in cents", []string{"currency"})
|
||||
bm.metrics.RegisterHistogram("orders_processing_duration_seconds", "Order processing duration", []string{"status"}, []float64{0.1, 0.5, 1, 2, 5, 10, 30})
|
||||
|
||||
// API相关指标
|
||||
bm.metrics.RegisterCounter("api_errors_total", "Total number of API errors", []string{"endpoint", "error_type"})
|
||||
bm.metrics.RegisterHistogram("api_response_size_bytes", "API response size in bytes", []string{"endpoint"}, []float64{100, 1000, 10000, 100000})
|
||||
|
||||
// 缓存相关指标
|
||||
bm.metrics.RegisterCounter("cache_operations_total", "Total number of cache operations", []string{"operation", "result"})
|
||||
bm.metrics.RegisterGauge("cache_memory_usage_bytes", "Cache memory usage in bytes", []string{"cache_type"})
|
||||
|
||||
// 数据库相关指标
|
||||
bm.metrics.RegisterHistogram("database_query_duration_seconds", "Database query duration", []string{"operation", "table"}, []float64{0.001, 0.01, 0.1, 1, 10})
|
||||
bm.metrics.RegisterCounter("database_errors_total", "Total number of database errors", []string{"operation", "error_type"})
|
||||
|
||||
bm.logger.Info("Business metrics registered successfully")
|
||||
}
|
||||
|
||||
// User相关指标
|
||||
|
||||
// RecordUserCreated 记录用户创建
|
||||
func (bm *BusinessMetrics) RecordUserCreated(source string) {
|
||||
bm.metrics.IncrementCounter("users_created_total", map[string]string{
|
||||
"source": source,
|
||||
})
|
||||
|
||||
bm.mutex.Lock()
|
||||
bm.userMetrics["created"]++
|
||||
bm.mutex.Unlock()
|
||||
|
||||
bm.logger.Debug("Recorded user created", zap.String("source", source))
|
||||
}
|
||||
|
||||
// RecordUserLogin 记录用户登录
|
||||
func (bm *BusinessMetrics) RecordUserLogin(method, status string) {
|
||||
bm.metrics.IncrementCounter("users_login_total", map[string]string{
|
||||
"method": method,
|
||||
"status": status,
|
||||
})
|
||||
|
||||
bm.logger.Debug("Recorded user login", zap.String("method", method), zap.String("status", status))
|
||||
}
|
||||
|
||||
// UpdateActiveUserSessions 更新活跃用户会话数
|
||||
func (bm *BusinessMetrics) UpdateActiveUserSessions(count float64) {
|
||||
bm.metrics.RecordGauge("users_active_sessions", count, nil)
|
||||
}
|
||||
|
||||
// Order相关指标
|
||||
|
||||
// RecordOrderCreated 记录订单创建
|
||||
func (bm *BusinessMetrics) RecordOrderCreated(status string, amount float64, currency string) {
|
||||
bm.metrics.IncrementCounter("orders_created_total", map[string]string{
|
||||
"status": status,
|
||||
})
|
||||
|
||||
// 记录订单金额(以分为单位)
|
||||
amountCents := int64(amount * 100)
|
||||
bm.metrics.IncrementCounter("orders_amount_total", map[string]string{
|
||||
"currency": currency,
|
||||
})
|
||||
|
||||
bm.mutex.Lock()
|
||||
bm.orderMetrics["created"]++
|
||||
bm.orderMetrics["amount"] += amountCents
|
||||
bm.mutex.Unlock()
|
||||
|
||||
bm.logger.Debug("Recorded order created",
|
||||
zap.String("status", status),
|
||||
zap.Float64("amount", amount),
|
||||
zap.String("currency", currency))
|
||||
}
|
||||
|
||||
// RecordOrderProcessingDuration 记录订单处理时长
|
||||
func (bm *BusinessMetrics) RecordOrderProcessingDuration(status string, duration float64) {
|
||||
bm.metrics.RecordHistogram("orders_processing_duration_seconds", duration, map[string]string{
|
||||
"status": status,
|
||||
})
|
||||
}
|
||||
|
||||
// API相关指标
|
||||
|
||||
// RecordAPIError 记录API错误
|
||||
func (bm *BusinessMetrics) RecordAPIError(endpoint, errorType string) {
|
||||
bm.metrics.IncrementCounter("api_errors_total", map[string]string{
|
||||
"endpoint": endpoint,
|
||||
"error_type": errorType,
|
||||
})
|
||||
|
||||
bm.logger.Debug("Recorded API error",
|
||||
zap.String("endpoint", endpoint),
|
||||
zap.String("error_type", errorType))
|
||||
}
|
||||
|
||||
// RecordAPIResponseSize 记录API响应大小
|
||||
func (bm *BusinessMetrics) RecordAPIResponseSize(endpoint string, sizeBytes float64) {
|
||||
bm.metrics.RecordHistogram("api_response_size_bytes", sizeBytes, map[string]string{
|
||||
"endpoint": endpoint,
|
||||
})
|
||||
}
|
||||
|
||||
// Cache相关指标
|
||||
|
||||
// RecordCacheOperation 记录缓存操作
|
||||
func (bm *BusinessMetrics) RecordCacheOperation(operation, result string) {
|
||||
bm.metrics.IncrementCounter("cache_operations_total", map[string]string{
|
||||
"operation": operation,
|
||||
"result": result,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateCacheMemoryUsage 更新缓存内存使用量
|
||||
func (bm *BusinessMetrics) UpdateCacheMemoryUsage(cacheType string, usageBytes float64) {
|
||||
bm.metrics.RecordGauge("cache_memory_usage_bytes", usageBytes, map[string]string{
|
||||
"cache_type": cacheType,
|
||||
})
|
||||
}
|
||||
|
||||
// Database相关指标
|
||||
|
||||
// RecordDatabaseQuery 记录数据库查询
|
||||
func (bm *BusinessMetrics) RecordDatabaseQuery(operation, table string, duration float64) {
|
||||
bm.metrics.RecordHistogram("database_query_duration_seconds", duration, map[string]string{
|
||||
"operation": operation,
|
||||
"table": table,
|
||||
})
|
||||
}
|
||||
|
||||
// RecordDatabaseError 记录数据库错误
|
||||
func (bm *BusinessMetrics) RecordDatabaseError(operation, errorType string) {
|
||||
bm.metrics.IncrementCounter("database_errors_total", map[string]string{
|
||||
"operation": operation,
|
||||
"error_type": errorType,
|
||||
})
|
||||
|
||||
bm.logger.Debug("Recorded database error",
|
||||
zap.String("operation", operation),
|
||||
zap.String("error_type", errorType))
|
||||
}
|
||||
|
||||
// 获取统计信息
|
||||
|
||||
// GetUserStats 获取用户统计
|
||||
func (bm *BusinessMetrics) GetUserStats() map[string]int64 {
|
||||
bm.mutex.RLock()
|
||||
defer bm.mutex.RUnlock()
|
||||
|
||||
stats := make(map[string]int64)
|
||||
for k, v := range bm.userMetrics {
|
||||
stats[k] = v
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetOrderStats 获取订单统计
|
||||
func (bm *BusinessMetrics) GetOrderStats() map[string]int64 {
|
||||
bm.mutex.RLock()
|
||||
defer bm.mutex.RUnlock()
|
||||
|
||||
stats := make(map[string]int64)
|
||||
for k, v := range bm.orderMetrics {
|
||||
stats[k] = v
|
||||
}
|
||||
return stats
|
||||
}
|
||||
|
||||
// GetOverallStats 获取整体统计
|
||||
func (bm *BusinessMetrics) GetOverallStats() map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"user_stats": bm.GetUserStats(),
|
||||
"order_stats": bm.GetOrderStats(),
|
||||
}
|
||||
}
|
||||
|
||||
// Reset 重置统计数据
|
||||
func (bm *BusinessMetrics) Reset() {
|
||||
bm.mutex.Lock()
|
||||
defer bm.mutex.Unlock()
|
||||
|
||||
bm.userMetrics = make(map[string]int64)
|
||||
bm.orderMetrics = make(map[string]int64)
|
||||
|
||||
bm.logger.Info("Business metrics reset")
|
||||
}
|
||||
|
||||
// Context相关方法
|
||||
|
||||
// WithContext 创建带上下文的业务指标收集器
|
||||
func (bm *BusinessMetrics) WithContext(ctx context.Context) *BusinessMetrics {
|
||||
// 这里可以从context中提取追踪信息,关联指标
|
||||
return bm
|
||||
}
|
||||
|
||||
// 实现Service接口(如果需要)
|
||||
|
||||
// Name 返回服务名称
|
||||
func (bm *BusinessMetrics) Name() string {
|
||||
return "business-metrics"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (bm *BusinessMetrics) Initialize(ctx context.Context) error {
|
||||
bm.logger.Info("Business metrics service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (bm *BusinessMetrics) HealthCheck(ctx context.Context) error {
|
||||
// 检查指标收集器是否正常
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (bm *BusinessMetrics) Shutdown(ctx context.Context) error {
|
||||
bm.logger.Info("Business metrics service shutdown")
|
||||
return nil
|
||||
}
|
||||
353
internal/shared/metrics/prometheus_metrics.go
Normal file
353
internal/shared/metrics/prometheus_metrics.go
Normal file
@@ -0,0 +1,353 @@
|
||||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PrometheusMetrics Prometheus指标收集器
|
||||
type PrometheusMetrics struct {
|
||||
logger *zap.Logger
|
||||
registry *prometheus.Registry
|
||||
mutex sync.RWMutex
|
||||
|
||||
// 预定义指标
|
||||
httpRequests *prometheus.CounterVec
|
||||
httpDuration *prometheus.HistogramVec
|
||||
activeUsers prometheus.Gauge
|
||||
dbConnections prometheus.Gauge
|
||||
cacheHits *prometheus.CounterVec
|
||||
businessMetrics map[string]prometheus.Collector
|
||||
}
|
||||
|
||||
// NewPrometheusMetrics 创建Prometheus指标收集器
|
||||
func NewPrometheusMetrics(logger *zap.Logger) *PrometheusMetrics {
|
||||
registry := prometheus.NewRegistry()
|
||||
|
||||
// HTTP请求计数器
|
||||
httpRequests := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "http_requests_total",
|
||||
Help: "Total number of HTTP requests",
|
||||
},
|
||||
[]string{"method", "path", "status"},
|
||||
)
|
||||
|
||||
// HTTP请求耗时直方图
|
||||
httpDuration := prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: "http_request_duration_seconds",
|
||||
Help: "HTTP request duration in seconds",
|
||||
Buckets: prometheus.DefBuckets,
|
||||
},
|
||||
[]string{"method", "path"},
|
||||
)
|
||||
|
||||
// 活跃用户数
|
||||
activeUsers := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "active_users_total",
|
||||
Help: "Current number of active users",
|
||||
},
|
||||
)
|
||||
|
||||
// 数据库连接数
|
||||
dbConnections := prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Name: "database_connections_active",
|
||||
Help: "Current number of active database connections",
|
||||
},
|
||||
)
|
||||
|
||||
// 缓存命中率
|
||||
cacheHits := prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Name: "cache_operations_total",
|
||||
Help: "Total number of cache operations",
|
||||
},
|
||||
[]string{"operation", "result"},
|
||||
)
|
||||
|
||||
// 注册指标
|
||||
registry.MustRegister(httpRequests)
|
||||
registry.MustRegister(httpDuration)
|
||||
registry.MustRegister(activeUsers)
|
||||
registry.MustRegister(dbConnections)
|
||||
registry.MustRegister(cacheHits)
|
||||
|
||||
return &PrometheusMetrics{
|
||||
logger: logger,
|
||||
registry: registry,
|
||||
httpRequests: httpRequests,
|
||||
httpDuration: httpDuration,
|
||||
activeUsers: activeUsers,
|
||||
dbConnections: dbConnections,
|
||||
cacheHits: cacheHits,
|
||||
businessMetrics: make(map[string]prometheus.Collector),
|
||||
}
|
||||
}
|
||||
|
||||
// RecordHTTPRequest 记录HTTP请求指标
|
||||
func (m *PrometheusMetrics) RecordHTTPRequest(method, path string, statusCode int, duration float64) {
|
||||
status := strconv.Itoa(statusCode)
|
||||
|
||||
m.httpRequests.WithLabelValues(method, path, status).Inc()
|
||||
m.httpDuration.WithLabelValues(method, path).Observe(duration)
|
||||
|
||||
m.logger.Debug("Recorded HTTP request metric",
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.String("status", status),
|
||||
zap.Float64("duration", duration))
|
||||
}
|
||||
|
||||
// RecordHTTPDuration 记录HTTP请求耗时
|
||||
func (m *PrometheusMetrics) RecordHTTPDuration(method, path string, duration float64) {
|
||||
m.httpDuration.WithLabelValues(method, path).Observe(duration)
|
||||
|
||||
m.logger.Debug("Recorded HTTP duration metric",
|
||||
zap.String("method", method),
|
||||
zap.String("path", path),
|
||||
zap.Float64("duration", duration))
|
||||
}
|
||||
|
||||
// IncrementCounter 增加计数器
|
||||
func (m *PrometheusMetrics) IncrementCounter(name string, labels map[string]string) {
|
||||
if counter, exists := m.getOrCreateCounter(name, labels); exists {
|
||||
if vec, ok := counter.(*prometheus.CounterVec); ok {
|
||||
vec.With(labels).Inc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordGauge 记录仪表盘值
|
||||
func (m *PrometheusMetrics) RecordGauge(name string, value float64, labels map[string]string) {
|
||||
if gauge, exists := m.getOrCreateGauge(name, labels); exists {
|
||||
if vec, ok := gauge.(*prometheus.GaugeVec); ok {
|
||||
vec.With(labels).Set(value)
|
||||
} else if g, ok := gauge.(prometheus.Gauge); ok {
|
||||
g.Set(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordHistogram 记录直方图值
|
||||
func (m *PrometheusMetrics) RecordHistogram(name string, value float64, labels map[string]string) {
|
||||
if histogram, exists := m.getOrCreateHistogram(name, labels); exists {
|
||||
if vec, ok := histogram.(*prometheus.HistogramVec); ok {
|
||||
vec.With(labels).Observe(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterCounter 注册计数器
|
||||
func (m *PrometheusMetrics) RegisterCounter(name, help string, labels []string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, exists := m.businessMetrics[name]; exists {
|
||||
return nil // 已存在
|
||||
}
|
||||
|
||||
var counter prometheus.Collector
|
||||
if len(labels) > 0 {
|
||||
counter = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{Name: name, Help: help},
|
||||
labels,
|
||||
)
|
||||
} else {
|
||||
counter = prometheus.NewCounter(
|
||||
prometheus.CounterOpts{Name: name, Help: help},
|
||||
)
|
||||
}
|
||||
|
||||
if err := m.registry.Register(counter); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.businessMetrics[name] = counter
|
||||
m.logger.Info("Registered counter metric", zap.String("name", name))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterGauge 注册仪表盘
|
||||
func (m *PrometheusMetrics) RegisterGauge(name, help string, labels []string) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, exists := m.businessMetrics[name]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
var gauge prometheus.Collector
|
||||
if len(labels) > 0 {
|
||||
gauge = prometheus.NewGaugeVec(
|
||||
prometheus.GaugeOpts{Name: name, Help: help},
|
||||
labels,
|
||||
)
|
||||
} else {
|
||||
gauge = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{Name: name, Help: help},
|
||||
)
|
||||
}
|
||||
|
||||
if err := m.registry.Register(gauge); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.businessMetrics[name] = gauge
|
||||
m.logger.Info("Registered gauge metric", zap.String("name", name))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterHistogram 注册直方图
|
||||
func (m *PrometheusMetrics) RegisterHistogram(name, help string, labels []string, buckets []float64) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if _, exists := m.businessMetrics[name]; exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
if buckets == nil {
|
||||
buckets = prometheus.DefBuckets
|
||||
}
|
||||
|
||||
var histogram prometheus.Collector
|
||||
if len(labels) > 0 {
|
||||
histogram = prometheus.NewHistogramVec(
|
||||
prometheus.HistogramOpts{
|
||||
Name: name,
|
||||
Help: help,
|
||||
Buckets: buckets,
|
||||
},
|
||||
labels,
|
||||
)
|
||||
} else {
|
||||
histogram = prometheus.NewHistogram(
|
||||
prometheus.HistogramOpts{
|
||||
Name: name,
|
||||
Help: help,
|
||||
Buckets: buckets,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if err := m.registry.Register(histogram); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.businessMetrics[name] = histogram
|
||||
m.logger.Info("Registered histogram metric", zap.String("name", name))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetHandler 获取HTTP处理器
|
||||
func (m *PrometheusMetrics) GetHandler() http.Handler {
|
||||
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
|
||||
}
|
||||
|
||||
// 内部辅助方法
|
||||
|
||||
func (m *PrometheusMetrics) getOrCreateCounter(name string, labels map[string]string) (prometheus.Collector, bool) {
|
||||
m.mutex.RLock()
|
||||
counter, exists := m.businessMetrics[name]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// 自动创建计数器
|
||||
labelNames := make([]string, 0, len(labels))
|
||||
for k := range labels {
|
||||
labelNames = append(labelNames, k)
|
||||
}
|
||||
|
||||
if err := m.RegisterCounter(name, "Auto-created counter", labelNames); err != nil {
|
||||
m.logger.Error("Failed to auto-create counter", zap.String("name", name), zap.Error(err))
|
||||
return nil, false
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
counter, exists = m.businessMetrics[name]
|
||||
m.mutex.RUnlock()
|
||||
}
|
||||
|
||||
return counter, exists
|
||||
}
|
||||
|
||||
func (m *PrometheusMetrics) getOrCreateGauge(name string, labels map[string]string) (prometheus.Collector, bool) {
|
||||
m.mutex.RLock()
|
||||
gauge, exists := m.businessMetrics[name]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
labelNames := make([]string, 0, len(labels))
|
||||
for k := range labels {
|
||||
labelNames = append(labelNames, k)
|
||||
}
|
||||
|
||||
if err := m.RegisterGauge(name, "Auto-created gauge", labelNames); err != nil {
|
||||
m.logger.Error("Failed to auto-create gauge", zap.String("name", name), zap.Error(err))
|
||||
return nil, false
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
gauge, exists = m.businessMetrics[name]
|
||||
m.mutex.RUnlock()
|
||||
}
|
||||
|
||||
return gauge, exists
|
||||
}
|
||||
|
||||
func (m *PrometheusMetrics) getOrCreateHistogram(name string, labels map[string]string) (prometheus.Collector, bool) {
|
||||
m.mutex.RLock()
|
||||
histogram, exists := m.businessMetrics[name]
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !exists {
|
||||
labelNames := make([]string, 0, len(labels))
|
||||
for k := range labels {
|
||||
labelNames = append(labelNames, k)
|
||||
}
|
||||
|
||||
if err := m.RegisterHistogram(name, "Auto-created histogram", labelNames, nil); err != nil {
|
||||
m.logger.Error("Failed to auto-create histogram", zap.String("name", name), zap.Error(err))
|
||||
return nil, false
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
histogram, exists = m.businessMetrics[name]
|
||||
m.mutex.RUnlock()
|
||||
}
|
||||
|
||||
return histogram, exists
|
||||
}
|
||||
|
||||
// UpdateActiveUsers 更新活跃用户数
|
||||
func (m *PrometheusMetrics) UpdateActiveUsers(count float64) {
|
||||
m.activeUsers.Set(count)
|
||||
}
|
||||
|
||||
// UpdateDBConnections 更新数据库连接数
|
||||
func (m *PrometheusMetrics) UpdateDBConnections(count float64) {
|
||||
m.dbConnections.Set(count)
|
||||
}
|
||||
|
||||
// RecordCacheOperation 记录缓存操作
|
||||
func (m *PrometheusMetrics) RecordCacheOperation(operation, result string) {
|
||||
m.cacheHits.WithLabelValues(operation, result).Inc()
|
||||
}
|
||||
|
||||
// GetStats 获取指标统计
|
||||
func (m *PrometheusMetrics) GetStats() map[string]interface{} {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"registered_metrics": len(m.businessMetrics),
|
||||
}
|
||||
}
|
||||
@@ -42,31 +42,31 @@ func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
|
||||
// 获取Authorization头部
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
m.respondUnauthorized(c, "Missing authorization header")
|
||||
m.respondUnauthorized(c, "缺少认证头部")
|
||||
return
|
||||
}
|
||||
|
||||
// 检查Bearer前缀
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
m.respondUnauthorized(c, "Invalid authorization header format")
|
||||
m.respondUnauthorized(c, "认证头部格式无效")
|
||||
return
|
||||
}
|
||||
|
||||
// 提取token
|
||||
tokenString := authHeader[len(bearerPrefix):]
|
||||
if tokenString == "" {
|
||||
m.respondUnauthorized(c, "Missing token")
|
||||
m.respondUnauthorized(c, "缺少认证令牌")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证token
|
||||
claims, err := m.validateToken(tokenString)
|
||||
if err != nil {
|
||||
m.logger.Warn("Invalid token",
|
||||
m.logger.Warn("无效的认证令牌",
|
||||
zap.Error(err),
|
||||
zap.String("request_id", c.GetString("request_id")))
|
||||
m.respondUnauthorized(c, "Invalid token")
|
||||
m.respondUnauthorized(c, "认证令牌无效")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -119,7 +119,7 @@ func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error
|
||||
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"success": false,
|
||||
"message": "Unauthorized",
|
||||
"message": "认证失败",
|
||||
"error": message,
|
||||
"request_id": c.GetString("request_id"),
|
||||
"timestamp": time.Now().Unix(),
|
||||
|
||||
@@ -2,11 +2,11 @@ package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/time/rate"
|
||||
@@ -15,14 +15,16 @@ import (
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
config *config.Config
|
||||
response interfaces.ResponseBuilder
|
||||
limiters map[string]*rate.Limiter
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
|
||||
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
config: cfg,
|
||||
response: response,
|
||||
limiters: make(map[string]*rate.Limiter),
|
||||
}
|
||||
}
|
||||
@@ -48,15 +50,13 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
|
||||
|
||||
// 检查是否允许请求
|
||||
if !limiter.Allow() {
|
||||
// 添加限流头部信息
|
||||
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
|
||||
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
|
||||
c.Header("Retry-After", "60")
|
||||
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"success": false,
|
||||
"message": "Rate limit exceeded",
|
||||
"error": "Too many requests",
|
||||
})
|
||||
// 使用统一的响应格式
|
||||
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
@@ -2,23 +2,35 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/tracing"
|
||||
)
|
||||
|
||||
// RequestLoggerMiddleware 请求日志中间件
|
||||
type RequestLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
logger *zap.Logger
|
||||
useColoredLog bool
|
||||
isDevelopment bool
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewRequestLoggerMiddleware 创建请求日志中间件
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
|
||||
func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
|
||||
return &RequestLoggerMiddleware{
|
||||
logger: logger,
|
||||
logger: logger,
|
||||
useColoredLog: isDevelopment, // 开发环境使用彩色日志
|
||||
isDevelopment: isDevelopment,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,24 +46,110 @@ func (m *RequestLoggerMiddleware) GetPriority() int {
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 使用zap logger记录请求信息
|
||||
m.logger.Info("HTTP Request",
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
)
|
||||
if m.useColoredLog {
|
||||
// 开发环境:使用Gin默认的彩色日志格式
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
var statusColor, methodColor, resetColor string
|
||||
if param.IsOutputColor() {
|
||||
statusColor = param.StatusCodeColor()
|
||||
methodColor = param.MethodColor()
|
||||
resetColor = param.ResetColor()
|
||||
}
|
||||
|
||||
// 返回空字符串,因为我们已经用zap记录了
|
||||
return ""
|
||||
})
|
||||
if param.Latency > time.Minute {
|
||||
param.Latency = param.Latency.Truncate(time.Second)
|
||||
}
|
||||
|
||||
// 获取TraceID
|
||||
traceID := param.Request.Header.Get("X-Trace-ID")
|
||||
if traceID == "" && m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(param.Request.Context())
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if param.StatusCode >= 400 && m.tracer != nil {
|
||||
span := trace.SpanFromContext(param.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
traceInfo := ""
|
||||
if traceID != "" {
|
||||
traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
|
||||
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
|
||||
statusColor, param.StatusCode, resetColor,
|
||||
param.Latency,
|
||||
param.ClientIP,
|
||||
methodColor, param.Method, resetColor,
|
||||
param.Path,
|
||||
traceInfo,
|
||||
param.ErrorMessage,
|
||||
)
|
||||
})
|
||||
} else {
|
||||
// 生产环境:使用结构化JSON日志
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
// 获取TraceID
|
||||
traceID := param.Request.Header.Get("X-Trace-ID")
|
||||
if traceID == "" && m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(param.Request.Context())
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if param.StatusCode >= 400 && m.tracer != nil {
|
||||
span := trace.SpanFromContext(param.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 对于服务器错误,记录更详细的日志
|
||||
if param.StatusCode >= 500 {
|
||||
m.logger.Error("服务器错误",
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("trace_id", traceID),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 记录请求日志
|
||||
logFields := []zap.Field{
|
||||
zap.String("client_ip", param.ClientIP),
|
||||
zap.String("method", param.Method),
|
||||
zap.String("path", param.Path),
|
||||
zap.String("protocol", param.Request.Proto),
|
||||
zap.Int("status_code", param.StatusCode),
|
||||
zap.Duration("latency", param.Latency),
|
||||
zap.String("user_agent", param.Request.UserAgent()),
|
||||
zap.Int("body_size", param.BodySize),
|
||||
zap.String("referer", param.Request.Referer()),
|
||||
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
|
||||
}
|
||||
|
||||
// 添加TraceID
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
m.logger.Info("HTTP请求", logFields...)
|
||||
return ""
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
@@ -102,6 +200,70 @@ func (m *RequestIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TraceIDMiddleware 追踪ID中间件
|
||||
type TraceIDMiddleware struct {
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewTraceIDMiddleware 创建追踪ID中间件
|
||||
func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
|
||||
return &TraceIDMiddleware{
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *TraceIDMiddleware) GetName() string {
|
||||
return "trace_id"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *TraceIDMiddleware) GetPriority() int {
|
||||
return 94 // 仅次于请求ID中间件
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// 获取或生成追踪ID
|
||||
traceID := m.tracer.GetTraceID(c.Request.Context())
|
||||
if traceID != "" {
|
||||
// 设置追踪ID到响应头
|
||||
c.Header("X-Trace-ID", traceID)
|
||||
// 添加到上下文
|
||||
c.Set("trace_id", traceID)
|
||||
}
|
||||
|
||||
// 检查是否为错误请求(例如URL不存在)
|
||||
c.Next()
|
||||
|
||||
// 请求完成后检查状态码
|
||||
if c.Writer.Status() >= 400 {
|
||||
// 获取当前span
|
||||
span := trace.SpanFromContext(c.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 设置错误上下文,以便后续span可以识别
|
||||
c.Request = c.Request.WithContext(context.WithValue(
|
||||
c.Request.Context(),
|
||||
"otel_error_request",
|
||||
true,
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *TraceIDMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SecurityHeadersMiddleware 安全头部中间件
|
||||
type SecurityHeadersMiddleware struct{}
|
||||
|
||||
@@ -183,13 +345,15 @@ func (m *ResponseTimeMiddleware) IsGlobal() bool {
|
||||
type RequestBodyLoggerMiddleware struct {
|
||||
logger *zap.Logger
|
||||
enable bool
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
|
||||
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
|
||||
return &RequestBodyLoggerMiddleware{
|
||||
logger: logger,
|
||||
enable: enable,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,13 +384,26 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
// 重新设置body供后续处理使用
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
// 获取追踪ID
|
||||
traceID := ""
|
||||
if m.tracer != nil {
|
||||
traceID = m.tracer.GetTraceID(c.Request.Context())
|
||||
}
|
||||
|
||||
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
|
||||
m.logger.Debug("Request Body",
|
||||
logFields := []zap.Field{
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.Request.URL.Path),
|
||||
zap.String("body", string(bodyBytes)),
|
||||
zap.String("request_id", c.GetString("request_id")),
|
||||
)
|
||||
}
|
||||
|
||||
// 添加追踪ID
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
|
||||
m.logger.Debug("请求体", logFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -239,3 +416,83 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
|
||||
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
|
||||
return false // 可选中间件,不是全局的
|
||||
}
|
||||
|
||||
// ErrorTrackingMiddleware 错误追踪中间件
|
||||
type ErrorTrackingMiddleware struct {
|
||||
logger *zap.Logger
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
// NewErrorTrackingMiddleware 创建错误追踪中间件
|
||||
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
|
||||
return &ErrorTrackingMiddleware{
|
||||
logger: logger,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 返回中间件名称
|
||||
func (m *ErrorTrackingMiddleware) GetName() string {
|
||||
return "error_tracking"
|
||||
}
|
||||
|
||||
// GetPriority 返回中间件优先级
|
||||
func (m *ErrorTrackingMiddleware) GetPriority() int {
|
||||
return 60 // 低优先级,在大多数中间件之后执行
|
||||
}
|
||||
|
||||
// Handle 返回中间件处理函数
|
||||
func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// 检查是否有错误
|
||||
if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
|
||||
// 获取当前span
|
||||
span := trace.SpanFromContext(c.Request.Context())
|
||||
if span.IsRecording() {
|
||||
// 标记为错误操作,确保100%采样
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 记录错误日志
|
||||
traceID := m.tracer.GetTraceID(c.Request.Context())
|
||||
spanID := m.tracer.GetSpanID(c.Request.Context())
|
||||
|
||||
logFields := []zap.Field{
|
||||
zap.Int("status_code", c.Writer.Status()),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.FullPath()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
}
|
||||
|
||||
// 添加追踪信息
|
||||
if traceID != "" {
|
||||
logFields = append(logFields, zap.String("trace_id", traceID))
|
||||
}
|
||||
if spanID != "" {
|
||||
logFields = append(logFields, zap.String("span_id", spanID))
|
||||
}
|
||||
|
||||
// 添加错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
logFields = append(logFields, zap.String("errors", c.Errors.String()))
|
||||
}
|
||||
|
||||
// 根据状态码决定日志级别
|
||||
if c.Writer.Status() >= 500 {
|
||||
m.logger.Error("服务器错误", logFields...)
|
||||
} else {
|
||||
m.logger.Warn("客户端错误", logFields...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsGlobal 是否为全局中间件
|
||||
func (m *ErrorTrackingMiddleware) IsGlobal() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
389
internal/shared/resilience/circuit_breaker.go
Normal file
389
internal/shared/resilience/circuit_breaker.go
Normal file
@@ -0,0 +1,389 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CircuitState 熔断器状态
|
||||
type CircuitState int
|
||||
|
||||
const (
|
||||
// StateClosed 关闭状态(正常)
|
||||
StateClosed CircuitState = iota
|
||||
// StateOpen 开启状态(熔断)
|
||||
StateOpen
|
||||
// StateHalfOpen 半开状态(测试)
|
||||
StateHalfOpen
|
||||
)
|
||||
|
||||
func (s CircuitState) String() string {
|
||||
switch s {
|
||||
case StateClosed:
|
||||
return "CLOSED"
|
||||
case StateOpen:
|
||||
return "OPEN"
|
||||
case StateHalfOpen:
|
||||
return "HALF_OPEN"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreakerConfig 熔断器配置
|
||||
type CircuitBreakerConfig struct {
|
||||
// 故障阈值
|
||||
FailureThreshold int
|
||||
// 重置超时时间
|
||||
ResetTimeout time.Duration
|
||||
// 检测窗口大小
|
||||
WindowSize int
|
||||
// 半开状态允许的请求数
|
||||
HalfOpenMaxRequests int
|
||||
// 成功阈值(半开->关闭)
|
||||
SuccessThreshold int
|
||||
}
|
||||
|
||||
// DefaultCircuitBreakerConfig 默认熔断器配置
|
||||
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
|
||||
return CircuitBreakerConfig{
|
||||
FailureThreshold: 5,
|
||||
ResetTimeout: 60 * time.Second,
|
||||
WindowSize: 10,
|
||||
HalfOpenMaxRequests: 3,
|
||||
SuccessThreshold: 2,
|
||||
}
|
||||
}
|
||||
|
||||
// CircuitBreaker 熔断器
|
||||
type CircuitBreaker struct {
|
||||
config CircuitBreakerConfig
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
|
||||
// 状态
|
||||
state CircuitState
|
||||
|
||||
// 计数器
|
||||
failures int
|
||||
successes int
|
||||
requests int
|
||||
consecutiveFailures int
|
||||
|
||||
// 时间记录
|
||||
lastFailTime time.Time
|
||||
lastStateChange time.Time
|
||||
|
||||
// 统计窗口
|
||||
window []bool // true=success, false=failure
|
||||
windowIndex int
|
||||
windowFull bool
|
||||
|
||||
// 事件回调
|
||||
onStateChange func(from, to CircuitState)
|
||||
}
|
||||
|
||||
// NewCircuitBreaker 创建熔断器
|
||||
func NewCircuitBreaker(config CircuitBreakerConfig, logger *zap.Logger) *CircuitBreaker {
|
||||
cb := &CircuitBreaker{
|
||||
config: config,
|
||||
logger: logger,
|
||||
state: StateClosed,
|
||||
window: make([]bool, config.WindowSize),
|
||||
lastStateChange: time.Now(),
|
||||
}
|
||||
|
||||
return cb
|
||||
}
|
||||
|
||||
// Execute 执行函数,如果熔断器开启则快速失败
|
||||
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
|
||||
// 检查是否允许执行
|
||||
if !cb.allowRequest() {
|
||||
return ErrCircuitBreakerOpen
|
||||
}
|
||||
|
||||
// 执行函数
|
||||
start := time.Now()
|
||||
err := fn()
|
||||
duration := time.Since(start)
|
||||
|
||||
// 记录结果
|
||||
cb.recordResult(err == nil, duration)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// allowRequest 检查是否允许请求
|
||||
func (cb *CircuitBreaker) allowRequest() bool {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
|
||||
switch cb.state {
|
||||
case StateClosed:
|
||||
return true
|
||||
|
||||
case StateOpen:
|
||||
// 检查是否到了重置时间
|
||||
if now.Sub(cb.lastStateChange) > cb.config.ResetTimeout {
|
||||
cb.setState(StateHalfOpen)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
|
||||
case StateHalfOpen:
|
||||
// 半开状态下限制请求数
|
||||
return cb.requests < cb.config.HalfOpenMaxRequests
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// recordResult 记录执行结果
|
||||
func (cb *CircuitBreaker) recordResult(success bool, duration time.Duration) {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.requests++
|
||||
|
||||
// 更新滑动窗口
|
||||
cb.updateWindow(success)
|
||||
|
||||
if success {
|
||||
cb.successes++
|
||||
cb.consecutiveFailures = 0
|
||||
cb.onSuccess()
|
||||
} else {
|
||||
cb.failures++
|
||||
cb.consecutiveFailures++
|
||||
cb.lastFailTime = time.Now()
|
||||
cb.onFailure()
|
||||
}
|
||||
|
||||
cb.logger.Debug("Circuit breaker recorded result",
|
||||
zap.Bool("success", success),
|
||||
zap.Duration("duration", duration),
|
||||
zap.String("state", cb.state.String()),
|
||||
zap.Int("failures", cb.failures),
|
||||
zap.Int("successes", cb.successes))
|
||||
}
|
||||
|
||||
// updateWindow 更新滑动窗口
|
||||
func (cb *CircuitBreaker) updateWindow(success bool) {
|
||||
cb.window[cb.windowIndex] = success
|
||||
cb.windowIndex = (cb.windowIndex + 1) % cb.config.WindowSize
|
||||
|
||||
if cb.windowIndex == 0 {
|
||||
cb.windowFull = true
|
||||
}
|
||||
}
|
||||
|
||||
// onSuccess 成功时的处理
|
||||
func (cb *CircuitBreaker) onSuccess() {
|
||||
if cb.state == StateHalfOpen {
|
||||
// 半开状态下,如果成功次数达到阈值,则关闭熔断器
|
||||
if cb.successes >= cb.config.SuccessThreshold {
|
||||
cb.setState(StateClosed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// onFailure 失败时的处理
|
||||
func (cb *CircuitBreaker) onFailure() {
|
||||
if cb.state == StateClosed {
|
||||
// 关闭状态下,检查是否需要开启熔断器
|
||||
if cb.shouldTrip() {
|
||||
cb.setState(StateOpen)
|
||||
}
|
||||
} else if cb.state == StateHalfOpen {
|
||||
// 半开状态下,如果失败则立即开启熔断器
|
||||
cb.setState(StateOpen)
|
||||
}
|
||||
}
|
||||
|
||||
// shouldTrip 检查是否应该触发熔断
|
||||
func (cb *CircuitBreaker) shouldTrip() bool {
|
||||
// 基于连续失败次数
|
||||
if cb.consecutiveFailures >= cb.config.FailureThreshold {
|
||||
return true
|
||||
}
|
||||
|
||||
// 基于滑动窗口的失败率
|
||||
if cb.windowFull {
|
||||
failures := 0
|
||||
for _, success := range cb.window {
|
||||
if !success {
|
||||
failures++
|
||||
}
|
||||
}
|
||||
|
||||
failureRate := float64(failures) / float64(cb.config.WindowSize)
|
||||
return failureRate >= 0.5 // 50%失败率
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// setState 设置状态
|
||||
func (cb *CircuitBreaker) setState(newState CircuitState) {
|
||||
if cb.state == newState {
|
||||
return
|
||||
}
|
||||
|
||||
oldState := cb.state
|
||||
cb.state = newState
|
||||
cb.lastStateChange = time.Now()
|
||||
|
||||
// 重置计数器
|
||||
if newState == StateClosed {
|
||||
cb.requests = 0
|
||||
cb.failures = 0
|
||||
cb.successes = 0
|
||||
cb.consecutiveFailures = 0
|
||||
} else if newState == StateHalfOpen {
|
||||
cb.requests = 0
|
||||
cb.successes = 0
|
||||
}
|
||||
|
||||
cb.logger.Info("Circuit breaker state changed",
|
||||
zap.String("from", oldState.String()),
|
||||
zap.String("to", newState.String()),
|
||||
zap.Int("failures", cb.failures),
|
||||
zap.Int("successes", cb.successes))
|
||||
|
||||
// 触发状态变更回调
|
||||
if cb.onStateChange != nil {
|
||||
cb.onStateChange(oldState, newState)
|
||||
}
|
||||
}
|
||||
|
||||
// GetState 获取当前状态
|
||||
func (cb *CircuitBreaker) GetState() CircuitState {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
return cb.state
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
|
||||
cb.mutex.RLock()
|
||||
defer cb.mutex.RUnlock()
|
||||
|
||||
return CircuitBreakerStats{
|
||||
State: cb.state.String(),
|
||||
Failures: cb.failures,
|
||||
Successes: cb.successes,
|
||||
Requests: cb.requests,
|
||||
ConsecutiveFailures: cb.consecutiveFailures,
|
||||
LastFailTime: cb.lastFailTime,
|
||||
LastStateChange: cb.lastStateChange,
|
||||
FailureThreshold: cb.config.FailureThreshold,
|
||||
ResetTimeout: cb.config.ResetTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// Reset 重置熔断器
|
||||
func (cb *CircuitBreaker) Reset() {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
|
||||
cb.setState(StateClosed)
|
||||
cb.window = make([]bool, cb.config.WindowSize)
|
||||
cb.windowIndex = 0
|
||||
cb.windowFull = false
|
||||
|
||||
cb.logger.Info("Circuit breaker reset")
|
||||
}
|
||||
|
||||
// SetStateChangeCallback 设置状态变更回调
|
||||
func (cb *CircuitBreaker) SetStateChangeCallback(callback func(from, to CircuitState)) {
|
||||
cb.mutex.Lock()
|
||||
defer cb.mutex.Unlock()
|
||||
cb.onStateChange = callback
|
||||
}
|
||||
|
||||
// CircuitBreakerStats 熔断器统计信息
|
||||
type CircuitBreakerStats struct {
|
||||
State string `json:"state"`
|
||||
Failures int `json:"failures"`
|
||||
Successes int `json:"successes"`
|
||||
Requests int `json:"requests"`
|
||||
ConsecutiveFailures int `json:"consecutive_failures"`
|
||||
LastFailTime time.Time `json:"last_fail_time"`
|
||||
LastStateChange time.Time `json:"last_state_change"`
|
||||
FailureThreshold int `json:"failure_threshold"`
|
||||
ResetTimeout time.Duration `json:"reset_timeout"`
|
||||
}
|
||||
|
||||
// 预定义错误
|
||||
var (
|
||||
ErrCircuitBreakerOpen = errors.New("circuit breaker is open")
|
||||
)
|
||||
|
||||
// Wrapper 熔断器包装器
|
||||
type Wrapper struct {
|
||||
breakers map[string]*CircuitBreaker
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewWrapper 创建熔断器包装器
|
||||
func NewWrapper(logger *zap.Logger) *Wrapper {
|
||||
return &Wrapper{
|
||||
breakers: make(map[string]*CircuitBreaker),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreate 获取或创建熔断器
|
||||
func (w *Wrapper) GetOrCreate(name string, config CircuitBreakerConfig) *CircuitBreaker {
|
||||
w.mutex.Lock()
|
||||
defer w.mutex.Unlock()
|
||||
|
||||
if cb, exists := w.breakers[name]; exists {
|
||||
return cb
|
||||
}
|
||||
|
||||
cb := NewCircuitBreaker(config, w.logger.Named(name))
|
||||
w.breakers[name] = cb
|
||||
|
||||
w.logger.Info("Created circuit breaker", zap.String("name", name))
|
||||
return cb
|
||||
}
|
||||
|
||||
// Execute 执行带熔断器的函数
|
||||
func (w *Wrapper) Execute(ctx context.Context, name string, fn func() error) error {
|
||||
cb := w.GetOrCreate(name, DefaultCircuitBreakerConfig())
|
||||
return cb.Execute(ctx, fn)
|
||||
}
|
||||
|
||||
// GetStats 获取所有熔断器统计
|
||||
func (w *Wrapper) GetStats() map[string]CircuitBreakerStats {
|
||||
w.mutex.RLock()
|
||||
defer w.mutex.RUnlock()
|
||||
|
||||
stats := make(map[string]CircuitBreakerStats)
|
||||
for name, cb := range w.breakers {
|
||||
stats[name] = cb.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// ResetAll 重置所有熔断器
|
||||
func (w *Wrapper) ResetAll() {
|
||||
w.mutex.RLock()
|
||||
defer w.mutex.RUnlock()
|
||||
|
||||
for name, cb := range w.breakers {
|
||||
cb.Reset()
|
||||
w.logger.Info("Reset circuit breaker", zap.String("name", name))
|
||||
}
|
||||
}
|
||||
467
internal/shared/resilience/retry.go
Normal file
467
internal/shared/resilience/retry.go
Normal file
@@ -0,0 +1,467 @@
|
||||
package resilience
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// RetryConfig 重试配置
|
||||
type RetryConfig struct {
|
||||
// 最大重试次数
|
||||
MaxAttempts int
|
||||
// 初始延迟
|
||||
InitialDelay time.Duration
|
||||
// 最大延迟
|
||||
MaxDelay time.Duration
|
||||
// 退避倍数
|
||||
BackoffMultiplier float64
|
||||
// 抖动系数
|
||||
JitterFactor float64
|
||||
// 重试条件
|
||||
RetryCondition func(error) bool
|
||||
// 延迟函数
|
||||
DelayFunc func(attempt int, config RetryConfig) time.Duration
|
||||
}
|
||||
|
||||
// DefaultRetryConfig 默认重试配置
|
||||
func DefaultRetryConfig() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterFactor: 0.1,
|
||||
RetryCondition: DefaultRetryCondition,
|
||||
DelayFunc: ExponentialBackoffWithJitter,
|
||||
}
|
||||
}
|
||||
|
||||
// RetryableError 可重试错误接口
|
||||
type RetryableError interface {
|
||||
error
|
||||
IsRetryable() bool
|
||||
}
|
||||
|
||||
// DefaultRetryCondition 默认重试条件
|
||||
func DefaultRetryCondition(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否实现了RetryableError接口
|
||||
if retryable, ok := err.(RetryableError); ok {
|
||||
return retryable.IsRetryable()
|
||||
}
|
||||
|
||||
// 默认所有错误都重试
|
||||
return true
|
||||
}
|
||||
|
||||
// IsRetryableHTTPError HTTP错误重试条件
|
||||
func IsRetryableHTTPError(statusCode int) bool {
|
||||
// 5xx错误通常可以重试
|
||||
// 429(Too Many Requests)也可以重试
|
||||
return statusCode >= 500 || statusCode == 429
|
||||
}
|
||||
|
||||
// DelayFunction 延迟函数类型
|
||||
type DelayFunction func(attempt int, config RetryConfig) time.Duration
|
||||
|
||||
// FixedDelay 固定延迟
|
||||
func FixedDelay(attempt int, config RetryConfig) time.Duration {
|
||||
return config.InitialDelay
|
||||
}
|
||||
|
||||
// LinearBackoff 线性退避
|
||||
func LinearBackoff(attempt int, config RetryConfig) time.Duration {
|
||||
delay := time.Duration(attempt) * config.InitialDelay
|
||||
if delay > config.MaxDelay {
|
||||
delay = config.MaxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
// ExponentialBackoff 指数退避
|
||||
func ExponentialBackoff(attempt int, config RetryConfig) time.Duration {
|
||||
delay := config.InitialDelay
|
||||
for i := 0; i < attempt; i++ {
|
||||
delay = time.Duration(float64(delay) * config.BackoffMultiplier)
|
||||
}
|
||||
|
||||
if delay > config.MaxDelay {
|
||||
delay = config.MaxDelay
|
||||
}
|
||||
|
||||
return delay
|
||||
}
|
||||
|
||||
// ExponentialBackoffWithJitter 带抖动的指数退避
|
||||
func ExponentialBackoffWithJitter(attempt int, config RetryConfig) time.Duration {
|
||||
delay := ExponentialBackoff(attempt, config)
|
||||
|
||||
// 添加抖动
|
||||
jitter := config.JitterFactor
|
||||
if jitter > 0 {
|
||||
jitterRange := float64(delay) * jitter
|
||||
jitterOffset := (rand.Float64() - 0.5) * 2 * jitterRange
|
||||
delay = time.Duration(float64(delay) + jitterOffset)
|
||||
}
|
||||
|
||||
if delay < 0 {
|
||||
delay = config.InitialDelay
|
||||
}
|
||||
|
||||
return delay
|
||||
}
|
||||
|
||||
// RetryStats 重试统计
|
||||
type RetryStats struct {
|
||||
TotalAttempts int `json:"total_attempts"`
|
||||
Successes int `json:"successes"`
|
||||
Failures int `json:"failures"`
|
||||
TotalRetries int `json:"total_retries"`
|
||||
AverageAttempts float64 `json:"average_attempts"`
|
||||
TotalDelay time.Duration `json:"total_delay"`
|
||||
LastError string `json:"last_error,omitempty"`
|
||||
}
|
||||
|
||||
// Retryer 重试器
|
||||
type Retryer struct {
|
||||
config RetryConfig
|
||||
logger *zap.Logger
|
||||
stats RetryStats
|
||||
}
|
||||
|
||||
// NewRetryer 创建重试器
|
||||
func NewRetryer(config RetryConfig, logger *zap.Logger) *Retryer {
|
||||
if config.DelayFunc == nil {
|
||||
config.DelayFunc = ExponentialBackoffWithJitter
|
||||
}
|
||||
if config.RetryCondition == nil {
|
||||
config.RetryCondition = DefaultRetryCondition
|
||||
}
|
||||
|
||||
return &Retryer{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute 执行带重试的函数
|
||||
func (r *Retryer) Execute(ctx context.Context, operation func() error) error {
|
||||
return r.ExecuteWithResult(ctx, func() (interface{}, error) {
|
||||
return nil, operation()
|
||||
})
|
||||
}
|
||||
|
||||
// ExecuteWithResult 执行带重试和返回值的函数
|
||||
func (r *Retryer) ExecuteWithResult(ctx context.Context, operation func() (interface{}, error)) error {
|
||||
var lastErr error
|
||||
startTime := time.Now()
|
||||
|
||||
for attempt := 0; attempt < r.config.MaxAttempts; attempt++ {
|
||||
// 检查上下文是否被取消
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// 执行操作
|
||||
attemptStart := time.Now()
|
||||
_, err := operation()
|
||||
attemptDuration := time.Since(attemptStart)
|
||||
|
||||
// 更新统计
|
||||
r.stats.TotalAttempts++
|
||||
if err == nil {
|
||||
r.stats.Successes++
|
||||
r.logger.Debug("Operation succeeded",
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Duration("duration", attemptDuration))
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
r.stats.Failures++
|
||||
if attempt > 0 {
|
||||
r.stats.TotalRetries++
|
||||
}
|
||||
|
||||
// 检查是否应该重试
|
||||
if !r.config.RetryCondition(err) {
|
||||
r.logger.Debug("Error is not retryable",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1))
|
||||
break
|
||||
}
|
||||
|
||||
// 如果这是最后一次尝试,不需要延迟
|
||||
if attempt == r.config.MaxAttempts-1 {
|
||||
r.logger.Debug("Reached max attempts",
|
||||
zap.Error(err),
|
||||
zap.Int("max_attempts", r.config.MaxAttempts))
|
||||
break
|
||||
}
|
||||
|
||||
// 计算延迟
|
||||
delay := r.config.DelayFunc(attempt, r.config)
|
||||
r.stats.TotalDelay += delay
|
||||
|
||||
r.logger.Debug("Operation failed, retrying",
|
||||
zap.Error(err),
|
||||
zap.Int("attempt", attempt+1),
|
||||
zap.Duration("delay", delay),
|
||||
zap.Duration("attempt_duration", attemptDuration))
|
||||
|
||||
// 等待重试
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
}
|
||||
}
|
||||
|
||||
// 更新最终统计
|
||||
totalDuration := time.Since(startTime)
|
||||
if r.stats.TotalAttempts > 0 {
|
||||
r.stats.AverageAttempts = float64(r.stats.TotalRetries) / float64(r.stats.Successes+r.stats.Failures)
|
||||
}
|
||||
if lastErr != nil {
|
||||
r.stats.LastError = lastErr.Error()
|
||||
}
|
||||
|
||||
r.logger.Warn("Operation failed after all retries",
|
||||
zap.Error(lastErr),
|
||||
zap.Int("total_attempts", r.stats.TotalAttempts),
|
||||
zap.Duration("total_duration", totalDuration))
|
||||
|
||||
return fmt.Errorf("operation failed after %d attempts: %w", r.config.MaxAttempts, lastErr)
|
||||
}
|
||||
|
||||
// GetStats 获取重试统计
|
||||
func (r *Retryer) GetStats() RetryStats {
|
||||
return r.stats
|
||||
}
|
||||
|
||||
// Reset 重置统计
|
||||
func (r *Retryer) Reset() {
|
||||
r.stats = RetryStats{}
|
||||
r.logger.Debug("Retry stats reset")
|
||||
}
|
||||
|
||||
// Retry 简单重试函数
|
||||
func Retry(ctx context.Context, config RetryConfig, operation func() error) error {
|
||||
retryer := NewRetryer(config, zap.NewNop())
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// RetryWithResult 带返回值的重试函数
|
||||
func RetryWithResult[T any](ctx context.Context, config RetryConfig, operation func() (T, error)) (T, error) {
|
||||
var result T
|
||||
var finalErr error
|
||||
|
||||
retryer := NewRetryer(config, zap.NewNop())
|
||||
err := retryer.ExecuteWithResult(ctx, func() (interface{}, error) {
|
||||
r, e := operation()
|
||||
result = r
|
||||
return r, e
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
finalErr = err
|
||||
}
|
||||
|
||||
return result, finalErr
|
||||
}
|
||||
|
||||
// 预定义的重试配置
|
||||
|
||||
// QuickRetry 快速重试(适用于轻量级操作)
|
||||
func QuickRetry() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 50 * time.Millisecond,
|
||||
MaxDelay: 500 * time.Millisecond,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterFactor: 0.1,
|
||||
RetryCondition: DefaultRetryCondition,
|
||||
DelayFunc: ExponentialBackoffWithJitter,
|
||||
}
|
||||
}
|
||||
|
||||
// StandardRetry 标准重试(适用于一般操作)
|
||||
func StandardRetry() RetryConfig {
|
||||
return DefaultRetryConfig()
|
||||
}
|
||||
|
||||
// PatientRetry 耐心重试(适用于重要操作)
|
||||
func PatientRetry() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 5,
|
||||
InitialDelay: 200 * time.Millisecond,
|
||||
MaxDelay: 10 * time.Second,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterFactor: 0.2,
|
||||
RetryCondition: DefaultRetryCondition,
|
||||
DelayFunc: ExponentialBackoffWithJitter,
|
||||
}
|
||||
}
|
||||
|
||||
// DatabaseRetry 数据库重试配置
|
||||
func DatabaseRetry() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 2 * time.Second,
|
||||
BackoffMultiplier: 1.5,
|
||||
JitterFactor: 0.1,
|
||||
RetryCondition: func(err error) bool {
|
||||
// 这里可以根据具体的数据库错误类型判断
|
||||
// 例如:连接超时、临时网络错误等
|
||||
return DefaultRetryCondition(err)
|
||||
},
|
||||
DelayFunc: ExponentialBackoffWithJitter,
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPRetry HTTP重试配置
|
||||
func HTTPRetry() RetryConfig {
|
||||
return RetryConfig{
|
||||
MaxAttempts: 3,
|
||||
InitialDelay: 200 * time.Millisecond,
|
||||
MaxDelay: 5 * time.Second,
|
||||
BackoffMultiplier: 2.0,
|
||||
JitterFactor: 0.15,
|
||||
RetryCondition: func(err error) bool {
|
||||
// HTTP相关的重试条件
|
||||
return DefaultRetryCondition(err)
|
||||
},
|
||||
DelayFunc: ExponentialBackoffWithJitter,
|
||||
}
|
||||
}
|
||||
|
||||
// RetryManager 重试管理器
|
||||
type RetryManager struct {
|
||||
retryers map[string]*Retryer
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRetryManager 创建重试管理器
|
||||
func NewRetryManager(logger *zap.Logger) *RetryManager {
|
||||
return &RetryManager{
|
||||
retryers: make(map[string]*Retryer),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// GetOrCreate 获取或创建重试器
|
||||
func (rm *RetryManager) GetOrCreate(name string, config RetryConfig) *Retryer {
|
||||
rm.mutex.Lock()
|
||||
defer rm.mutex.Unlock()
|
||||
|
||||
if retryer, exists := rm.retryers[name]; exists {
|
||||
return retryer
|
||||
}
|
||||
|
||||
retryer := NewRetryer(config, rm.logger.Named(name))
|
||||
rm.retryers[name] = retryer
|
||||
|
||||
rm.logger.Info("Created retryer", zap.String("name", name))
|
||||
return retryer
|
||||
}
|
||||
|
||||
// Execute 执行带重试的操作
|
||||
func (rm *RetryManager) Execute(ctx context.Context, name string, operation func() error) error {
|
||||
retryer := rm.GetOrCreate(name, DefaultRetryConfig())
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// GetStats 获取所有重试器统计
|
||||
func (rm *RetryManager) GetStats() map[string]RetryStats {
|
||||
rm.mutex.RLock()
|
||||
defer rm.mutex.RUnlock()
|
||||
|
||||
stats := make(map[string]RetryStats)
|
||||
for name, retryer := range rm.retryers {
|
||||
stats[name] = retryer.GetStats()
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
// ResetAll 重置所有重试器统计
|
||||
func (rm *RetryManager) ResetAll() {
|
||||
rm.mutex.RLock()
|
||||
defer rm.mutex.RUnlock()
|
||||
|
||||
for name, retryer := range rm.retryers {
|
||||
retryer.Reset()
|
||||
rm.logger.Info("Reset retryer stats", zap.String("name", name))
|
||||
}
|
||||
}
|
||||
|
||||
// RetryerWrapper 重试器包装器
|
||||
type RetryerWrapper struct {
|
||||
manager *RetryManager
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewRetryerWrapper 创建重试器包装器
|
||||
func NewRetryerWrapper(logger *zap.Logger) *RetryerWrapper {
|
||||
return &RetryerWrapper{
|
||||
manager: NewRetryManager(logger),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteWithQuickRetry 执行快速重试
|
||||
func (rw *RetryerWrapper) ExecuteWithQuickRetry(ctx context.Context, name string, operation func() error) error {
|
||||
retryer := rw.manager.GetOrCreate(name+".quick", QuickRetry())
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// ExecuteWithStandardRetry 执行标准重试
|
||||
func (rw *RetryerWrapper) ExecuteWithStandardRetry(ctx context.Context, name string, operation func() error) error {
|
||||
retryer := rw.manager.GetOrCreate(name+".standard", StandardRetry())
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// ExecuteWithDatabaseRetry 执行数据库重试
|
||||
func (rw *RetryerWrapper) ExecuteWithDatabaseRetry(ctx context.Context, name string, operation func() error) error {
|
||||
retryer := rw.manager.GetOrCreate(name+".database", DatabaseRetry())
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// ExecuteWithHTTPRetry 执行HTTP重试
|
||||
func (rw *RetryerWrapper) ExecuteWithHTTPRetry(ctx context.Context, name string, operation func() error) error {
|
||||
retryer := rw.manager.GetOrCreate(name+".http", HTTPRetry())
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// ExecuteWithCustomRetry 执行自定义重试
|
||||
func (rw *RetryerWrapper) ExecuteWithCustomRetry(ctx context.Context, name string, config RetryConfig, operation func() error) error {
|
||||
retryer := rw.manager.GetOrCreate(name+".custom", config)
|
||||
return retryer.Execute(ctx, operation)
|
||||
}
|
||||
|
||||
// GetManager 获取重试管理器
|
||||
func (rw *RetryerWrapper) GetManager() *RetryManager {
|
||||
return rw.manager
|
||||
}
|
||||
|
||||
// GetAllStats 获取所有统计信息
|
||||
func (rw *RetryerWrapper) GetAllStats() map[string]RetryStats {
|
||||
return rw.manager.GetStats()
|
||||
}
|
||||
|
||||
// ResetAllStats 重置所有统计信息
|
||||
func (rw *RetryerWrapper) ResetAllStats() {
|
||||
rw.manager.ResetAll()
|
||||
}
|
||||
612
internal/shared/saga/saga.go
Normal file
612
internal/shared/saga/saga.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package saga
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// SagaStatus Saga状态
|
||||
type SagaStatus int
|
||||
|
||||
const (
|
||||
// StatusPending 等待中
|
||||
StatusPending SagaStatus = iota
|
||||
// StatusRunning 执行中
|
||||
StatusRunning
|
||||
// StatusCompleted 已完成
|
||||
StatusCompleted
|
||||
// StatusFailed 失败
|
||||
StatusFailed
|
||||
// StatusCompensating 补偿中
|
||||
StatusCompensating
|
||||
// StatusCompensated 已补偿
|
||||
StatusCompensated
|
||||
// StatusAborted 已中止
|
||||
StatusAborted
|
||||
)
|
||||
|
||||
func (s SagaStatus) String() string {
|
||||
switch s {
|
||||
case StatusPending:
|
||||
return "PENDING"
|
||||
case StatusRunning:
|
||||
return "RUNNING"
|
||||
case StatusCompleted:
|
||||
return "COMPLETED"
|
||||
case StatusFailed:
|
||||
return "FAILED"
|
||||
case StatusCompensating:
|
||||
return "COMPENSATING"
|
||||
case StatusCompensated:
|
||||
return "COMPENSATED"
|
||||
case StatusAborted:
|
||||
return "ABORTED"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// StepStatus 步骤状态
|
||||
type StepStatus int
|
||||
|
||||
const (
|
||||
// StepPending 等待执行
|
||||
StepPending StepStatus = iota
|
||||
// StepRunning 执行中
|
||||
StepRunning
|
||||
// StepCompleted 完成
|
||||
StepCompleted
|
||||
// StepFailed 失败
|
||||
StepFailed
|
||||
// StepCompensated 已补偿
|
||||
StepCompensated
|
||||
// StepSkipped 跳过
|
||||
StepSkipped
|
||||
)
|
||||
|
||||
func (s StepStatus) String() string {
|
||||
switch s {
|
||||
case StepPending:
|
||||
return "PENDING"
|
||||
case StepRunning:
|
||||
return "RUNNING"
|
||||
case StepCompleted:
|
||||
return "COMPLETED"
|
||||
case StepFailed:
|
||||
return "FAILED"
|
||||
case StepCompensated:
|
||||
return "COMPENSATED"
|
||||
case StepSkipped:
|
||||
return "SKIPPED"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// SagaStep Saga步骤
|
||||
type SagaStep struct {
|
||||
Name string
|
||||
Action func(ctx context.Context, data interface{}) error
|
||||
Compensate func(ctx context.Context, data interface{}) error
|
||||
Status StepStatus
|
||||
Error error
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
RetryCount int
|
||||
MaxRetries int
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// SagaConfig Saga配置
|
||||
type SagaConfig struct {
|
||||
// 默认超时时间
|
||||
DefaultTimeout time.Duration
|
||||
// 默认重试次数
|
||||
DefaultMaxRetries int
|
||||
// 是否并行执行(当前只支持串行)
|
||||
Parallel bool
|
||||
// 事件发布器
|
||||
EventBus interfaces.EventBus
|
||||
}
|
||||
|
||||
// DefaultSagaConfig 默认Saga配置
|
||||
func DefaultSagaConfig() SagaConfig {
|
||||
return SagaConfig{
|
||||
DefaultTimeout: 30 * time.Second,
|
||||
DefaultMaxRetries: 3,
|
||||
Parallel: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Saga 分布式事务
|
||||
type Saga struct {
|
||||
ID string
|
||||
Name string
|
||||
Steps []*SagaStep
|
||||
Status SagaStatus
|
||||
Data interface{}
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
Error error
|
||||
Config SagaConfig
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
currentStep int
|
||||
result interface{}
|
||||
}
|
||||
|
||||
// NewSaga 创建新的Saga
|
||||
func NewSaga(id, name string, config SagaConfig, logger *zap.Logger) *Saga {
|
||||
return &Saga{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Steps: make([]*SagaStep, 0),
|
||||
Status: StatusPending,
|
||||
Config: config,
|
||||
logger: logger,
|
||||
currentStep: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// AddStep 添加步骤
|
||||
func (s *Saga) AddStep(name string, action, compensate func(ctx context.Context, data interface{}) error) *Saga {
|
||||
step := &SagaStep{
|
||||
Name: name,
|
||||
Action: action,
|
||||
Compensate: compensate,
|
||||
Status: StepPending,
|
||||
MaxRetries: s.Config.DefaultMaxRetries,
|
||||
Timeout: s.Config.DefaultTimeout,
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.Steps = append(s.Steps, step)
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.logger.Debug("Added step to saga",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", name))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// AddStepWithConfig 添加带配置的步骤
|
||||
func (s *Saga) AddStepWithConfig(name string, action, compensate func(ctx context.Context, data interface{}) error, maxRetries int, timeout time.Duration) *Saga {
|
||||
step := &SagaStep{
|
||||
Name: name,
|
||||
Action: action,
|
||||
Compensate: compensate,
|
||||
Status: StepPending,
|
||||
MaxRetries: maxRetries,
|
||||
Timeout: timeout,
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
s.Steps = append(s.Steps, step)
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.logger.Debug("Added step with config to saga",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", name),
|
||||
zap.Int("max_retries", maxRetries),
|
||||
zap.Duration("timeout", timeout))
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Execute 执行Saga
|
||||
func (s *Saga) Execute(ctx context.Context, data interface{}) error {
|
||||
s.mutex.Lock()
|
||||
if s.Status != StatusPending {
|
||||
s.mutex.Unlock()
|
||||
return fmt.Errorf("saga %s is not in pending status", s.ID)
|
||||
}
|
||||
|
||||
s.Status = StatusRunning
|
||||
s.Data = data
|
||||
s.StartTime = time.Now()
|
||||
s.mutex.Unlock()
|
||||
|
||||
s.logger.Info("Starting saga execution",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("saga_name", s.Name),
|
||||
zap.Int("total_steps", len(s.Steps)))
|
||||
|
||||
// 发布Saga开始事件
|
||||
s.publishEvent(ctx, "saga.started")
|
||||
|
||||
// 执行所有步骤
|
||||
for i, step := range s.Steps {
|
||||
s.mutex.Lock()
|
||||
s.currentStep = i
|
||||
s.mutex.Unlock()
|
||||
|
||||
if err := s.executeStep(ctx, step, data); err != nil {
|
||||
s.logger.Error("Step execution failed",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name),
|
||||
zap.Error(err))
|
||||
|
||||
// 执行补偿
|
||||
if compensateErr := s.compensate(ctx, i-1); compensateErr != nil {
|
||||
s.logger.Error("Compensation failed",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.Error(compensateErr))
|
||||
|
||||
s.setStatus(StatusAborted)
|
||||
s.publishEvent(ctx, "saga.aborted")
|
||||
return fmt.Errorf("saga execution failed and compensation failed: %w", compensateErr)
|
||||
}
|
||||
|
||||
s.setStatus(StatusCompensated)
|
||||
s.publishEvent(ctx, "saga.compensated")
|
||||
return fmt.Errorf("saga execution failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 所有步骤成功完成
|
||||
s.setStatus(StatusCompleted)
|
||||
s.EndTime = time.Now()
|
||||
|
||||
s.logger.Info("Saga completed successfully",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.Duration("duration", s.EndTime.Sub(s.StartTime)))
|
||||
|
||||
s.publishEvent(ctx, "saga.completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeStep 执行单个步骤
|
||||
func (s *Saga) executeStep(ctx context.Context, step *SagaStep, data interface{}) error {
|
||||
step.Status = StepRunning
|
||||
step.StartTime = time.Now()
|
||||
|
||||
s.logger.Debug("Executing step",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name))
|
||||
|
||||
// 设置超时上下文
|
||||
stepCtx, cancel := context.WithTimeout(ctx, step.Timeout)
|
||||
defer cancel()
|
||||
|
||||
// 重试逻辑
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= step.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
s.logger.Debug("Retrying step",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name),
|
||||
zap.Int("attempt", attempt))
|
||||
}
|
||||
|
||||
err := step.Action(stepCtx, data)
|
||||
if err == nil {
|
||||
step.Status = StepCompleted
|
||||
step.EndTime = time.Now()
|
||||
|
||||
s.logger.Debug("Step completed successfully",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name),
|
||||
zap.Duration("duration", step.EndTime.Sub(step.StartTime)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
step.RetryCount = attempt
|
||||
|
||||
// 检查是否应该重试
|
||||
if attempt < step.MaxRetries {
|
||||
select {
|
||||
case <-stepCtx.Done():
|
||||
// 上下文被取消,停止重试
|
||||
break
|
||||
case <-time.After(time.Duration(attempt+1) * 100 * time.Millisecond):
|
||||
// 等待一段时间后重试
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 所有重试都失败了
|
||||
step.Status = StepFailed
|
||||
step.Error = lastErr
|
||||
step.EndTime = time.Now()
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// compensate 执行补偿
|
||||
func (s *Saga) compensate(ctx context.Context, fromStep int) error {
|
||||
s.setStatus(StatusCompensating)
|
||||
|
||||
s.logger.Info("Starting compensation",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.Int("from_step", fromStep))
|
||||
|
||||
// 逆序执行补偿
|
||||
for i := fromStep; i >= 0; i-- {
|
||||
step := s.Steps[i]
|
||||
|
||||
// 只补偿已完成的步骤
|
||||
if step.Status != StepCompleted {
|
||||
step.Status = StepSkipped
|
||||
continue
|
||||
}
|
||||
|
||||
if step.Compensate == nil {
|
||||
s.logger.Warn("No compensation function for step",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
s.logger.Debug("Compensating step",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name))
|
||||
|
||||
// 设置超时上下文
|
||||
compensateCtx, cancel := context.WithTimeout(ctx, step.Timeout)
|
||||
|
||||
err := step.Compensate(compensateCtx, s.Data)
|
||||
cancel()
|
||||
|
||||
if err != nil {
|
||||
s.logger.Error("Compensation failed for step",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
step.Status = StepCompensated
|
||||
|
||||
s.logger.Debug("Step compensated successfully",
|
||||
zap.String("saga_id", s.ID),
|
||||
zap.String("step_name", step.Name))
|
||||
}
|
||||
|
||||
s.logger.Info("Compensation completed",
|
||||
zap.String("saga_id", s.ID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setStatus 设置状态
|
||||
func (s *Saga) setStatus(status SagaStatus) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
s.Status = status
|
||||
}
|
||||
|
||||
// GetStatus 获取状态
|
||||
func (s *Saga) GetStatus() SagaStatus {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
return s.Status
|
||||
}
|
||||
|
||||
// GetProgress 获取进度
|
||||
func (s *Saga) GetProgress() SagaProgress {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
completed := 0
|
||||
for _, step := range s.Steps {
|
||||
if step.Status == StepCompleted {
|
||||
completed++
|
||||
}
|
||||
}
|
||||
|
||||
var percentage float64
|
||||
if len(s.Steps) > 0 {
|
||||
percentage = float64(completed) / float64(len(s.Steps)) * 100
|
||||
}
|
||||
|
||||
return SagaProgress{
|
||||
SagaID: s.ID,
|
||||
Status: s.Status.String(),
|
||||
TotalSteps: len(s.Steps),
|
||||
CompletedSteps: completed,
|
||||
CurrentStep: s.currentStep + 1,
|
||||
PercentComplete: percentage,
|
||||
StartTime: s.StartTime,
|
||||
Duration: time.Since(s.StartTime),
|
||||
}
|
||||
}
|
||||
|
||||
// GetStepStatus 获取所有步骤状态
|
||||
func (s *Saga) GetStepStatus() []StepProgress {
|
||||
s.mutex.RLock()
|
||||
defer s.mutex.RUnlock()
|
||||
|
||||
progress := make([]StepProgress, len(s.Steps))
|
||||
for i, step := range s.Steps {
|
||||
progress[i] = StepProgress{
|
||||
Name: step.Name,
|
||||
Status: step.Status.String(),
|
||||
RetryCount: step.RetryCount,
|
||||
StartTime: step.StartTime,
|
||||
EndTime: step.EndTime,
|
||||
Duration: step.EndTime.Sub(step.StartTime),
|
||||
Error: "",
|
||||
}
|
||||
|
||||
if step.Error != nil {
|
||||
progress[i].Error = step.Error.Error()
|
||||
}
|
||||
}
|
||||
|
||||
return progress
|
||||
}
|
||||
|
||||
// publishEvent 发布事件
|
||||
func (s *Saga) publishEvent(ctx context.Context, eventType string) {
|
||||
if s.Config.EventBus == nil {
|
||||
return
|
||||
}
|
||||
|
||||
event := &SagaEvent{
|
||||
SagaID: s.ID,
|
||||
SagaName: s.Name,
|
||||
EventType: eventType,
|
||||
Status: s.Status.String(),
|
||||
Timestamp: time.Now(),
|
||||
Data: s.Data,
|
||||
}
|
||||
|
||||
// 这里应该实现Event接口,简化处理
|
||||
_ = event
|
||||
}
|
||||
|
||||
// SagaProgress Saga进度
|
||||
type SagaProgress struct {
|
||||
SagaID string `json:"saga_id"`
|
||||
Status string `json:"status"`
|
||||
TotalSteps int `json:"total_steps"`
|
||||
CompletedSteps int `json:"completed_steps"`
|
||||
CurrentStep int `json:"current_step"`
|
||||
PercentComplete float64 `json:"percent_complete"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
}
|
||||
|
||||
// StepProgress 步骤进度
|
||||
type StepProgress struct {
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
RetryCount int `json:"retry_count"`
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
Duration time.Duration `json:"duration"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// SagaEvent Saga事件
|
||||
type SagaEvent struct {
|
||||
SagaID string `json:"saga_id"`
|
||||
SagaName string `json:"saga_name"`
|
||||
EventType string `json:"event_type"`
|
||||
Status string `json:"status"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// SagaManager Saga管理器
|
||||
type SagaManager struct {
|
||||
sagas map[string]*Saga
|
||||
logger *zap.Logger
|
||||
mutex sync.RWMutex
|
||||
config SagaConfig
|
||||
}
|
||||
|
||||
// NewSagaManager 创建Saga管理器
|
||||
func NewSagaManager(config SagaConfig, logger *zap.Logger) *SagaManager {
|
||||
return &SagaManager{
|
||||
sagas: make(map[string]*Saga),
|
||||
logger: logger,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateSaga 创建Saga
|
||||
func (sm *SagaManager) CreateSaga(id, name string) *Saga {
|
||||
saga := NewSaga(id, name, sm.config, sm.logger.Named("saga"))
|
||||
|
||||
sm.mutex.Lock()
|
||||
sm.sagas[id] = saga
|
||||
sm.mutex.Unlock()
|
||||
|
||||
sm.logger.Info("Created saga",
|
||||
zap.String("saga_id", id),
|
||||
zap.String("saga_name", name))
|
||||
|
||||
return saga
|
||||
}
|
||||
|
||||
// GetSaga 获取Saga
|
||||
func (sm *SagaManager) GetSaga(id string) (*Saga, bool) {
|
||||
sm.mutex.RLock()
|
||||
defer sm.mutex.RUnlock()
|
||||
|
||||
saga, exists := sm.sagas[id]
|
||||
return saga, exists
|
||||
}
|
||||
|
||||
// ListSagas 列出所有Saga
|
||||
func (sm *SagaManager) ListSagas() []*Saga {
|
||||
sm.mutex.RLock()
|
||||
defer sm.mutex.RUnlock()
|
||||
|
||||
sagas := make([]*Saga, 0, len(sm.sagas))
|
||||
for _, saga := range sm.sagas {
|
||||
sagas = append(sagas, saga)
|
||||
}
|
||||
|
||||
return sagas
|
||||
}
|
||||
|
||||
// GetSagaProgress 获取Saga进度
|
||||
func (sm *SagaManager) GetSagaProgress(id string) (SagaProgress, bool) {
|
||||
saga, exists := sm.GetSaga(id)
|
||||
if !exists {
|
||||
return SagaProgress{}, false
|
||||
}
|
||||
|
||||
return saga.GetProgress(), true
|
||||
}
|
||||
|
||||
// RemoveSaga 移除Saga
|
||||
func (sm *SagaManager) RemoveSaga(id string) {
|
||||
sm.mutex.Lock()
|
||||
defer sm.mutex.Unlock()
|
||||
|
||||
delete(sm.sagas, id)
|
||||
sm.logger.Debug("Removed saga", zap.String("saga_id", id))
|
||||
}
|
||||
|
||||
// GetStats 获取统计信息
|
||||
func (sm *SagaManager) GetStats() map[string]interface{} {
|
||||
sm.mutex.RLock()
|
||||
defer sm.mutex.RUnlock()
|
||||
|
||||
statusCount := make(map[string]int)
|
||||
for _, saga := range sm.sagas {
|
||||
status := saga.GetStatus().String()
|
||||
statusCount[status]++
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"total_sagas": len(sm.sagas),
|
||||
"status_count": statusCount,
|
||||
}
|
||||
}
|
||||
|
||||
// 实现Service接口
|
||||
|
||||
// Name 返回服务名称
|
||||
func (sm *SagaManager) Name() string {
|
||||
return "saga-manager"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (sm *SagaManager) Initialize(ctx context.Context) error {
|
||||
sm.logger.Info("Saga manager service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (sm *SagaManager) HealthCheck(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (sm *SagaManager) Shutdown(ctx context.Context) error {
|
||||
sm.logger.Info("Saga manager service shutdown")
|
||||
return nil
|
||||
}
|
||||
130
internal/shared/sms/sms_service.go
Normal file
130
internal/shared/sms/sms_service.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package sms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
)
|
||||
|
||||
// Service 短信服务接口
|
||||
type Service interface {
|
||||
SendVerificationCode(ctx context.Context, phone string, code string) error
|
||||
GenerateCode(length int) string
|
||||
}
|
||||
|
||||
// AliSMSService 阿里云短信服务实现
|
||||
type AliSMSService struct {
|
||||
client *dysmsapi.Client
|
||||
config config.SMSConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAliSMSService 创建阿里云短信服务
|
||||
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
|
||||
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
|
||||
}
|
||||
|
||||
return &AliSMSService{
|
||||
client: client,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
request := dysmsapi.CreateSendSmsRequest()
|
||||
request.Scheme = "https"
|
||||
request.PhoneNumbers = phone
|
||||
request.SignName = s.config.SignName
|
||||
request.TemplateCode = s.config.TemplateCode
|
||||
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
|
||||
|
||||
response, err := s.client.SendSms(request)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to send SMS",
|
||||
zap.String("phone", phone),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("短信发送失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Code != "OK" {
|
||||
s.logger.Error("SMS send failed",
|
||||
zap.String("phone", phone),
|
||||
zap.String("code", response.Code),
|
||||
zap.String("message", response.Message))
|
||||
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
|
||||
}
|
||||
|
||||
s.logger.Info("SMS sent successfully",
|
||||
zap.String("phone", phone),
|
||||
zap.String("bizId", response.BizId))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成验证码
|
||||
func (s *AliSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
|
||||
// 生成指定长度的数字验证码
|
||||
max := big.NewInt(int64(pow10(length)))
|
||||
n, _ := rand.Int(rand.Reader, max)
|
||||
|
||||
// 格式化为指定长度,不足时前面补0
|
||||
format := fmt.Sprintf("%%0%dd", length)
|
||||
return fmt.Sprintf(format, n.Int64())
|
||||
}
|
||||
|
||||
// pow10 计算10的n次方
|
||||
func pow10(n int) int {
|
||||
result := 1
|
||||
for i := 0; i < n; i++ {
|
||||
result *= 10
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MockSMSService 模拟短信服务(用于开发和测试)
|
||||
type MockSMSService struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewMockSMSService 创建模拟短信服务
|
||||
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
|
||||
return &MockSMSService{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendVerificationCode 模拟发送验证码
|
||||
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
s.logger.Info("Mock SMS sent",
|
||||
zap.String("phone", phone),
|
||||
zap.String("code", code))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成验证码
|
||||
func (s *MockSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
|
||||
// 开发环境使用固定验证码便于测试
|
||||
result := ""
|
||||
for i := 0; i < length; i++ {
|
||||
result += "1"
|
||||
}
|
||||
return result
|
||||
}
|
||||
292
internal/shared/tracing/decorators.go
Normal file
292
internal/shared/tracing/decorators.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TracableService 可追踪的服务接口
|
||||
type TracableService interface {
|
||||
Name() string
|
||||
}
|
||||
|
||||
// ServiceDecorator 服务装饰器
|
||||
type ServiceDecorator struct {
|
||||
tracer *Tracer
|
||||
logger *zap.Logger
|
||||
config DecoratorConfig
|
||||
}
|
||||
|
||||
// DecoratorConfig 装饰器配置
|
||||
type DecoratorConfig struct {
|
||||
EnableMethodTracing bool
|
||||
ExcludePatterns []string
|
||||
IncludeArguments bool
|
||||
IncludeResults bool
|
||||
SlowMethodThreshold time.Duration
|
||||
}
|
||||
|
||||
// DefaultDecoratorConfig 默认装饰器配置
|
||||
func DefaultDecoratorConfig() DecoratorConfig {
|
||||
return DecoratorConfig{
|
||||
EnableMethodTracing: true,
|
||||
ExcludePatterns: []string{"Health", "Ping", "Name"},
|
||||
IncludeArguments: true,
|
||||
IncludeResults: false,
|
||||
SlowMethodThreshold: 100 * time.Millisecond,
|
||||
}
|
||||
}
|
||||
|
||||
// NewServiceDecorator 创建服务装饰器
|
||||
func NewServiceDecorator(tracer *Tracer, logger *zap.Logger) *ServiceDecorator {
|
||||
return &ServiceDecorator{
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
config: DefaultDecoratorConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// WrapService 自动包装服务,为所有方法添加链路追踪
|
||||
func (d *ServiceDecorator) WrapService(service interface{}) interface{} {
|
||||
serviceValue := reflect.ValueOf(service)
|
||||
serviceType := reflect.TypeOf(service)
|
||||
|
||||
if serviceType.Kind() == reflect.Ptr {
|
||||
serviceType = serviceType.Elem()
|
||||
serviceValue = serviceValue.Elem()
|
||||
}
|
||||
|
||||
// 创建代理结构
|
||||
proxyType := d.createProxyType(serviceType)
|
||||
proxyValue := reflect.New(proxyType).Elem()
|
||||
|
||||
// 设置原始服务字段
|
||||
proxyValue.FieldByName("target").Set(reflect.ValueOf(service))
|
||||
proxyValue.FieldByName("decorator").Set(reflect.ValueOf(d))
|
||||
|
||||
return proxyValue.Addr().Interface()
|
||||
}
|
||||
|
||||
// createProxyType 创建代理类型
|
||||
func (d *ServiceDecorator) createProxyType(serviceType reflect.Type) reflect.Type {
|
||||
// 获取服务名称
|
||||
serviceName := d.getServiceName(serviceType)
|
||||
|
||||
// 创建代理结构字段
|
||||
fields := []reflect.StructField{
|
||||
{
|
||||
Name: "target",
|
||||
Type: reflect.PtrTo(serviceType),
|
||||
},
|
||||
{
|
||||
Name: "decorator",
|
||||
Type: reflect.TypeOf(d),
|
||||
},
|
||||
}
|
||||
|
||||
// 为每个方法创建包装器方法
|
||||
for i := 0; i < serviceType.NumMethod(); i++ {
|
||||
method := serviceType.Method(i)
|
||||
if d.shouldTraceMethod(method.Name) {
|
||||
// 创建方法字段(用于存储方法实现)
|
||||
fields = append(fields, reflect.StructField{
|
||||
Name: method.Name,
|
||||
Type: method.Type,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// 创建新的结构类型
|
||||
proxyType := reflect.StructOf(fields)
|
||||
|
||||
// 实现接口方法
|
||||
d.implementMethods(proxyType, serviceType, serviceName)
|
||||
|
||||
return proxyType
|
||||
}
|
||||
|
||||
// shouldTraceMethod 判断是否应该追踪方法
|
||||
func (d *ServiceDecorator) shouldTraceMethod(methodName string) bool {
|
||||
if !d.config.EnableMethodTracing {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, pattern := range d.config.ExcludePatterns {
|
||||
if strings.Contains(methodName, pattern) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// getServiceName 获取服务名称
|
||||
func (d *ServiceDecorator) getServiceName(serviceType reflect.Type) string {
|
||||
serviceName := serviceType.Name()
|
||||
// 移除Service后缀
|
||||
if strings.HasSuffix(serviceName, "Service") {
|
||||
serviceName = strings.TrimSuffix(serviceName, "Service")
|
||||
}
|
||||
return strings.ToLower(serviceName)
|
||||
}
|
||||
|
||||
// TraceMethodCall 追踪方法调用
|
||||
func (d *ServiceDecorator) TraceMethodCall(
|
||||
ctx context.Context,
|
||||
serviceName, methodName string,
|
||||
fn func(context.Context) ([]reflect.Value, error),
|
||||
args []reflect.Value,
|
||||
) ([]reflect.Value, error) {
|
||||
// 创建span名称
|
||||
spanName := fmt.Sprintf("%s.%s", serviceName, methodName)
|
||||
|
||||
// 开始追踪
|
||||
ctx, span := d.tracer.StartSpan(ctx, spanName)
|
||||
defer span.End()
|
||||
|
||||
// 添加基础属性
|
||||
d.tracer.AddSpanAttributes(span,
|
||||
attribute.String("service.name", serviceName),
|
||||
attribute.String("service.method", methodName),
|
||||
attribute.String("service.type", "business"),
|
||||
)
|
||||
|
||||
// 添加参数信息(如果启用)
|
||||
if d.config.IncludeArguments {
|
||||
d.addArgumentAttributes(span, args)
|
||||
}
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 执行原始方法
|
||||
results, err := fn(ctx)
|
||||
|
||||
// 计算执行时间
|
||||
duration := time.Since(startTime)
|
||||
d.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("service.duration_ms", duration.Milliseconds()),
|
||||
)
|
||||
|
||||
// 标记慢方法
|
||||
if duration > d.config.SlowMethodThreshold {
|
||||
d.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("service.slow_method", true),
|
||||
)
|
||||
d.logger.Warn("慢方法检测",
|
||||
zap.String("service", serviceName),
|
||||
zap.String("method", methodName),
|
||||
zap.Duration("duration", duration),
|
||||
zap.String("trace_id", d.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
}
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
d.tracer.SetSpanError(span, err)
|
||||
d.logger.Error("服务方法执行失败",
|
||||
zap.String("service", serviceName),
|
||||
zap.String("method", methodName),
|
||||
zap.Error(err),
|
||||
zap.String("trace_id", d.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
} else {
|
||||
d.tracer.SetSpanSuccess(span)
|
||||
|
||||
// 添加结果信息(如果启用)
|
||||
if d.config.IncludeResults {
|
||||
d.addResultAttributes(span, results)
|
||||
}
|
||||
}
|
||||
|
||||
return results, err
|
||||
}
|
||||
|
||||
// addArgumentAttributes 添加参数属性
|
||||
func (d *ServiceDecorator) addArgumentAttributes(span trace.Span, args []reflect.Value) {
|
||||
for i, arg := range args {
|
||||
if i == 0 && arg.Type().String() == "context.Context" {
|
||||
continue // 跳过context参数
|
||||
}
|
||||
|
||||
argName := fmt.Sprintf("service.arg_%d", i)
|
||||
argValue := d.extractValue(arg)
|
||||
|
||||
if argValue != "" && len(argValue) < 1000 { // 限制长度避免性能问题
|
||||
d.tracer.AddSpanAttributes(span,
|
||||
attribute.String(argName, argValue),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addResultAttributes 添加结果属性
|
||||
func (d *ServiceDecorator) addResultAttributes(span trace.Span, results []reflect.Value) {
|
||||
for i, result := range results {
|
||||
if result.Type().String() == "error" {
|
||||
continue // 错误在其他地方处理
|
||||
}
|
||||
|
||||
resultName := fmt.Sprintf("service.result_%d", i)
|
||||
resultValue := d.extractValue(result)
|
||||
|
||||
if resultValue != "" && len(resultValue) < 1000 {
|
||||
d.tracer.AddSpanAttributes(span,
|
||||
attribute.String(resultName, resultValue),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractValue 提取值的字符串表示
|
||||
func (d *ServiceDecorator) extractValue(value reflect.Value) string {
|
||||
if !value.IsValid() {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.String:
|
||||
return value.String()
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return fmt.Sprintf("%d", value.Int())
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return fmt.Sprintf("%d", value.Uint())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return fmt.Sprintf("%.2f", value.Float())
|
||||
case reflect.Bool:
|
||||
return fmt.Sprintf("%t", value.Bool())
|
||||
case reflect.Ptr:
|
||||
if value.IsNil() {
|
||||
return "nil"
|
||||
}
|
||||
return d.extractValue(value.Elem())
|
||||
case reflect.Struct:
|
||||
// 对于结构体,只返回类型名
|
||||
return value.Type().Name()
|
||||
case reflect.Slice, reflect.Array:
|
||||
return fmt.Sprintf("[%d items]", value.Len())
|
||||
default:
|
||||
return value.Type().Name()
|
||||
}
|
||||
}
|
||||
|
||||
// implementMethods 实现接口方法(占位符,实际需要运行时代理)
|
||||
func (d *ServiceDecorator) implementMethods(proxyType, serviceType reflect.Type, serviceName string) {
|
||||
// 这里是运行时方法实现的占位符
|
||||
// 实际实现需要使用reflect.MakeFunc或其他运行时代理技术
|
||||
}
|
||||
|
||||
// GetFunctionName 获取函数名称
|
||||
func GetFunctionName(fn interface{}) string {
|
||||
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
||||
parts := strings.Split(name, ".")
|
||||
return parts[len(parts)-1]
|
||||
}
|
||||
320
internal/shared/tracing/gorm_plugin.go
Normal file
320
internal/shared/tracing/gorm_plugin.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
gormSpanKey = "otel:span"
|
||||
gormOperationKey = "otel:operation"
|
||||
gormTableNameKey = "otel:table_name"
|
||||
gormStartTimeKey = "otel:start_time"
|
||||
)
|
||||
|
||||
// GormTracingPlugin GORM链路追踪插件
|
||||
type GormTracingPlugin struct {
|
||||
tracer *Tracer
|
||||
logger *zap.Logger
|
||||
config GormPluginConfig
|
||||
}
|
||||
|
||||
// GormPluginConfig GORM插件配置
|
||||
type GormPluginConfig struct {
|
||||
IncludeSQL bool
|
||||
IncludeValues bool
|
||||
SlowThreshold time.Duration
|
||||
ExcludeTables []string
|
||||
SanitizeSQL bool
|
||||
}
|
||||
|
||||
// DefaultGormPluginConfig 默认GORM插件配置
|
||||
func DefaultGormPluginConfig() GormPluginConfig {
|
||||
return GormPluginConfig{
|
||||
IncludeSQL: true,
|
||||
IncludeValues: false, // 生产环境建议设为false避免记录敏感数据
|
||||
SlowThreshold: 200 * time.Millisecond,
|
||||
ExcludeTables: []string{"migrations", "schema_migrations"},
|
||||
SanitizeSQL: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewGormTracingPlugin 创建GORM追踪插件
|
||||
func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin {
|
||||
return &GormTracingPlugin{
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
config: DefaultGormPluginConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回插件名称
|
||||
func (p *GormTracingPlugin) Name() string {
|
||||
return "gorm-otel-tracing"
|
||||
}
|
||||
|
||||
// Initialize 初始化插件
|
||||
func (p *GormTracingPlugin) Initialize(db *gorm.DB) error {
|
||||
// 注册各种操作的回调
|
||||
callbacks := []string{"create", "query", "update", "delete", "raw"}
|
||||
|
||||
for _, operation := range callbacks {
|
||||
switch operation {
|
||||
case "create":
|
||||
err := db.Callback().Create().Before("gorm:create").
|
||||
Register(p.Name()+":before_create", p.beforeOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register before create callback: %w", err)
|
||||
}
|
||||
err = db.Callback().Create().After("gorm:create").
|
||||
Register(p.Name()+":after_create", p.afterOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register after create callback: %w", err)
|
||||
}
|
||||
case "query":
|
||||
err := db.Callback().Query().Before("gorm:query").
|
||||
Register(p.Name()+":before_query", p.beforeOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register before query callback: %w", err)
|
||||
}
|
||||
err = db.Callback().Query().After("gorm:query").
|
||||
Register(p.Name()+":after_query", p.afterOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register after query callback: %w", err)
|
||||
}
|
||||
case "update":
|
||||
err := db.Callback().Update().Before("gorm:update").
|
||||
Register(p.Name()+":before_update", p.beforeOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register before update callback: %w", err)
|
||||
}
|
||||
err = db.Callback().Update().After("gorm:update").
|
||||
Register(p.Name()+":after_update", p.afterOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register after update callback: %w", err)
|
||||
}
|
||||
case "delete":
|
||||
err := db.Callback().Delete().Before("gorm:delete").
|
||||
Register(p.Name()+":before_delete", p.beforeOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register before delete callback: %w", err)
|
||||
}
|
||||
err = db.Callback().Delete().After("gorm:delete").
|
||||
Register(p.Name()+":after_delete", p.afterOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register after delete callback: %w", err)
|
||||
}
|
||||
case "raw":
|
||||
err := db.Callback().Raw().Before("gorm:raw").
|
||||
Register(p.Name()+":before_raw", p.beforeOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register before raw callback: %w", err)
|
||||
}
|
||||
err = db.Callback().Raw().After("gorm:raw").
|
||||
Register(p.Name()+":after_raw", p.afterOperation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register after raw callback: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.logger.Info("GORM追踪插件已初始化")
|
||||
return nil
|
||||
}
|
||||
|
||||
// beforeOperation 操作前回调
|
||||
func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) {
|
||||
// 检查是否应该跳过追踪
|
||||
if p.shouldSkipTracing(db) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx := db.Statement.Context
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// 获取操作信息
|
||||
operation := p.getOperationType(db)
|
||||
tableName := p.getTableName(db)
|
||||
|
||||
// 检查是否应该排除此表
|
||||
if p.isExcludedTable(tableName) {
|
||||
return
|
||||
}
|
||||
|
||||
// 开始追踪
|
||||
ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName)
|
||||
|
||||
// 添加基础属性
|
||||
p.tracer.AddSpanAttributes(span,
|
||||
attribute.String("db.system", "postgresql"),
|
||||
attribute.String("db.operation", operation),
|
||||
)
|
||||
|
||||
if tableName != "" {
|
||||
p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName))
|
||||
}
|
||||
|
||||
// 保存追踪信息到GORM context
|
||||
db.Set(gormSpanKey, span)
|
||||
db.Set(gormOperationKey, operation)
|
||||
db.Set(gormTableNameKey, tableName)
|
||||
db.Set(gormStartTimeKey, time.Now())
|
||||
|
||||
// 更新statement context
|
||||
db.Statement.Context = ctx
|
||||
}
|
||||
|
||||
// afterOperation 操作后回调
|
||||
func (p *GormTracingPlugin) afterOperation(db *gorm.DB) {
|
||||
// 获取span
|
||||
spanValue, exists := db.Get(gormSpanKey)
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
span, ok := spanValue.(trace.Span)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
defer span.End()
|
||||
|
||||
// 获取操作信息
|
||||
operation, _ := db.Get(gormOperationKey)
|
||||
tableName, _ := db.Get(gormTableNameKey)
|
||||
startTime, _ := db.Get(gormStartTimeKey)
|
||||
|
||||
// 计算执行时间
|
||||
var duration time.Duration
|
||||
if st, ok := startTime.(time.Time); ok {
|
||||
duration = time.Since(st)
|
||||
p.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("db.duration_ms", duration.Milliseconds()),
|
||||
)
|
||||
}
|
||||
|
||||
// 添加SQL信息
|
||||
if p.config.IncludeSQL && db.Statement.SQL.String() != "" {
|
||||
sql := db.Statement.SQL.String()
|
||||
if p.config.SanitizeSQL {
|
||||
sql = p.sanitizeSQL(sql)
|
||||
}
|
||||
p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql))
|
||||
}
|
||||
|
||||
// 添加影响行数
|
||||
if db.Statement.RowsAffected >= 0 {
|
||||
p.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("db.rows_affected", db.Statement.RowsAffected),
|
||||
)
|
||||
}
|
||||
|
||||
// 处理错误
|
||||
if db.Error != nil {
|
||||
p.tracer.SetSpanError(span, db.Error)
|
||||
span.SetStatus(codes.Error, db.Error.Error())
|
||||
|
||||
p.logger.Error("数据库操作失败",
|
||||
zap.String("operation", fmt.Sprintf("%v", operation)),
|
||||
zap.String("table", fmt.Sprintf("%v", tableName)),
|
||||
zap.Error(db.Error),
|
||||
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
|
||||
)
|
||||
} else {
|
||||
p.tracer.SetSpanSuccess(span)
|
||||
span.SetStatus(codes.Ok, "success")
|
||||
|
||||
// 检查慢查询
|
||||
if duration > p.config.SlowThreshold {
|
||||
p.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("db.slow_query", true),
|
||||
)
|
||||
|
||||
p.logger.Warn("慢SQL查询检测",
|
||||
zap.String("operation", fmt.Sprintf("%v", operation)),
|
||||
zap.String("table", fmt.Sprintf("%v", tableName)),
|
||||
zap.Duration("duration", duration),
|
||||
zap.String("sql", db.Statement.SQL.String()),
|
||||
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldSkipTracing 检查是否应该跳过追踪
|
||||
func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool {
|
||||
// 检查是否已有span(避免重复追踪)
|
||||
if _, exists := db.Get(gormSpanKey); exists {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getOperationType 获取操作类型
|
||||
func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string {
|
||||
switch db.Statement.ReflectValue.Kind() {
|
||||
default:
|
||||
sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String()))
|
||||
if sql == "" {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
if strings.HasPrefix(sql, "SELECT") {
|
||||
return "select"
|
||||
} else if strings.HasPrefix(sql, "INSERT") {
|
||||
return "insert"
|
||||
} else if strings.HasPrefix(sql, "UPDATE") {
|
||||
return "update"
|
||||
} else if strings.HasPrefix(sql, "DELETE") {
|
||||
return "delete"
|
||||
} else if strings.HasPrefix(sql, "CREATE") {
|
||||
return "create"
|
||||
} else if strings.HasPrefix(sql, "DROP") {
|
||||
return "drop"
|
||||
} else if strings.HasPrefix(sql, "ALTER") {
|
||||
return "alter"
|
||||
}
|
||||
|
||||
return "query"
|
||||
}
|
||||
}
|
||||
|
||||
// getTableName 获取表名
|
||||
func (p *GormTracingPlugin) getTableName(db *gorm.DB) string {
|
||||
if db.Statement.Table != "" {
|
||||
return db.Statement.Table
|
||||
}
|
||||
|
||||
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
|
||||
return db.Statement.Schema.Table
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// isExcludedTable 检查是否为排除的表
|
||||
func (p *GormTracingPlugin) isExcludedTable(tableName string) bool {
|
||||
for _, excluded := range p.config.ExcludeTables {
|
||||
if tableName == excluded {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// sanitizeSQL 清理SQL语句,移除敏感信息
|
||||
func (p *GormTracingPlugin) sanitizeSQL(sql string) string {
|
||||
// 简单的SQL清理,将参数替换为占位符
|
||||
// 在生产环境中,您可能需要更复杂的清理逻辑
|
||||
return strings.ReplaceAll(sql, "'", "?")
|
||||
}
|
||||
407
internal/shared/tracing/redis_wrapper.go
Normal file
407
internal/shared/tracing/redis_wrapper.go
Normal file
@@ -0,0 +1,407 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// TracedRedisCache Redis缓存自动追踪包装器
|
||||
type TracedRedisCache struct {
|
||||
client redis.UniversalClient
|
||||
tracer *Tracer
|
||||
logger *zap.Logger
|
||||
prefix string
|
||||
config RedisTracingConfig
|
||||
}
|
||||
|
||||
// RedisTracingConfig Redis追踪配置
|
||||
type RedisTracingConfig struct {
|
||||
IncludeKeys bool
|
||||
IncludeValues bool
|
||||
MaxKeyLength int
|
||||
MaxValueLength int
|
||||
SlowThreshold time.Duration
|
||||
SanitizeValues bool
|
||||
}
|
||||
|
||||
// DefaultRedisTracingConfig 默认Redis追踪配置
|
||||
func DefaultRedisTracingConfig() RedisTracingConfig {
|
||||
return RedisTracingConfig{
|
||||
IncludeKeys: true,
|
||||
IncludeValues: false, // 生产环境建议设为false保护敏感数据
|
||||
MaxKeyLength: 100,
|
||||
MaxValueLength: 1000,
|
||||
SlowThreshold: 50 * time.Millisecond,
|
||||
SanitizeValues: true,
|
||||
}
|
||||
}
|
||||
|
||||
// NewTracedRedisCache 创建带追踪的Redis缓存
|
||||
func NewTracedRedisCache(client redis.UniversalClient, tracer *Tracer, logger *zap.Logger, prefix string) interfaces.CacheService {
|
||||
return &TracedRedisCache{
|
||||
client: client,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
prefix: prefix,
|
||||
config: DefaultRedisTracingConfig(),
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (c *TracedRedisCache) Name() string {
|
||||
return "redis-cache"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (c *TracedRedisCache) Initialize(ctx context.Context) error {
|
||||
c.logger.Info("Redis缓存服务已初始化")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (c *TracedRedisCache) HealthCheck(ctx context.Context) error {
|
||||
_, err := c.client.Ping(ctx).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (c *TracedRedisCache) Shutdown(ctx context.Context) error {
|
||||
c.logger.Info("Redis缓存服务已关闭")
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (c *TracedRedisCache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
// 开始追踪
|
||||
ctx, span := c.tracer.StartCacheSpan(ctx, "get", key)
|
||||
defer span.End()
|
||||
|
||||
// 添加基础属性
|
||||
c.addBaseAttributes(span, "get", key)
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 构建完整键名
|
||||
fullKey := c.buildKey(key)
|
||||
|
||||
// 执行Redis操作
|
||||
result, err := c.client.Get(ctx, fullKey).Result()
|
||||
|
||||
// 计算执行时间
|
||||
duration := time.Since(startTime)
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
|
||||
)
|
||||
|
||||
// 检查慢操作
|
||||
if duration > c.config.SlowThreshold {
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("redis.slow_operation", true),
|
||||
)
|
||||
c.logger.Warn("Redis慢操作检测",
|
||||
zap.String("operation", "get"),
|
||||
zap.String("key", c.sanitizeKey(key)),
|
||||
zap.Duration("duration", duration),
|
||||
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
}
|
||||
|
||||
// 处理结果
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
// 缓存未命中
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("redis.hit", false),
|
||||
attribute.String("redis.result", "miss"),
|
||||
)
|
||||
c.tracer.SetSpanSuccess(span)
|
||||
return interfaces.ErrCacheMiss
|
||||
} else {
|
||||
// Redis错误
|
||||
c.tracer.SetSpanError(span, err)
|
||||
c.logger.Error("Redis GET操作失败",
|
||||
zap.String("key", c.sanitizeKey(key)),
|
||||
zap.Error(err),
|
||||
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存命中
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("redis.hit", true),
|
||||
attribute.String("redis.result", "hit"),
|
||||
attribute.Int("redis.value_size", len(result)),
|
||||
)
|
||||
|
||||
// 反序列化
|
||||
if err := c.deserialize(result, dest); err != nil {
|
||||
c.tracer.SetSpanError(span, err)
|
||||
return err
|
||||
}
|
||||
|
||||
c.tracer.SetSpanSuccess(span)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (c *TracedRedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
|
||||
// 开始追踪
|
||||
ctx, span := c.tracer.StartCacheSpan(ctx, "set", key)
|
||||
defer span.End()
|
||||
|
||||
// 添加基础属性
|
||||
c.addBaseAttributes(span, "set", key)
|
||||
|
||||
// 处理TTL
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
if duration, ok := ttl[0].(time.Duration); ok {
|
||||
expiration = duration
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("redis.ttl_seconds", int64(expiration.Seconds())),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 序列化值
|
||||
serialized, err := c.serialize(value)
|
||||
if err != nil {
|
||||
c.tracer.SetSpanError(span, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建完整键名
|
||||
fullKey := c.buildKey(key)
|
||||
|
||||
// 执行Redis操作
|
||||
err = c.client.Set(ctx, fullKey, serialized, expiration).Err()
|
||||
|
||||
// 计算执行时间
|
||||
duration := time.Since(startTime)
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
|
||||
attribute.Int("redis.value_size", len(serialized)),
|
||||
)
|
||||
|
||||
// 检查慢操作
|
||||
if duration > c.config.SlowThreshold {
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("redis.slow_operation", true),
|
||||
)
|
||||
c.logger.Warn("Redis慢操作检测",
|
||||
zap.String("operation", "set"),
|
||||
zap.String("key", c.sanitizeKey(key)),
|
||||
zap.Duration("duration", duration),
|
||||
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
}
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
c.tracer.SetSpanError(span, err)
|
||||
c.logger.Error("Redis SET操作失败",
|
||||
zap.String("key", c.sanitizeKey(key)),
|
||||
zap.Error(err),
|
||||
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
c.tracer.SetSpanSuccess(span)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (c *TracedRedisCache) Delete(ctx context.Context, keys ...string) error {
|
||||
// 开始追踪
|
||||
ctx, span := c.tracer.StartCacheSpan(ctx, "delete", strings.Join(keys, ","))
|
||||
defer span.End()
|
||||
|
||||
// 添加基础属性
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.String("redis.operation", "delete"),
|
||||
attribute.Int("redis.key_count", len(keys)),
|
||||
)
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 构建完整键名
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = c.buildKey(key)
|
||||
}
|
||||
|
||||
// 执行Redis操作
|
||||
deleted, err := c.client.Del(ctx, fullKeys...).Result()
|
||||
|
||||
// 计算执行时间
|
||||
duration := time.Since(startTime)
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
|
||||
attribute.Int64("redis.deleted_count", deleted),
|
||||
)
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
c.tracer.SetSpanError(span, err)
|
||||
c.logger.Error("Redis DELETE操作失败",
|
||||
zap.Strings("keys", c.sanitizeKeys(keys)),
|
||||
zap.Error(err),
|
||||
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
c.tracer.SetSpanSuccess(span)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (c *TracedRedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
// 开始追踪
|
||||
ctx, span := c.tracer.StartCacheSpan(ctx, "exists", key)
|
||||
defer span.End()
|
||||
|
||||
// 添加基础属性
|
||||
c.addBaseAttributes(span, "exists", key)
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 构建完整键名
|
||||
fullKey := c.buildKey(key)
|
||||
|
||||
// 执行Redis操作
|
||||
count, err := c.client.Exists(ctx, fullKey).Result()
|
||||
|
||||
// 计算执行时间
|
||||
duration := time.Since(startTime)
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
|
||||
attribute.Bool("redis.exists", count > 0),
|
||||
)
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
c.tracer.SetSpanError(span, err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
c.tracer.SetSpanSuccess(span)
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取(基础实现)
|
||||
func (c *TracedRedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
|
||||
result := make(map[string]interface{})
|
||||
|
||||
// 简单实现:逐个获取(实际应用中可以使用MGET优化)
|
||||
for _, key := range keys {
|
||||
var value interface{}
|
||||
if err := c.Get(ctx, key, &value); err == nil {
|
||||
result[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置(基础实现)
|
||||
func (c *TracedRedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
|
||||
// 简单实现:逐个设置(实际应用中可以使用pipeline优化)
|
||||
for key, value := range data {
|
||||
if err := c.Set(ctx, key, value, ttl...); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeletePattern 按模式删除(基础实现)
|
||||
func (c *TracedRedisCache) DeletePattern(ctx context.Context, pattern string) error {
|
||||
// 这里需要实现模式删除逻辑
|
||||
return fmt.Errorf("DeletePattern not implemented")
|
||||
}
|
||||
|
||||
// Keys 获取匹配的键(基础实现)
|
||||
func (c *TracedRedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
// 这里需要实现键匹配逻辑
|
||||
return nil, fmt.Errorf("Keys not implemented")
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计(基础实现)
|
||||
func (c *TracedRedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
|
||||
return interfaces.CacheStats{}, fmt.Errorf("Stats not implemented")
|
||||
}
|
||||
|
||||
// 辅助方法
|
||||
|
||||
// addBaseAttributes 添加基础属性
|
||||
func (c *TracedRedisCache) addBaseAttributes(span trace.Span, operation, key string) {
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.String("redis.operation", operation),
|
||||
attribute.String("db.system", "redis"),
|
||||
)
|
||||
|
||||
if c.config.IncludeKeys {
|
||||
sanitizedKey := c.sanitizeKey(key)
|
||||
if len(sanitizedKey) <= c.config.MaxKeyLength {
|
||||
c.tracer.AddSpanAttributes(span,
|
||||
attribute.String("redis.key", sanitizedKey),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildKey 构建完整的Redis键名
|
||||
func (c *TracedRedisCache) buildKey(key string) string {
|
||||
if c.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", c.prefix, key)
|
||||
}
|
||||
|
||||
// sanitizeKey 清理键名用于日志记录
|
||||
func (c *TracedRedisCache) sanitizeKey(key string) string {
|
||||
if len(key) <= c.config.MaxKeyLength {
|
||||
return key
|
||||
}
|
||||
return key[:c.config.MaxKeyLength] + "..."
|
||||
}
|
||||
|
||||
// sanitizeKeys 批量清理键名
|
||||
func (c *TracedRedisCache) sanitizeKeys(keys []string) []string {
|
||||
result := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
result[i] = c.sanitizeKey(key)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// serialize 序列化值(简单实现)
|
||||
func (c *TracedRedisCache) serialize(value interface{}) (string, error) {
|
||||
// 这里应该使用JSON或其他序列化方法
|
||||
return fmt.Sprintf("%v", value), nil
|
||||
}
|
||||
|
||||
// deserialize 反序列化值(简单实现)
|
||||
func (c *TracedRedisCache) deserialize(data string, dest interface{}) error {
|
||||
// 这里应该实现真正的反序列化逻辑
|
||||
return fmt.Errorf("deserialize not fully implemented")
|
||||
}
|
||||
189
internal/shared/tracing/service_wrapper.go
Normal file
189
internal/shared/tracing/service_wrapper.go
Normal file
@@ -0,0 +1,189 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/domains/user/dto"
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// ServiceWrapper 服务包装器,提供自动追踪能力
|
||||
type ServiceWrapper struct {
|
||||
tracer *Tracer
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewServiceWrapper 创建服务包装器
|
||||
func NewServiceWrapper(tracer *Tracer, logger *zap.Logger) *ServiceWrapper {
|
||||
return &ServiceWrapper{
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// TraceServiceCall 追踪服务调用的通用方法
|
||||
func (w *ServiceWrapper) TraceServiceCall(
|
||||
ctx context.Context,
|
||||
serviceName, methodName string,
|
||||
fn func(context.Context) error,
|
||||
) error {
|
||||
// 创建span名称
|
||||
spanName := fmt.Sprintf("%s.%s", serviceName, methodName)
|
||||
|
||||
// 开始追踪
|
||||
ctx, span := w.tracer.StartSpan(ctx, spanName)
|
||||
defer span.End()
|
||||
|
||||
// 添加基础属性
|
||||
w.tracer.AddSpanAttributes(span,
|
||||
attribute.String("service.name", serviceName),
|
||||
attribute.String("service.method", methodName),
|
||||
attribute.String("service.type", "business"),
|
||||
)
|
||||
|
||||
// 记录开始时间
|
||||
startTime := time.Now()
|
||||
|
||||
// 执行原始方法
|
||||
err := fn(ctx)
|
||||
|
||||
// 计算执行时间
|
||||
duration := time.Since(startTime)
|
||||
w.tracer.AddSpanAttributes(span,
|
||||
attribute.Int64("service.duration_ms", duration.Milliseconds()),
|
||||
)
|
||||
|
||||
// 标记慢方法
|
||||
if duration > 100*time.Millisecond {
|
||||
w.tracer.AddSpanAttributes(span,
|
||||
attribute.Bool("service.slow_method", true),
|
||||
)
|
||||
w.logger.Warn("慢方法检测",
|
||||
zap.String("service", serviceName),
|
||||
zap.String("method", methodName),
|
||||
zap.Duration("duration", duration),
|
||||
zap.String("trace_id", w.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
}
|
||||
|
||||
// 处理错误
|
||||
if err != nil {
|
||||
w.tracer.SetSpanError(span, err)
|
||||
w.logger.Error("服务方法执行失败",
|
||||
zap.String("service", serviceName),
|
||||
zap.String("method", methodName),
|
||||
zap.Error(err),
|
||||
zap.String("trace_id", w.tracer.GetTraceID(ctx)),
|
||||
)
|
||||
} else {
|
||||
w.tracer.SetSpanSuccess(span)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// TracedUserService 自动追踪的用户服务包装器
|
||||
type TracedUserService struct {
|
||||
service interfaces.UserService
|
||||
wrapper *ServiceWrapper
|
||||
}
|
||||
|
||||
// NewTracedUserService 创建带追踪的用户服务
|
||||
func NewTracedUserService(service interfaces.UserService, wrapper *ServiceWrapper) interfaces.UserService {
|
||||
return &TracedUserService{
|
||||
service: service,
|
||||
wrapper: wrapper,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *TracedUserService) Name() string {
|
||||
return "user-service"
|
||||
}
|
||||
|
||||
func (t *TracedUserService) Initialize(ctx context.Context) error {
|
||||
return t.wrapper.TraceServiceCall(ctx, "user", "initialize", t.service.Initialize)
|
||||
}
|
||||
|
||||
func (t *TracedUserService) HealthCheck(ctx context.Context) error {
|
||||
return t.service.HealthCheck(ctx) // 不追踪健康检查
|
||||
}
|
||||
|
||||
func (t *TracedUserService) Shutdown(ctx context.Context) error {
|
||||
return t.wrapper.TraceServiceCall(ctx, "user", "shutdown", t.service.Shutdown)
|
||||
}
|
||||
|
||||
func (t *TracedUserService) Register(ctx context.Context, req *dto.RegisterRequest) (*entities.User, error) {
|
||||
var result *entities.User
|
||||
var err error
|
||||
|
||||
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "register", func(ctx context.Context) error {
|
||||
result, err = t.service.Register(ctx, req)
|
||||
return err
|
||||
})
|
||||
|
||||
if traceErr != nil {
|
||||
return nil, traceErr
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (t *TracedUserService) LoginWithPassword(ctx context.Context, req *dto.LoginWithPasswordRequest) (*entities.User, error) {
|
||||
var result *entities.User
|
||||
var err error
|
||||
|
||||
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "login_password", func(ctx context.Context) error {
|
||||
result, err = t.service.LoginWithPassword(ctx, req)
|
||||
return err
|
||||
})
|
||||
|
||||
if traceErr != nil {
|
||||
return nil, traceErr
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (t *TracedUserService) LoginWithSMS(ctx context.Context, req *dto.LoginWithSMSRequest) (*entities.User, error) {
|
||||
var result *entities.User
|
||||
var err error
|
||||
|
||||
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "login_sms", func(ctx context.Context) error {
|
||||
result, err = t.service.LoginWithSMS(ctx, req)
|
||||
return err
|
||||
})
|
||||
|
||||
if traceErr != nil {
|
||||
return nil, traceErr
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (t *TracedUserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
|
||||
return t.wrapper.TraceServiceCall(ctx, "user", "change_password", func(ctx context.Context) error {
|
||||
return t.service.ChangePassword(ctx, userID, req)
|
||||
})
|
||||
}
|
||||
|
||||
func (t *TracedUserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
|
||||
var result *entities.User
|
||||
var err error
|
||||
|
||||
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "get_by_id", func(ctx context.Context) error {
|
||||
result, err = t.service.GetByID(ctx, id)
|
||||
return err
|
||||
})
|
||||
|
||||
if traceErr != nil {
|
||||
return nil, traceErr
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
474
internal/shared/tracing/tracer.go
Normal file
474
internal/shared/tracing/tracer.go
Normal file
@@ -0,0 +1,474 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/codes"
|
||||
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
|
||||
"go.opentelemetry.io/otel/sdk/resource"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TracerConfig 追踪器配置
|
||||
type TracerConfig struct {
|
||||
ServiceName string
|
||||
ServiceVersion string
|
||||
Environment string
|
||||
Endpoint string
|
||||
SampleRate float64
|
||||
Enabled bool
|
||||
}
|
||||
|
||||
// DefaultTracerConfig 默认追踪器配置
|
||||
func DefaultTracerConfig() TracerConfig {
|
||||
return TracerConfig{
|
||||
ServiceName: "tyapi-server",
|
||||
ServiceVersion: "1.0.0",
|
||||
Environment: "development",
|
||||
Endpoint: "http://localhost:4317",
|
||||
SampleRate: 0.1,
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
// Tracer 链路追踪器
|
||||
type Tracer struct {
|
||||
config TracerConfig
|
||||
logger *zap.Logger
|
||||
provider *sdktrace.TracerProvider
|
||||
tracer trace.Tracer
|
||||
mutex sync.RWMutex
|
||||
initialized bool
|
||||
shutdown func(context.Context) error
|
||||
}
|
||||
|
||||
// NewTracer 创建链路追踪器
|
||||
func NewTracer(config TracerConfig, logger *zap.Logger) *Tracer {
|
||||
return &Tracer{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize 初始化追踪器
|
||||
func (t *Tracer) Initialize(ctx context.Context) error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if t.initialized {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !t.config.Enabled {
|
||||
t.logger.Info("Tracing is disabled")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建资源
|
||||
res, err := resource.New(ctx,
|
||||
resource.WithAttributes(
|
||||
attribute.String("service.name", t.config.ServiceName),
|
||||
attribute.String("service.version", t.config.ServiceVersion),
|
||||
attribute.String("environment", t.config.Environment),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create resource: %w", err)
|
||||
}
|
||||
|
||||
// 创建采样器
|
||||
sampler := sdktrace.TraceIDRatioBased(t.config.SampleRate)
|
||||
|
||||
// 创建导出器
|
||||
var spanProcessor sdktrace.SpanProcessor
|
||||
if t.config.Endpoint != "" {
|
||||
// 使用OTLP gRPC导出器(支持Jaeger、Tempo等)
|
||||
exporter, err := otlptracegrpc.New(ctx,
|
||||
otlptracegrpc.WithEndpoint(t.config.Endpoint),
|
||||
otlptracegrpc.WithInsecure(), // 开发环境使用,生产环境应配置TLS
|
||||
otlptracegrpc.WithTimeout(time.Second*10),
|
||||
otlptracegrpc.WithRetry(otlptracegrpc.RetryConfig{
|
||||
Enabled: true,
|
||||
InitialInterval: time.Millisecond * 100,
|
||||
MaxInterval: time.Second * 5,
|
||||
MaxElapsedTime: time.Second * 30,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
t.logger.Warn("Failed to create OTLP exporter, using noop exporter",
|
||||
zap.Error(err),
|
||||
zap.String("endpoint", t.config.Endpoint))
|
||||
spanProcessor = sdktrace.NewSimpleSpanProcessor(&noopExporter{})
|
||||
} else {
|
||||
// 在生产环境中使用批处理器以提高性能
|
||||
spanProcessor = sdktrace.NewBatchSpanProcessor(exporter,
|
||||
sdktrace.WithBatchTimeout(time.Second*5),
|
||||
sdktrace.WithMaxExportBatchSize(512),
|
||||
sdktrace.WithMaxQueueSize(2048),
|
||||
sdktrace.WithExportTimeout(time.Second*30),
|
||||
)
|
||||
t.logger.Info("OTLP exporter initialized successfully",
|
||||
zap.String("endpoint", t.config.Endpoint))
|
||||
}
|
||||
} else {
|
||||
// 如果没有配置端点,使用空导出器
|
||||
spanProcessor = sdktrace.NewSimpleSpanProcessor(&noopExporter{})
|
||||
t.logger.Info("Using noop exporter (no endpoint configured)")
|
||||
}
|
||||
|
||||
// 创建TracerProvider
|
||||
provider := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithResource(res),
|
||||
sdktrace.WithSampler(sampler),
|
||||
sdktrace.WithSpanProcessor(spanProcessor),
|
||||
)
|
||||
|
||||
// 设置全局TracerProvider
|
||||
otel.SetTracerProvider(provider)
|
||||
|
||||
// 创建Tracer
|
||||
tracer := provider.Tracer(t.config.ServiceName)
|
||||
|
||||
t.provider = provider
|
||||
t.tracer = tracer
|
||||
t.shutdown = func(ctx context.Context) error {
|
||||
return provider.Shutdown(ctx)
|
||||
}
|
||||
t.initialized = true
|
||||
|
||||
t.logger.Info("Tracing initialized successfully",
|
||||
zap.String("service", t.config.ServiceName),
|
||||
zap.Float64("sample_rate", t.config.SampleRate))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartSpan 开始一个新的span
|
||||
func (t *Tracer) StartSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
if !t.initialized || !t.config.Enabled {
|
||||
return ctx, trace.SpanFromContext(ctx)
|
||||
}
|
||||
|
||||
return t.tracer.Start(ctx, name, opts...)
|
||||
}
|
||||
|
||||
// StartHTTPSpan 开始一个HTTP span
|
||||
func (t *Tracer) StartHTTPSpan(ctx context.Context, method, path string) (context.Context, trace.Span) {
|
||||
spanName := fmt.Sprintf("%s %s", method, path)
|
||||
|
||||
// 检查是否已有错误标记,如果有则使用"error"作为操作名
|
||||
// 这样可以匹配Jaeger采样配置中的错误操作策略
|
||||
if ctx.Value("otel_error_request") != nil {
|
||||
spanName = "error"
|
||||
}
|
||||
|
||||
ctx, span := t.StartSpan(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindServer),
|
||||
trace.WithAttributes(
|
||||
attribute.String("http.method", method),
|
||||
attribute.String("http.route", path),
|
||||
),
|
||||
)
|
||||
|
||||
// 保存原始操作名,以便在错误发生时可以更新
|
||||
if ctx.Value("otel_error_request") == nil {
|
||||
ctx = context.WithValue(ctx, "otel_original_operation", spanName)
|
||||
}
|
||||
|
||||
return ctx, span
|
||||
}
|
||||
|
||||
// StartDBSpan 开始一个数据库span
|
||||
func (t *Tracer) StartDBSpan(ctx context.Context, operation, table string) (context.Context, trace.Span) {
|
||||
spanName := fmt.Sprintf("db.%s.%s", operation, table)
|
||||
|
||||
return t.StartSpan(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindClient),
|
||||
trace.WithAttributes(
|
||||
attribute.String("db.operation", operation),
|
||||
attribute.String("db.table", table),
|
||||
attribute.String("db.system", "postgresql"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// StartCacheSpan 开始一个缓存span
|
||||
func (t *Tracer) StartCacheSpan(ctx context.Context, operation, key string) (context.Context, trace.Span) {
|
||||
spanName := fmt.Sprintf("cache.%s", operation)
|
||||
|
||||
return t.StartSpan(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindClient),
|
||||
trace.WithAttributes(
|
||||
attribute.String("cache.operation", operation),
|
||||
attribute.String("cache.system", "redis"),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// StartExternalAPISpan 开始一个外部API调用span
|
||||
func (t *Tracer) StartExternalAPISpan(ctx context.Context, service, operation string) (context.Context, trace.Span) {
|
||||
spanName := fmt.Sprintf("api.%s.%s", service, operation)
|
||||
|
||||
return t.StartSpan(ctx, spanName,
|
||||
trace.WithSpanKind(trace.SpanKindClient),
|
||||
trace.WithAttributes(
|
||||
attribute.String("api.service", service),
|
||||
attribute.String("api.operation", operation),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
// AddSpanAttributes 添加span属性
|
||||
func (t *Tracer) AddSpanAttributes(span trace.Span, attrs ...attribute.KeyValue) {
|
||||
if span.IsRecording() {
|
||||
span.SetAttributes(attrs...)
|
||||
}
|
||||
}
|
||||
|
||||
// SetSpanError 设置span错误
|
||||
func (t *Tracer) SetSpanError(span trace.Span, err error) {
|
||||
if span.IsRecording() {
|
||||
span.SetStatus(codes.Error, err.Error())
|
||||
span.RecordError(err)
|
||||
|
||||
// 将span操作名更新为"error",以匹配Jaeger采样配置
|
||||
// 注意:这是一种变通方法,因为OpenTelemetry不支持直接更改span名称
|
||||
// 我们通过添加特殊属性来标识这是一个错误span
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 记录错误日志,包含trace ID便于关联
|
||||
if t.logger != nil {
|
||||
ctx := trace.ContextWithSpan(context.Background(), span)
|
||||
t.logger.Error("操作发生错误",
|
||||
zap.Error(err),
|
||||
zap.String("trace_id", t.GetTraceID(ctx)),
|
||||
zap.String("span_id", t.GetSpanID(ctx)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetSpanSuccess 设置span成功
|
||||
func (t *Tracer) SetSpanSuccess(span trace.Span) {
|
||||
if span.IsRecording() {
|
||||
span.SetStatus(codes.Ok, "success")
|
||||
}
|
||||
}
|
||||
|
||||
// SetHTTPStatus 根据HTTP状态码设置span状态
|
||||
func (t *Tracer) SetHTTPStatus(span trace.Span, statusCode int) {
|
||||
if !span.IsRecording() {
|
||||
return
|
||||
}
|
||||
|
||||
// 添加HTTP状态码属性
|
||||
span.SetAttributes(attribute.Int("http.status_code", statusCode))
|
||||
|
||||
// 对于4xx和5xx错误,标记为错误并应用错误采样策略
|
||||
if statusCode >= 400 {
|
||||
errorMsg := fmt.Sprintf("HTTP %d", statusCode)
|
||||
span.SetStatus(codes.Error, errorMsg)
|
||||
|
||||
// 添加错误操作标记,以匹配Jaeger采样配置
|
||||
span.SetAttributes(
|
||||
attribute.String("error.operation", "true"),
|
||||
attribute.String("operation.type", "error"),
|
||||
)
|
||||
|
||||
// 记录HTTP错误
|
||||
if t.logger != nil {
|
||||
ctx := trace.ContextWithSpan(context.Background(), span)
|
||||
t.logger.Warn("HTTP请求错误",
|
||||
zap.Int("status_code", statusCode),
|
||||
zap.String("trace_id", t.GetTraceID(ctx)),
|
||||
zap.String("span_id", t.GetSpanID(ctx)),
|
||||
)
|
||||
}
|
||||
} else {
|
||||
span.SetStatus(codes.Ok, "success")
|
||||
}
|
||||
}
|
||||
|
||||
// GetTraceID 获取当前上下文的trace ID
|
||||
func (t *Tracer) GetTraceID(ctx context.Context) string {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.SpanContext().IsValid() {
|
||||
return span.SpanContext().TraceID().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetSpanID 获取当前上下文的span ID
|
||||
func (t *Tracer) GetSpanID(ctx context.Context) string {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
if span.SpanContext().IsValid() {
|
||||
return span.SpanContext().SpanID().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsTracing 检查是否正在追踪
|
||||
func (t *Tracer) IsTracing(ctx context.Context) bool {
|
||||
span := trace.SpanFromContext(ctx)
|
||||
return span.SpanContext().IsValid() && span.IsRecording()
|
||||
}
|
||||
|
||||
// Shutdown 关闭追踪器
|
||||
func (t *Tracer) Shutdown(ctx context.Context) error {
|
||||
t.mutex.Lock()
|
||||
defer t.mutex.Unlock()
|
||||
|
||||
if !t.initialized || t.shutdown == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := t.shutdown(ctx)
|
||||
if err != nil {
|
||||
t.logger.Error("Failed to shutdown tracer", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
t.initialized = false
|
||||
t.logger.Info("Tracer shutdown successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取追踪统计信息
|
||||
func (t *Tracer) GetStats() map[string]interface{} {
|
||||
t.mutex.RLock()
|
||||
defer t.mutex.RUnlock()
|
||||
|
||||
return map[string]interface{}{
|
||||
"initialized": t.initialized,
|
||||
"enabled": t.config.Enabled,
|
||||
"service_name": t.config.ServiceName,
|
||||
"service_version": t.config.ServiceVersion,
|
||||
"environment": t.config.Environment,
|
||||
"sample_rate": t.config.SampleRate,
|
||||
"endpoint": t.config.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// 实现Service接口
|
||||
|
||||
// Name 返回服务名称
|
||||
func (t *Tracer) Name() string {
|
||||
return "tracer"
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (t *Tracer) HealthCheck(ctx context.Context) error {
|
||||
if !t.config.Enabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !t.initialized {
|
||||
return fmt.Errorf("tracer not initialized")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// noopExporter 简单的无操作导出器(用于演示)
|
||||
type noopExporter struct{}
|
||||
|
||||
func (e *noopExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error {
|
||||
// 在实际应用中,这里应该将spans发送到Jaeger或其他追踪系统
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *noopExporter) Shutdown(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TraceMiddleware 追踪中间件工厂
|
||||
func (t *Tracer) TraceMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !t.initialized || !t.config.Enabled {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 开始HTTP span
|
||||
ctx, span := t.StartHTTPSpan(c.Request.Context(), c.Request.Method, c.FullPath())
|
||||
defer span.End()
|
||||
|
||||
// 将trace ID添加到响应头
|
||||
traceID := t.GetTraceID(ctx)
|
||||
if traceID != "" {
|
||||
c.Header("X-Trace-ID", traceID)
|
||||
}
|
||||
|
||||
// 将span上下文存储到gin上下文
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
// 处理请求
|
||||
c.Next()
|
||||
|
||||
// 设置HTTP状态码
|
||||
t.SetHTTPStatus(span, c.Writer.Status())
|
||||
|
||||
// 添加响应信息
|
||||
t.AddSpanAttributes(span,
|
||||
attribute.Int("http.status_code", c.Writer.Status()),
|
||||
attribute.Int("http.response_size", c.Writer.Size()),
|
||||
)
|
||||
|
||||
// 添加错误信息
|
||||
if len(c.Errors) > 0 {
|
||||
errMsg := c.Errors.String()
|
||||
t.SetSpanError(span, fmt.Errorf(errMsg))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GinTraceMiddleware 兼容旧的方法名,保持向后兼容
|
||||
func (t *Tracer) GinTraceMiddleware() gin.HandlerFunc {
|
||||
return t.TraceMiddleware()
|
||||
}
|
||||
|
||||
// WithTracing 添加追踪到上下文的辅助函数
|
||||
func WithTracing(ctx context.Context, tracer *Tracer, name string) (context.Context, trace.Span) {
|
||||
return tracer.StartSpan(ctx, name)
|
||||
}
|
||||
|
||||
// TraceFunction 追踪函数执行的辅助函数
|
||||
func (t *Tracer) TraceFunction(ctx context.Context, name string, fn func(context.Context) error) error {
|
||||
ctx, span := t.StartSpan(ctx, name)
|
||||
defer span.End()
|
||||
|
||||
err := fn(ctx)
|
||||
if err != nil {
|
||||
t.SetSpanError(span, err)
|
||||
} else {
|
||||
t.SetSpanSuccess(span)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// TraceFunctionWithResult 追踪带返回值的函数执行
|
||||
func TraceFunctionWithResult[T any](ctx context.Context, tracer *Tracer, name string, fn func(context.Context) (T, error)) (T, error) {
|
||||
ctx, span := tracer.StartSpan(ctx, name)
|
||||
defer span.End()
|
||||
|
||||
result, err := fn(ctx)
|
||||
if err != nil {
|
||||
tracer.SetSpanError(span, err)
|
||||
} else {
|
||||
tracer.SetSpanSuccess(span)
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
Reference in New Issue
Block a user