feat(架构): 完善基础架构设计

This commit is contained in:
2025-07-02 16:17:59 +08:00
parent 03e615a8fd
commit 5b4392894f
89 changed files with 18555 additions and 3521 deletions

View File

@@ -36,7 +36,7 @@ func (h *HealthChecker) RegisterService(service interfaces.Service) {
defer h.mutex.Unlock()
h.services[service.Name()] = service
h.logger.Info("Registered service for health check", zap.String("service", service.Name()))
h.logger.Info("服务已注册健康检查", zap.String("service", service.Name()))
}
// CheckHealth 检查单个服务健康状态
@@ -47,8 +47,8 @@ func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *in
h.mutex.RUnlock()
return &interfaces.HealthStatus{
Status: "DOWN",
Message: "Service not found",
Details: map[string]interface{}{"error": "service not registered"},
Message: "服务未找到",
Details: map[string]interface{}{"error": "服务未注册"},
CheckedAt: time.Now().Unix(),
ResponseTime: 0,
}
@@ -79,24 +79,24 @@ func (h *HealthChecker) CheckHealth(ctx context.Context, serviceName string) *in
if err != nil {
status.Status = "DOWN"
status.Message = "Health check failed"
status.Message = "健康检查失败"
status.Details = map[string]interface{}{
"error": err.Error(),
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Warn("Service health check failed",
h.logger.Warn("服务健康检查失败",
zap.String("service", serviceName),
zap.Error(err),
zap.Int64("response_time_ms", responseTime))
} else {
status.Status = "UP"
status.Message = "Service is healthy"
status.Message = "服务运行正常"
status.Details = map[string]interface{}{
"service_name": serviceName,
"check_time": start.Format(time.RFC3339),
}
h.logger.Debug("Service health check passed",
h.logger.Debug("服务健康检查通过",
zap.String("service", serviceName),
zap.Int64("response_time_ms", responseTime))
}
@@ -173,13 +173,13 @@ func (h *HealthChecker) GetOverallStatus(ctx context.Context) *interfaces.Health
// 确定整体状态
if healthyCount == totalCount {
overall.Status = "UP"
overall.Message = "All services are healthy"
overall.Message = "所有服务运行正常"
} else if healthyCount == 0 {
overall.Status = "DOWN"
overall.Message = "All services are down"
overall.Message = "所有服务均已下线"
} else {
overall.Status = "DEGRADED"
overall.Message = fmt.Sprintf("%d of %d services are healthy", healthyCount, totalCount)
overall.Message = fmt.Sprintf("%d/%d 个服务运行正常", healthyCount, totalCount)
}
return overall
@@ -205,7 +205,7 @@ func (h *HealthChecker) RemoveService(serviceName string) {
delete(h.services, serviceName)
delete(h.cache, serviceName)
h.logger.Info("Removed service from health check", zap.String("service", serviceName))
h.logger.Info("服务已从健康检查中移除", zap.String("service", serviceName))
}
// ClearCache 清除缓存
@@ -214,7 +214,7 @@ func (h *HealthChecker) ClearCache() {
defer h.mutex.Unlock()
h.cache = make(map[string]*interfaces.HealthStatus)
h.logger.Debug("Health check cache cleared")
h.logger.Debug("健康检查缓存已清除")
}
// GetCacheStats 获取缓存统计
@@ -243,7 +243,7 @@ func (h *HealthChecker) SetCacheTTL(ttl time.Duration) {
defer h.mutex.Unlock()
h.cacheTTL = ttl
h.logger.Info("Updated health check cache TTL", zap.Duration("ttl", ttl))
h.logger.Info("健康检查缓存TTL已更新", zap.Duration("ttl", ttl))
}
// StartPeriodicCheck 启动定期健康检查
@@ -251,12 +251,12 @@ func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Du
ticker := time.NewTicker(interval)
defer ticker.Stop()
h.logger.Info("Started periodic health check", zap.Duration("interval", interval))
h.logger.Info("已启动定期健康检查", zap.Duration("interval", interval))
for {
select {
case <-ctx.Done():
h.logger.Info("Stopped periodic health check")
h.logger.Info("已停止定期健康检查")
return
case <-ticker.C:
h.performPeriodicCheck(ctx)
@@ -268,14 +268,14 @@ func (h *HealthChecker) StartPeriodicCheck(ctx context.Context, interval time.Du
func (h *HealthChecker) performPeriodicCheck(ctx context.Context) {
overall := h.GetOverallStatus(ctx)
h.logger.Info("Periodic health check completed",
h.logger.Info("定期健康检查已完成",
zap.String("overall_status", overall.Status),
zap.String("message", overall.Message),
zap.Int64("response_time_ms", overall.ResponseTime))
// 如果有服务下线,记录警告
if overall.Status != "UP" {
h.logger.Warn("Some services are not healthy",
h.logger.Warn("部分服务不健康",
zap.String("status", overall.Status),
zap.Any("details", overall.Details))
}

View File

@@ -0,0 +1,587 @@
package hooks
import (
"context"
"fmt"
"reflect"
"sort"
"sync"
"time"
"go.uber.org/zap"
)
// HookPriority 钩子优先级
type HookPriority int
const (
// PriorityLowest 最低优先级
PriorityLowest HookPriority = 0
// PriorityLow 低优先级
PriorityLow HookPriority = 25
// PriorityNormal 普通优先级
PriorityNormal HookPriority = 50
// PriorityHigh 高优先级
PriorityHigh HookPriority = 75
// PriorityHighest 最高优先级
PriorityHighest HookPriority = 100
)
// HookFunc 钩子函数类型
type HookFunc func(ctx context.Context, data interface{}) error
// Hook 钩子定义
type Hook struct {
Name string
Func HookFunc
Priority HookPriority
Async bool
Timeout time.Duration
}
// HookResult 钩子执行结果
type HookResult struct {
HookName string `json:"hook_name"`
Success bool `json:"success"`
Duration time.Duration `json:"duration"`
Error string `json:"error,omitempty"`
}
// HookConfig 钩子配置
type HookConfig struct {
// 默认超时时间
DefaultTimeout time.Duration
// 是否记录执行时间
TrackDuration bool
// 错误处理策略
ErrorStrategy ErrorStrategy
}
// ErrorStrategy 错误处理策略
type ErrorStrategy int
const (
// ContinueOnError 遇到错误继续执行
ContinueOnError ErrorStrategy = iota
// StopOnError 遇到错误停止执行
StopOnError
// CollectErrors 收集所有错误
CollectErrors
)
// DefaultHookConfig 默认钩子配置
func DefaultHookConfig() HookConfig {
return HookConfig{
DefaultTimeout: 30 * time.Second,
TrackDuration: true,
ErrorStrategy: ContinueOnError,
}
}
// HookSystem 钩子系统
type HookSystem struct {
hooks map[string][]*Hook
config HookConfig
logger *zap.Logger
mutex sync.RWMutex
stats map[string]*HookStats
}
// HookStats 钩子统计
type HookStats struct {
TotalExecutions int `json:"total_executions"`
Successes int `json:"successes"`
Failures int `json:"failures"`
TotalDuration time.Duration `json:"total_duration"`
AverageDuration time.Duration `json:"average_duration"`
LastExecution time.Time `json:"last_execution"`
LastError string `json:"last_error,omitempty"`
}
// NewHookSystem 创建钩子系统
func NewHookSystem(config HookConfig, logger *zap.Logger) *HookSystem {
return &HookSystem{
hooks: make(map[string][]*Hook),
config: config,
logger: logger,
stats: make(map[string]*HookStats),
}
}
// Register 注册钩子
func (hs *HookSystem) Register(event string, hook *Hook) error {
if hook.Name == "" {
return fmt.Errorf("hook name cannot be empty")
}
if hook.Func == nil {
return fmt.Errorf("hook function cannot be nil")
}
if hook.Timeout == 0 {
hook.Timeout = hs.config.DefaultTimeout
}
hs.mutex.Lock()
defer hs.mutex.Unlock()
// 检查是否已经注册了同名钩子
for _, existingHook := range hs.hooks[event] {
if existingHook.Name == hook.Name {
return fmt.Errorf("hook %s already registered for event %s", hook.Name, event)
}
}
hs.hooks[event] = append(hs.hooks[event], hook)
// 按优先级排序
sort.Slice(hs.hooks[event], func(i, j int) bool {
return hs.hooks[event][i].Priority > hs.hooks[event][j].Priority
})
// 初始化统计
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
hs.stats[hookKey] = &HookStats{}
hs.logger.Info("Registered hook",
zap.String("event", event),
zap.String("hook_name", hook.Name),
zap.Int("priority", int(hook.Priority)),
zap.Bool("async", hook.Async))
return nil
}
// RegisterFunc 注册钩子函数(简化版)
func (hs *HookSystem) RegisterFunc(event, name string, priority HookPriority, fn HookFunc) error {
hook := &Hook{
Name: name,
Func: fn,
Priority: priority,
Async: false,
Timeout: hs.config.DefaultTimeout,
}
return hs.Register(event, hook)
}
// RegisterAsyncFunc 注册异步钩子函数
func (hs *HookSystem) RegisterAsyncFunc(event, name string, priority HookPriority, fn HookFunc) error {
hook := &Hook{
Name: name,
Func: fn,
Priority: priority,
Async: true,
Timeout: hs.config.DefaultTimeout,
}
return hs.Register(event, hook)
}
// Unregister 取消注册钩子
func (hs *HookSystem) Unregister(event, hookName string) error {
hs.mutex.Lock()
defer hs.mutex.Unlock()
hooks := hs.hooks[event]
for i, hook := range hooks {
if hook.Name == hookName {
// 删除钩子
hs.hooks[event] = append(hooks[:i], hooks[i+1:]...)
// 删除统计
hookKey := fmt.Sprintf("%s.%s", event, hookName)
delete(hs.stats, hookKey)
hs.logger.Info("Unregistered hook",
zap.String("event", event),
zap.String("hook_name", hookName))
return nil
}
}
return fmt.Errorf("hook %s not found for event %s", hookName, event)
}
// Trigger 触发事件
func (hs *HookSystem) Trigger(ctx context.Context, event string, data interface{}) ([]HookResult, error) {
hs.mutex.RLock()
hooks := make([]*Hook, len(hs.hooks[event]))
copy(hooks, hs.hooks[event])
hs.mutex.RUnlock()
if len(hooks) == 0 {
hs.logger.Debug("No hooks registered for event", zap.String("event", event))
return nil, nil
}
hs.logger.Debug("Triggering event",
zap.String("event", event),
zap.Int("hook_count", len(hooks)))
results := make([]HookResult, 0, len(hooks))
var errors []error
for _, hook := range hooks {
result := hs.executeHook(ctx, event, hook, data)
results = append(results, result)
if !result.Success {
err := fmt.Errorf("hook %s failed: %s", hook.Name, result.Error)
errors = append(errors, err)
// 根据错误策略决定是否继续
if hs.config.ErrorStrategy == StopOnError {
break
}
}
}
// 处理错误
if len(errors) > 0 {
switch hs.config.ErrorStrategy {
case StopOnError:
return results, errors[0]
case CollectErrors:
return results, fmt.Errorf("multiple hook errors: %v", errors)
case ContinueOnError:
// 继续执行,但记录错误
hs.logger.Warn("Some hooks failed but continuing execution",
zap.String("event", event),
zap.Int("error_count", len(errors)))
}
}
return results, nil
}
// executeHook 执行单个钩子
func (hs *HookSystem) executeHook(ctx context.Context, event string, hook *Hook, data interface{}) HookResult {
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
start := time.Now()
result := HookResult{
HookName: hook.Name,
Success: false,
}
// 更新统计
defer func() {
result.Duration = time.Since(start)
hs.updateStats(hookKey, result)
}()
if hook.Async {
// 异步执行
go func() {
hs.doExecuteHook(ctx, hook, data)
}()
result.Success = true // 异步执行总是认为成功
return result
}
// 同步执行
err := hs.doExecuteHook(ctx, hook, data)
if err != nil {
result.Error = err.Error()
hs.logger.Error("Hook execution failed",
zap.String("event", event),
zap.String("hook_name", hook.Name),
zap.Error(err))
} else {
result.Success = true
hs.logger.Debug("Hook executed successfully",
zap.String("event", event),
zap.String("hook_name", hook.Name))
}
return result
}
// doExecuteHook 实际执行钩子
func (hs *HookSystem) doExecuteHook(ctx context.Context, hook *Hook, data interface{}) error {
// 设置超时上下文
hookCtx, cancel := context.WithTimeout(ctx, hook.Timeout)
defer cancel()
// 在goroutine中执行以便处理超时
errChan := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("hook panicked: %v", r)
}
}()
errChan <- hook.Func(hookCtx, data)
}()
select {
case err := <-errChan:
return err
case <-hookCtx.Done():
return fmt.Errorf("hook execution timeout after %v", hook.Timeout)
}
}
// updateStats 更新统计信息
func (hs *HookSystem) updateStats(hookKey string, result HookResult) {
hs.mutex.Lock()
defer hs.mutex.Unlock()
stats, exists := hs.stats[hookKey]
if !exists {
stats = &HookStats{}
hs.stats[hookKey] = stats
}
stats.TotalExecutions++
stats.LastExecution = time.Now()
if result.Success {
stats.Successes++
} else {
stats.Failures++
stats.LastError = result.Error
}
if hs.config.TrackDuration {
stats.TotalDuration += result.Duration
stats.AverageDuration = stats.TotalDuration / time.Duration(stats.TotalExecutions)
}
}
// GetHooks 获取事件的所有钩子
func (hs *HookSystem) GetHooks(event string) []*Hook {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
hooks := make([]*Hook, len(hs.hooks[event]))
copy(hooks, hs.hooks[event])
return hooks
}
// GetEvents 获取所有注册的事件
func (hs *HookSystem) GetEvents() []string {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
events := make([]string, 0, len(hs.hooks))
for event := range hs.hooks {
events = append(events, event)
}
sort.Strings(events)
return events
}
// GetStats 获取钩子统计信息
func (hs *HookSystem) GetStats() map[string]*HookStats {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
stats := make(map[string]*HookStats)
for key, stat := range hs.stats {
statCopy := *stat
stats[key] = &statCopy
}
return stats
}
// GetEventStats 获取特定事件的统计信息
func (hs *HookSystem) GetEventStats(event string) map[string]*HookStats {
allStats := hs.GetStats()
eventStats := make(map[string]*HookStats)
prefix := event + "."
for key, stat := range allStats {
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
hookName := key[len(prefix):]
eventStats[hookName] = stat
}
}
return eventStats
}
// Clear 清除所有钩子
func (hs *HookSystem) Clear() {
hs.mutex.Lock()
defer hs.mutex.Unlock()
hs.hooks = make(map[string][]*Hook)
hs.stats = make(map[string]*HookStats)
hs.logger.Info("Cleared all hooks")
}
// ClearEvent 清除特定事件的所有钩子
func (hs *HookSystem) ClearEvent(event string) {
hs.mutex.Lock()
defer hs.mutex.Unlock()
// 删除钩子
delete(hs.hooks, event)
// 删除统计
prefix := event + "."
for key := range hs.stats {
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
delete(hs.stats, key)
}
}
hs.logger.Info("Cleared hooks for event", zap.String("event", event))
}
// Count 获取钩子总数
func (hs *HookSystem) Count() int {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
total := 0
for _, hooks := range hs.hooks {
total += len(hooks)
}
return total
}
// EventCount 获取特定事件的钩子数量
func (hs *HookSystem) EventCount(event string) int {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
return len(hs.hooks[event])
}
// 实现Service接口
// Name 返回服务名称
func (hs *HookSystem) Name() string {
return "hook-system"
}
// Initialize 初始化钩子系统
func (hs *HookSystem) Initialize(ctx context.Context) error {
hs.logger.Info("Hook system initialized")
return nil
}
// Start 启动钩子系统
func (hs *HookSystem) Start(ctx context.Context) error {
hs.logger.Info("Hook system started")
return nil
}
// HealthCheck 健康检查
func (hs *HookSystem) HealthCheck(ctx context.Context) error {
return nil
}
// Shutdown 关闭钩子系统
func (hs *HookSystem) Shutdown(ctx context.Context) error {
hs.logger.Info("Hook system shutdown")
return nil
}
// 便捷方法
// OnUserCreated 用户创建事件钩子
func (hs *HookSystem) OnUserCreated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.created", name, priority, fn)
}
// OnUserUpdated 用户更新事件钩子
func (hs *HookSystem) OnUserUpdated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.updated", name, priority, fn)
}
// OnUserDeleted 用户删除事件钩子
func (hs *HookSystem) OnUserDeleted(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.deleted", name, priority, fn)
}
// OnOrderCreated 订单创建事件钩子
func (hs *HookSystem) OnOrderCreated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("order.created", name, priority, fn)
}
// OnOrderCompleted 订单完成事件钩子
func (hs *HookSystem) OnOrderCompleted(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("order.completed", name, priority, fn)
}
// TriggerUserCreated 触发用户创建事件
func (hs *HookSystem) TriggerUserCreated(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.created", user)
}
// TriggerUserUpdated 触发用户更新事件
func (hs *HookSystem) TriggerUserUpdated(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.updated", user)
}
// TriggerUserDeleted 触发用户删除事件
func (hs *HookSystem) TriggerUserDeleted(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.deleted", user)
}
// HookBuilder 钩子构建器
type HookBuilder struct {
hook *Hook
}
// NewHookBuilder 创建钩子构建器
func NewHookBuilder(name string, fn HookFunc) *HookBuilder {
return &HookBuilder{
hook: &Hook{
Name: name,
Func: fn,
Priority: PriorityNormal,
Async: false,
Timeout: 30 * time.Second,
},
}
}
// WithPriority 设置优先级
func (hb *HookBuilder) WithPriority(priority HookPriority) *HookBuilder {
hb.hook.Priority = priority
return hb
}
// WithTimeout 设置超时时间
func (hb *HookBuilder) WithTimeout(timeout time.Duration) *HookBuilder {
hb.hook.Timeout = timeout
return hb
}
// Async 设置为异步执行
func (hb *HookBuilder) Async() *HookBuilder {
hb.hook.Async = true
return hb
}
// Build 构建钩子
func (hb *HookBuilder) Build() *Hook {
return hb.hook
}
// TypedHookFunc 类型化钩子函数
type TypedHookFunc[T any] func(ctx context.Context, data T) error
// RegisterTypedFunc 注册类型化钩子函数
func RegisterTypedFunc[T any](hs *HookSystem, event, name string, priority HookPriority, fn TypedHookFunc[T]) error {
hookFunc := func(ctx context.Context, data interface{}) error {
typedData, ok := data.(T)
if !ok {
return fmt.Errorf("invalid data type for hook %s, expected %s", name, reflect.TypeOf((*T)(nil)).Elem().Name())
}
return fn(ctx, typedData)
}
return hs.RegisterFunc(event, name, priority, hookFunc)
}

View File

@@ -20,7 +20,7 @@ func NewResponseBuilder() interfaces.ResponseBuilder {
// Success 成功响应
func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...string) {
msg := "Success"
msg := "操作成功"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -38,7 +38,7 @@ func (r *ResponseBuilder) Success(c *gin.Context, data interface{}, message ...s
// Created 创建成功响应
func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...string) {
msg := "Created successfully"
msg := "创建成功"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -58,7 +58,7 @@ func (r *ResponseBuilder) Created(c *gin.Context, data interface{}, message ...s
func (r *ResponseBuilder) Error(c *gin.Context, err error) {
// 根据错误类型确定状态码
statusCode := http.StatusInternalServerError
message := "Internal server error"
message := "服务器内部错误"
errorDetail := err.Error()
// 这里可以根据不同的错误类型设置不同的状态码
@@ -93,7 +93,7 @@ func (r *ResponseBuilder) BadRequest(c *gin.Context, message string, errors ...i
// Unauthorized 401错误响应
func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
msg := "Unauthorized"
msg := "用户未登录或认证已过期"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -110,7 +110,7 @@ func (r *ResponseBuilder) Unauthorized(c *gin.Context, message ...string) {
// Forbidden 403错误响应
func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
msg := "Forbidden"
msg := "权限不足,无法访问此资源"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -127,7 +127,7 @@ func (r *ResponseBuilder) Forbidden(c *gin.Context, message ...string) {
// NotFound 404错误响应
func (r *ResponseBuilder) NotFound(c *gin.Context, message ...string) {
msg := "Resource not found"
msg := "请求的资源不存在"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -156,7 +156,7 @@ func (r *ResponseBuilder) Conflict(c *gin.Context, message string) {
// InternalError 500错误响应
func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
msg := "Internal server error"
msg := "服务器内部错误"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}
@@ -175,7 +175,7 @@ func (r *ResponseBuilder) InternalError(c *gin.Context, message ...string) {
func (r *ResponseBuilder) Paginated(c *gin.Context, data interface{}, pagination interfaces.PaginationMeta) {
response := interfaces.APIResponse{
Success: true,
Message: "Success",
Message: "查询成功",
Data: data,
Pagination: &pagination,
RequestID: r.getRequestID(c),
@@ -215,9 +215,35 @@ func BuildPagination(page, pageSize int, total int64) interfaces.PaginationMeta
// CustomResponse 自定义响应
func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data interface{}) {
var message string
switch statusCode {
case http.StatusOK:
message = "请求成功"
case http.StatusCreated:
message = "创建成功"
case http.StatusNoContent:
message = "无内容"
case http.StatusBadRequest:
message = "请求参数错误"
case http.StatusUnauthorized:
message = "认证失败"
case http.StatusForbidden:
message = "权限不足"
case http.StatusNotFound:
message = "资源不存在"
case http.StatusConflict:
message = "资源冲突"
case http.StatusTooManyRequests:
message = "请求过于频繁"
case http.StatusInternalServerError:
message = "服务器内部错误"
default:
message = "未知状态"
}
response := interfaces.APIResponse{
Success: statusCode >= 200 && statusCode < 300,
Message: http.StatusText(statusCode),
Message: message,
Data: data,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
@@ -230,7 +256,7 @@ func (r *ResponseBuilder) CustomResponse(c *gin.Context, statusCode int, data in
func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
response := interfaces.APIResponse{
Success: false,
Message: "Validation failed",
Message: "请求参数验证失败",
Errors: errors,
RequestID: r.getRequestID(c),
Timestamp: time.Now().Unix(),
@@ -241,7 +267,7 @@ func (r *ResponseBuilder) ValidationError(c *gin.Context, errors interface{}) {
// TooManyRequests 限流错误响应
func (r *ResponseBuilder) TooManyRequests(c *gin.Context, message ...string) {
msg := "Too many requests"
msg := "请求过于频繁,请稍后再试"
if len(message) > 0 && message[0] != "" {
msg = message[0]
}

View File

@@ -8,6 +8,8 @@ import (
"time"
"github.com/gin-gonic/gin"
swaggerFiles "github.com/swaggo/files"
ginSwagger "github.com/swaggo/gin-swagger"
"go.uber.org/zap"
"tyapi-server/internal/config"
@@ -51,7 +53,7 @@ func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
// 注册路由
r.engine.Handle(handler.GetMethod(), handler.GetPath(), append(middlewares, handler.Handle)...)
r.logger.Info("Registered HTTP handler",
r.logger.Info("已注册HTTP处理器",
zap.String("method", handler.GetMethod()),
zap.String("path", handler.GetPath()))
@@ -62,7 +64,7 @@ func (r *GinRouter) RegisterHandler(handler interfaces.HTTPHandler) error {
func (r *GinRouter) RegisterMiddleware(middleware interfaces.Middleware) error {
r.middlewares = append(r.middlewares, middleware)
r.logger.Info("Registered middleware",
r.logger.Info("已注册中间件",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
@@ -93,7 +95,7 @@ func (r *GinRouter) Start(addr string) error {
IdleTimeout: r.config.Server.IdleTimeout,
}
r.logger.Info("Starting HTTP server", zap.String("addr", addr))
r.logger.Info("正在启动HTTP服务器", zap.String("addr", addr))
// 启动服务器
if err := r.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
@@ -109,15 +111,15 @@ func (r *GinRouter) Stop(ctx context.Context) error {
return nil
}
r.logger.Info("Stopping HTTP server...")
r.logger.Info("正在关闭HTTP服务器...")
// 优雅关闭服务器
if err := r.server.Shutdown(ctx); err != nil {
r.logger.Error("Failed to shutdown server gracefully", zap.Error(err))
r.logger.Error("优雅关闭服务器失败", zap.Error(err))
return err
}
r.logger.Info("HTTP server stopped")
r.logger.Info("HTTP服务器已关闭")
return nil
}
@@ -137,7 +139,7 @@ func (r *GinRouter) applyMiddlewares() {
for _, middleware := range r.middlewares {
if middleware.IsGlobal() {
r.engine.Use(middleware.Handle())
r.logger.Debug("Applied global middleware",
r.logger.Debug("已应用全局中间件",
zap.String("name", middleware.GetName()),
zap.Int("priority", middleware.GetPriority()))
}
@@ -156,6 +158,18 @@ func (r *GinRouter) SetupDefaultRoutes() {
})
})
// 详细健康检查
r.engine.GET("/health/detailed", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"status": "healthy",
"timestamp": time.Now().Unix(),
"service": r.config.App.Name,
"version": r.config.App.Version,
"uptime": time.Now().Unix(),
"environment": r.config.App.Env,
})
})
// API信息
r.engine.GET("/info", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
@@ -166,11 +180,37 @@ func (r *GinRouter) SetupDefaultRoutes() {
})
})
// Swagger文档路由 (仅在开发环境启用)
if !r.config.App.IsProduction() {
// Swagger UI
r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
// API文档重定向
r.engine.GET("/docs", func(c *gin.Context) {
c.Redirect(http.StatusMovedPermanently, "/swagger/index.html")
})
// API文档信息
r.engine.GET("/api/docs", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"swagger_ui": fmt.Sprintf("http://%s/swagger/index.html", c.Request.Host),
"openapi_json": fmt.Sprintf("http://%s/swagger/doc.json", c.Request.Host),
"redoc": fmt.Sprintf("http://%s/redoc", c.Request.Host),
"message": "API文档已可用",
})
})
r.logger.Info("Swagger documentation enabled",
zap.String("swagger_url", "/swagger/index.html"),
zap.String("docs_url", "/docs"),
zap.String("api_docs_url", "/api/docs"))
}
// 404处理
r.engine.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{
"success": false,
"message": "Route not found",
"message": "路由未找到",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),
@@ -181,7 +221,7 @@ func (r *GinRouter) SetupDefaultRoutes() {
r.engine.NoMethod(func(c *gin.Context) {
c.JSON(http.StatusMethodNotAllowed, gin.H{
"success": false,
"message": "Method not allowed",
"message": "请求方法不允许",
"path": c.Request.URL.Path,
"method": c.Request.Method,
"timestamp": time.Now().Unix(),

View File

@@ -42,13 +42,13 @@ func (v *RequestValidator) Validate(c *gin.Context, dto interface{}) error {
// ValidateQuery 验证查询参数
func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
v.response.BadRequest(c, "Invalid query parameters", err.Error())
v.response.BadRequest(c, "查询参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
@@ -57,13 +57,13 @@ func (v *RequestValidator) ValidateQuery(c *gin.Context, dto interface{}) error
// ValidateParam 验证路径参数
func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
v.response.BadRequest(c, "Invalid path parameters", err.Error())
v.response.BadRequest(c, "路径参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrors(err)
v.response.BadRequest(c, "Validation failed", validationErrors)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
@@ -73,7 +73,7 @@ func (v *RequestValidator) ValidateParam(c *gin.Context, dto interface{}) error
func (v *RequestValidator) BindAndValidate(c *gin.Context, dto interface{}) error {
// 绑定请求体
if err := c.ShouldBindJSON(dto); err != nil {
v.response.BadRequest(c, "Invalid request body", err.Error())
v.response.BadRequest(c, "请求体格式错误", err.Error())
return err
}
@@ -115,44 +115,74 @@ func (v *RequestValidator) getErrorMessage(fieldError validator.FieldError) stri
tag := fieldError.Tag()
param := fieldError.Param()
fieldDisplayName := v.getFieldDisplayName(field)
switch tag {
case "required":
return fmt.Sprintf("%s is required", field)
return fmt.Sprintf("%s 不能为空", fieldDisplayName)
case "email":
return fmt.Sprintf("%s must be a valid email address", field)
return fmt.Sprintf("%s 必须是有效的邮箱地址", fieldDisplayName)
case "min":
return fmt.Sprintf("%s must be at least %s characters", field, param)
return fmt.Sprintf("%s 长度不能少于 %s 位", fieldDisplayName, param)
case "max":
return fmt.Sprintf("%s must be at most %s characters", field, param)
return fmt.Sprintf("%s 长度不能超过 %s 位", fieldDisplayName, param)
case "len":
return fmt.Sprintf("%s must be exactly %s characters", field, param)
return fmt.Sprintf("%s 长度必须为 %s 位", fieldDisplayName, param)
case "gt":
return fmt.Sprintf("%s must be greater than %s", field, param)
return fmt.Sprintf("%s 必须大于 %s", fieldDisplayName, param)
case "gte":
return fmt.Sprintf("%s must be greater than or equal to %s", field, param)
return fmt.Sprintf("%s 必须大于等于 %s", fieldDisplayName, param)
case "lt":
return fmt.Sprintf("%s must be less than %s", field, param)
return fmt.Sprintf("%s 必须小于 %s", fieldDisplayName, param)
case "lte":
return fmt.Sprintf("%s must be less than or equal to %s", field, param)
return fmt.Sprintf("%s 必须小于等于 %s", fieldDisplayName, param)
case "oneof":
return fmt.Sprintf("%s must be one of [%s]", field, param)
return fmt.Sprintf("%s 必须是以下值之一:[%s]", fieldDisplayName, param)
case "url":
return fmt.Sprintf("%s must be a valid URL", field)
return fmt.Sprintf("%s 必须是有效的URL地址", fieldDisplayName)
case "alpha":
return fmt.Sprintf("%s must contain only alphabetic characters", field)
return fmt.Sprintf("%s 只能包含字母", fieldDisplayName)
case "alphanum":
return fmt.Sprintf("%s must contain only alphanumeric characters", field)
return fmt.Sprintf("%s 只能包含字母和数字", fieldDisplayName)
case "numeric":
return fmt.Sprintf("%s must be numeric", field)
return fmt.Sprintf("%s 必须是数字", fieldDisplayName)
case "phone":
return fmt.Sprintf("%s must be a valid phone number", field)
return fmt.Sprintf("%s 必须是有效的手机号", fieldDisplayName)
case "username":
return fmt.Sprintf("%s must be a valid username", field)
return fmt.Sprintf("%s 格式不正确,只能包含字母、数字、下划线,且不能以数字开头", fieldDisplayName)
case "strong_password":
return fmt.Sprintf("%s 强度不足必须包含大小写字母和数字且不少于8位", fieldDisplayName)
case "eqfield":
return fmt.Sprintf("%s 必须与 %s 一致", fieldDisplayName, v.getFieldDisplayName(param))
default:
return fmt.Sprintf("%s is invalid", field)
return fmt.Sprintf("%s 格式不正确", fieldDisplayName)
}
}
// getFieldDisplayName 获取字段显示名称(中文)
func (v *RequestValidator) getFieldDisplayName(field string) string {
fieldNames := map[string]string{
"phone": "手机号",
"password": "密码",
"confirm_password": "确认密码",
"old_password": "原密码",
"new_password": "新密码",
"confirm_new_password": "确认新密码",
"code": "验证码",
"username": "用户名",
"email": "邮箱",
"display_name": "显示名称",
"scene": "使用场景",
"Password": "密码",
"NewPassword": "新密码",
}
if displayName, exists := fieldNames[field]; exists {
return displayName
}
return field
}
// toSnakeCase 转换为snake_case
func (v *RequestValidator) toSnakeCase(str string) string {
var result strings.Builder

View File

@@ -0,0 +1,294 @@
package http
import (
"strings"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"github.com/go-playground/locales/zh"
ut "github.com/go-playground/universal-translator"
"github.com/go-playground/validator/v10"
zh_translations "github.com/go-playground/validator/v10/translations/zh"
)
// RequestValidatorZh 中文验证器实现
type RequestValidatorZh struct {
validator *validator.Validate
translator ut.Translator
response interfaces.ResponseBuilder
}
// NewRequestValidatorZh 创建支持中文翻译的请求验证器
func NewRequestValidatorZh(response interfaces.ResponseBuilder) interfaces.RequestValidator {
// 创建验证器实例
validate := validator.New()
// 创建中文locale
zhLocale := zh.New()
uni := ut.New(zhLocale, zhLocale)
// 获取中文翻译器
trans, _ := uni.GetTranslator("zh")
// 注册中文翻译
zh_translations.RegisterDefaultTranslations(validate, trans)
// 注册自定义验证器
registerCustomValidatorsZh(validate, trans)
return &RequestValidatorZh{
validator: validate,
translator: trans,
response: response,
}
}
// Validate 验证请求体
func (v *RequestValidatorZh) Validate(c *gin.Context, dto interface{}) error {
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrorsZh(err)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
}
// ValidateQuery 验证查询参数
func (v *RequestValidatorZh) ValidateQuery(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindQuery(dto); err != nil {
v.response.BadRequest(c, "查询参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrorsZh(err)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
}
// ValidateParam 验证路径参数
func (v *RequestValidatorZh) ValidateParam(c *gin.Context, dto interface{}) error {
if err := c.ShouldBindUri(dto); err != nil {
v.response.BadRequest(c, "路径参数格式错误", err.Error())
return err
}
if err := v.validator.Struct(dto); err != nil {
validationErrors := v.formatValidationErrorsZh(err)
v.response.ValidationError(c, validationErrors)
return err
}
return nil
}
// BindAndValidate 绑定并验证请求
func (v *RequestValidatorZh) BindAndValidate(c *gin.Context, dto interface{}) error {
// 绑定请求体
if err := c.ShouldBindJSON(dto); err != nil {
v.response.BadRequest(c, "请求体格式错误", err.Error())
return err
}
// 验证数据
return v.Validate(c, dto)
}
// formatValidationErrorsZh 格式化验证错误(中文翻译版)
func (v *RequestValidatorZh) formatValidationErrorsZh(err error) map[string][]string {
errors := make(map[string][]string)
if validationErrors, ok := err.(validator.ValidationErrors); ok {
for _, fieldError := range validationErrors {
fieldName := v.getFieldNameZh(fieldError)
// 首先尝试使用翻译器获取翻译后的错误消息
errorMessage := fieldError.Translate(v.translator)
// 如果翻译后的消息包含英文字段名,则替换为中文字段名
fieldDisplayName := v.getFieldDisplayName(fieldError.Field())
if fieldDisplayName != fieldError.Field() {
// 替换字段名为中文
errorMessage = strings.ReplaceAll(errorMessage, fieldError.Field(), fieldDisplayName)
}
if _, exists := errors[fieldName]; !exists {
errors[fieldName] = []string{}
}
errors[fieldName] = append(errors[fieldName], errorMessage)
}
}
return errors
}
// getFieldNameZh 获取字段名JSON标签优先
func (v *RequestValidatorZh) getFieldNameZh(fieldError validator.FieldError) string {
fieldName := fieldError.Field()
return v.toSnakeCase(fieldName)
}
// getFieldDisplayName 获取字段显示名称(中文)
func (v *RequestValidatorZh) getFieldDisplayName(field string) string {
fieldNames := map[string]string{
"phone": "手机号",
"password": "密码",
"confirm_password": "确认密码",
"old_password": "原密码",
"new_password": "新密码",
"confirm_new_password": "确认新密码",
"code": "验证码",
"username": "用户名",
"email": "邮箱",
"display_name": "显示名称",
"scene": "使用场景",
"Password": "密码",
"NewPassword": "新密码",
"ConfirmPassword": "确认密码",
}
if displayName, exists := fieldNames[field]; exists {
return displayName
}
return field
}
// toSnakeCase 转换为snake_case
func (v *RequestValidatorZh) toSnakeCase(str string) string {
var result strings.Builder
for i, r := range str {
if i > 0 && (r >= 'A' && r <= 'Z') {
result.WriteRune('_')
}
result.WriteRune(r)
}
return strings.ToLower(result.String())
}
// registerCustomValidatorsZh 注册自定义验证器和中文翻译
func registerCustomValidatorsZh(v *validator.Validate, trans ut.Translator) {
// 注册手机号验证器
v.RegisterValidation("phone", validatePhoneZh)
v.RegisterTranslation("phone", trans, func(ut ut.Translator) error {
return ut.Add("phone", "{0}必须是有效的手机号", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("phone", fe.Field())
return t
})
// 注册用户名验证器
v.RegisterValidation("username", validateUsernameZh)
v.RegisterTranslation("username", trans, func(ut ut.Translator) error {
return ut.Add("username", "{0}格式不正确,只能包含字母、数字、下划线,且不能以数字开头", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("username", fe.Field())
return t
})
// 注册密码强度验证器
v.RegisterValidation("strong_password", validateStrongPasswordZh)
v.RegisterTranslation("strong_password", trans, func(ut ut.Translator) error {
return ut.Add("strong_password", "{0}强度不足必须包含大小写字母和数字且不少于8位", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("strong_password", fe.Field())
return t
})
// 自定义eqfield翻译
v.RegisterTranslation("eqfield", trans, func(ut ut.Translator) error {
return ut.Add("eqfield", "{0}必须等于{1}", true)
}, func(ut ut.Translator, fe validator.FieldError) string {
t, _ := ut.T("eqfield", fe.Field(), fe.Param())
return t
})
}
// validatePhoneZh 验证手机号
func validatePhoneZh(fl validator.FieldLevel) bool {
phone := fl.Field().String()
if phone == "" {
return true // 空值由required标签处理
}
// 中国手机号验证11位以1开头
if len(phone) != 11 {
return false
}
if !strings.HasPrefix(phone, "1") {
return false
}
// 检查是否全是数字
for _, r := range phone {
if r < '0' || r > '9' {
return false
}
}
return true
}
// validateUsernameZh 验证用户名
func validateUsernameZh(fl validator.FieldLevel) bool {
username := fl.Field().String()
if username == "" {
return true // 空值由required标签处理
}
// 用户名规则3-30个字符只能包含字母、数字、下划线不能以数字开头
if len(username) < 3 || len(username) > 30 {
return false
}
// 不能以数字开头
if username[0] >= '0' && username[0] <= '9' {
return false
}
// 只能包含字母、数字、下划线
for _, r := range username {
if !((r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
return false
}
}
return true
}
// validateStrongPasswordZh 验证密码强度
func validateStrongPasswordZh(fl validator.FieldLevel) bool {
password := fl.Field().String()
if password == "" {
return true // 空值由required标签处理
}
// 密码强度规则至少8个字符包含大小写字母、数字
if len(password) < 8 {
return false
}
hasUpper := false
hasLower := false
hasDigit := false
for _, r := range password {
switch {
case r >= 'A' && r <= 'Z':
hasUpper = true
case r >= 'a' && r <= 'z':
hasLower = true
case r >= '0' && r <= '9':
hasDigit = true
}
}
return hasUpper && hasLower && hasDigit
}
// ValidateStruct 直接验证结构体
func (v *RequestValidatorZh) ValidateStruct(dto interface{}) error {
return v.validator.Struct(dto)
}

View File

@@ -76,9 +76,14 @@ type ResponseBuilder interface {
NotFound(c *gin.Context, message ...string)
Conflict(c *gin.Context, message string)
InternalError(c *gin.Context, message ...string)
ValidationError(c *gin.Context, errors interface{})
TooManyRequests(c *gin.Context, message ...string)
// 分页响应
Paginated(c *gin.Context, data interface{}, pagination PaginationMeta)
// 自定义响应
CustomResponse(c *gin.Context, statusCode int, data interface{})
}
// RequestValidator 请求验证器接口
@@ -90,6 +95,9 @@ type RequestValidator interface {
// 绑定和验证
BindAndValidate(c *gin.Context, dto interface{}) error
// 直接验证结构体
ValidateStruct(dto interface{}) error
}
// PaginationMeta 分页元数据

View File

@@ -2,6 +2,15 @@ package interfaces
import (
"context"
"errors"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/entities"
)
// 常见错误定义
var (
ErrCacheMiss = errors.New("cache miss")
)
// Service 通用服务接口
@@ -16,6 +25,22 @@ type Service interface {
Shutdown(ctx context.Context) error
}
// UserService 用户服务接口
type UserService interface {
Service
// 用户注册
Register(ctx context.Context, req *dto.RegisterRequest) (*entities.User, error)
// 密码登录
LoginWithPassword(ctx context.Context, req *dto.LoginWithPasswordRequest) (*entities.User, error)
// 短信验证码登录
LoginWithSMS(ctx context.Context, req *dto.LoginWithSMSRequest) (*entities.User, error)
// 修改密码
ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error
// 根据ID获取用户
GetByID(ctx context.Context, id string) (*entities.User, error)
}
// DomainService 领域服务接口,支持泛型
type DomainService[T Entity] interface {
Service

View File

@@ -0,0 +1,214 @@
package logger
import (
"context"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// LogLevel 日志级别
type LogLevel string
const (
DebugLevel LogLevel = "debug"
InfoLevel LogLevel = "info"
WarnLevel LogLevel = "warn"
ErrorLevel LogLevel = "error"
)
// LogContext 日志上下文
type LogContext struct {
RequestID string
UserID string
TraceID string
OperationName string
Layer string // repository/service/handler
Component string
}
// ContextualLogger 上下文感知的日志器
type ContextualLogger struct {
logger *zap.Logger
ctx LogContext
}
// NewContextualLogger 创建上下文日志器
func NewContextualLogger(logger *zap.Logger) *ContextualLogger {
return &ContextualLogger{
logger: logger,
}
}
// WithContext 添加上下文信息
func (l *ContextualLogger) WithContext(ctx context.Context) *ContextualLogger {
logCtx := LogContext{}
// 从context中提取常用字段
if requestID := getStringFromContext(ctx, "request_id"); requestID != "" {
logCtx.RequestID = requestID
}
if userID := getStringFromContext(ctx, "user_id"); userID != "" {
logCtx.UserID = userID
}
if traceID := getStringFromContext(ctx, "trace_id"); traceID != "" {
logCtx.TraceID = traceID
}
return &ContextualLogger{
logger: l.logger,
ctx: logCtx,
}
}
// WithLayer 设置层级信息
func (l *ContextualLogger) WithLayer(layer string) *ContextualLogger {
newCtx := l.ctx
newCtx.Layer = layer
return &ContextualLogger{
logger: l.logger,
ctx: newCtx,
}
}
// WithComponent 设置组件信息
func (l *ContextualLogger) WithComponent(component string) *ContextualLogger {
newCtx := l.ctx
newCtx.Component = component
return &ContextualLogger{
logger: l.logger,
ctx: newCtx,
}
}
// WithOperation 设置操作名称
func (l *ContextualLogger) WithOperation(operation string) *ContextualLogger {
newCtx := l.ctx
newCtx.OperationName = operation
return &ContextualLogger{
logger: l.logger,
ctx: newCtx,
}
}
// 构建基础字段
func (l *ContextualLogger) buildBaseFields() []zapcore.Field {
fields := []zapcore.Field{}
if l.ctx.RequestID != "" {
fields = append(fields, zap.String("request_id", l.ctx.RequestID))
}
if l.ctx.UserID != "" {
fields = append(fields, zap.String("user_id", l.ctx.UserID))
}
if l.ctx.TraceID != "" {
fields = append(fields, zap.String("trace_id", l.ctx.TraceID))
}
if l.ctx.Layer != "" {
fields = append(fields, zap.String("layer", l.ctx.Layer))
}
if l.ctx.Component != "" {
fields = append(fields, zap.String("component", l.ctx.Component))
}
if l.ctx.OperationName != "" {
fields = append(fields, zap.String("operation", l.ctx.OperationName))
}
return fields
}
// LogTechnicalError 记录技术性错误Repository层
func (l *ContextualLogger) LogTechnicalError(msg string, err error, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.Error(err))
allFields = append(allFields, zap.String("error_type", "technical"))
allFields = append(allFields, fields...)
l.logger.Error(msg, allFields...)
}
// LogBusinessWarn 记录业务警告Service层
func (l *ContextualLogger) LogBusinessWarn(msg string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "business"))
allFields = append(allFields, fields...)
l.logger.Warn(msg, allFields...)
}
// LogBusinessInfo 记录业务信息Service层
func (l *ContextualLogger) LogBusinessInfo(msg string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "business"))
allFields = append(allFields, fields...)
l.logger.Info(msg, allFields...)
}
// LogUserAction 记录用户行为Handler层
func (l *ContextualLogger) LogUserAction(msg string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "user_action"))
allFields = append(allFields, fields...)
l.logger.Info(msg, allFields...)
}
// LogRequestFailed 记录请求失败Handler层
func (l *ContextualLogger) LogRequestFailed(msg string, errorType string, fields ...zapcore.Field) {
allFields := l.buildBaseFields()
allFields = append(allFields, zap.String("log_type", "request_failed"))
allFields = append(allFields, zap.String("error_category", errorType))
allFields = append(allFields, fields...)
l.logger.Info(msg, allFields...)
}
// getStringFromContext 从上下文获取字符串值
func getStringFromContext(ctx context.Context, key string) string {
if value := ctx.Value(key); value != nil {
if str, ok := value.(string); ok {
return str
}
}
return ""
}
// ErrorCategory 错误分类
type ErrorCategory string
const (
DatabaseError ErrorCategory = "database"
NetworkError ErrorCategory = "network"
ValidationError ErrorCategory = "validation"
BusinessError ErrorCategory = "business"
AuthError ErrorCategory = "auth"
ExternalAPIError ErrorCategory = "external_api"
)
// CategorizeError 错误分类
func CategorizeError(err error) ErrorCategory {
errMsg := strings.ToLower(err.Error())
switch {
case strings.Contains(errMsg, "database") ||
strings.Contains(errMsg, "sql") ||
strings.Contains(errMsg, "gorm"):
return DatabaseError
case strings.Contains(errMsg, "network") ||
strings.Contains(errMsg, "connection") ||
strings.Contains(errMsg, "timeout"):
return NetworkError
case strings.Contains(errMsg, "validation") ||
strings.Contains(errMsg, "invalid") ||
strings.Contains(errMsg, "format"):
return ValidationError
case strings.Contains(errMsg, "unauthorized") ||
strings.Contains(errMsg, "forbidden") ||
strings.Contains(errMsg, "token"):
return AuthError
default:
return BusinessError
}
}

View File

@@ -0,0 +1,263 @@
package metrics
import (
"context"
"sync"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// BusinessMetrics 业务指标收集器
type BusinessMetrics struct {
metrics interfaces.MetricsCollector
logger *zap.Logger
mutex sync.RWMutex
// 业务指标缓存
userMetrics map[string]int64
orderMetrics map[string]int64
}
// NewBusinessMetrics 创建业务指标收集器
func NewBusinessMetrics(metrics interfaces.MetricsCollector, logger *zap.Logger) *BusinessMetrics {
bm := &BusinessMetrics{
metrics: metrics,
logger: logger,
userMetrics: make(map[string]int64),
orderMetrics: make(map[string]int64),
}
// 注册业务指标
bm.registerBusinessMetrics()
return bm
}
// registerBusinessMetrics 注册业务指标
func (bm *BusinessMetrics) registerBusinessMetrics() {
// 用户相关指标
bm.metrics.RegisterCounter("users_created_total", "Total number of users created", []string{"source"})
bm.metrics.RegisterCounter("users_login_total", "Total number of user logins", []string{"method", "status"})
bm.metrics.RegisterGauge("users_active_sessions", "Current number of active user sessions", nil)
// 订单相关指标
bm.metrics.RegisterCounter("orders_created_total", "Total number of orders created", []string{"status"})
bm.metrics.RegisterCounter("orders_amount_total", "Total order amount in cents", []string{"currency"})
bm.metrics.RegisterHistogram("orders_processing_duration_seconds", "Order processing duration", []string{"status"}, []float64{0.1, 0.5, 1, 2, 5, 10, 30})
// API相关指标
bm.metrics.RegisterCounter("api_errors_total", "Total number of API errors", []string{"endpoint", "error_type"})
bm.metrics.RegisterHistogram("api_response_size_bytes", "API response size in bytes", []string{"endpoint"}, []float64{100, 1000, 10000, 100000})
// 缓存相关指标
bm.metrics.RegisterCounter("cache_operations_total", "Total number of cache operations", []string{"operation", "result"})
bm.metrics.RegisterGauge("cache_memory_usage_bytes", "Cache memory usage in bytes", []string{"cache_type"})
// 数据库相关指标
bm.metrics.RegisterHistogram("database_query_duration_seconds", "Database query duration", []string{"operation", "table"}, []float64{0.001, 0.01, 0.1, 1, 10})
bm.metrics.RegisterCounter("database_errors_total", "Total number of database errors", []string{"operation", "error_type"})
bm.logger.Info("Business metrics registered successfully")
}
// User相关指标
// RecordUserCreated 记录用户创建
func (bm *BusinessMetrics) RecordUserCreated(source string) {
bm.metrics.IncrementCounter("users_created_total", map[string]string{
"source": source,
})
bm.mutex.Lock()
bm.userMetrics["created"]++
bm.mutex.Unlock()
bm.logger.Debug("Recorded user created", zap.String("source", source))
}
// RecordUserLogin 记录用户登录
func (bm *BusinessMetrics) RecordUserLogin(method, status string) {
bm.metrics.IncrementCounter("users_login_total", map[string]string{
"method": method,
"status": status,
})
bm.logger.Debug("Recorded user login", zap.String("method", method), zap.String("status", status))
}
// UpdateActiveUserSessions 更新活跃用户会话数
func (bm *BusinessMetrics) UpdateActiveUserSessions(count float64) {
bm.metrics.RecordGauge("users_active_sessions", count, nil)
}
// Order相关指标
// RecordOrderCreated 记录订单创建
func (bm *BusinessMetrics) RecordOrderCreated(status string, amount float64, currency string) {
bm.metrics.IncrementCounter("orders_created_total", map[string]string{
"status": status,
})
// 记录订单金额(以分为单位)
amountCents := int64(amount * 100)
bm.metrics.IncrementCounter("orders_amount_total", map[string]string{
"currency": currency,
})
bm.mutex.Lock()
bm.orderMetrics["created"]++
bm.orderMetrics["amount"] += amountCents
bm.mutex.Unlock()
bm.logger.Debug("Recorded order created",
zap.String("status", status),
zap.Float64("amount", amount),
zap.String("currency", currency))
}
// RecordOrderProcessingDuration 记录订单处理时长
func (bm *BusinessMetrics) RecordOrderProcessingDuration(status string, duration float64) {
bm.metrics.RecordHistogram("orders_processing_duration_seconds", duration, map[string]string{
"status": status,
})
}
// API相关指标
// RecordAPIError 记录API错误
func (bm *BusinessMetrics) RecordAPIError(endpoint, errorType string) {
bm.metrics.IncrementCounter("api_errors_total", map[string]string{
"endpoint": endpoint,
"error_type": errorType,
})
bm.logger.Debug("Recorded API error",
zap.String("endpoint", endpoint),
zap.String("error_type", errorType))
}
// RecordAPIResponseSize 记录API响应大小
func (bm *BusinessMetrics) RecordAPIResponseSize(endpoint string, sizeBytes float64) {
bm.metrics.RecordHistogram("api_response_size_bytes", sizeBytes, map[string]string{
"endpoint": endpoint,
})
}
// Cache相关指标
// RecordCacheOperation 记录缓存操作
func (bm *BusinessMetrics) RecordCacheOperation(operation, result string) {
bm.metrics.IncrementCounter("cache_operations_total", map[string]string{
"operation": operation,
"result": result,
})
}
// UpdateCacheMemoryUsage 更新缓存内存使用量
func (bm *BusinessMetrics) UpdateCacheMemoryUsage(cacheType string, usageBytes float64) {
bm.metrics.RecordGauge("cache_memory_usage_bytes", usageBytes, map[string]string{
"cache_type": cacheType,
})
}
// Database相关指标
// RecordDatabaseQuery 记录数据库查询
func (bm *BusinessMetrics) RecordDatabaseQuery(operation, table string, duration float64) {
bm.metrics.RecordHistogram("database_query_duration_seconds", duration, map[string]string{
"operation": operation,
"table": table,
})
}
// RecordDatabaseError 记录数据库错误
func (bm *BusinessMetrics) RecordDatabaseError(operation, errorType string) {
bm.metrics.IncrementCounter("database_errors_total", map[string]string{
"operation": operation,
"error_type": errorType,
})
bm.logger.Debug("Recorded database error",
zap.String("operation", operation),
zap.String("error_type", errorType))
}
// 获取统计信息
// GetUserStats 获取用户统计
func (bm *BusinessMetrics) GetUserStats() map[string]int64 {
bm.mutex.RLock()
defer bm.mutex.RUnlock()
stats := make(map[string]int64)
for k, v := range bm.userMetrics {
stats[k] = v
}
return stats
}
// GetOrderStats 获取订单统计
func (bm *BusinessMetrics) GetOrderStats() map[string]int64 {
bm.mutex.RLock()
defer bm.mutex.RUnlock()
stats := make(map[string]int64)
for k, v := range bm.orderMetrics {
stats[k] = v
}
return stats
}
// GetOverallStats 获取整体统计
func (bm *BusinessMetrics) GetOverallStats() map[string]interface{} {
return map[string]interface{}{
"user_stats": bm.GetUserStats(),
"order_stats": bm.GetOrderStats(),
}
}
// Reset 重置统计数据
func (bm *BusinessMetrics) Reset() {
bm.mutex.Lock()
defer bm.mutex.Unlock()
bm.userMetrics = make(map[string]int64)
bm.orderMetrics = make(map[string]int64)
bm.logger.Info("Business metrics reset")
}
// Context相关方法
// WithContext 创建带上下文的业务指标收集器
func (bm *BusinessMetrics) WithContext(ctx context.Context) *BusinessMetrics {
// 这里可以从context中提取追踪信息关联指标
return bm
}
// 实现Service接口如果需要
// Name 返回服务名称
func (bm *BusinessMetrics) Name() string {
return "business-metrics"
}
// Initialize 初始化服务
func (bm *BusinessMetrics) Initialize(ctx context.Context) error {
bm.logger.Info("Business metrics service initialized")
return nil
}
// HealthCheck 健康检查
func (bm *BusinessMetrics) HealthCheck(ctx context.Context) error {
// 检查指标收集器是否正常
return nil
}
// Shutdown 关闭服务
func (bm *BusinessMetrics) Shutdown(ctx context.Context) error {
bm.logger.Info("Business metrics service shutdown")
return nil
}

View File

@@ -0,0 +1,353 @@
package metrics
import (
"net/http"
"strconv"
"sync"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
)
// PrometheusMetrics Prometheus指标收集器
type PrometheusMetrics struct {
logger *zap.Logger
registry *prometheus.Registry
mutex sync.RWMutex
// 预定义指标
httpRequests *prometheus.CounterVec
httpDuration *prometheus.HistogramVec
activeUsers prometheus.Gauge
dbConnections prometheus.Gauge
cacheHits *prometheus.CounterVec
businessMetrics map[string]prometheus.Collector
}
// NewPrometheusMetrics 创建Prometheus指标收集器
func NewPrometheusMetrics(logger *zap.Logger) *PrometheusMetrics {
registry := prometheus.NewRegistry()
// HTTP请求计数器
httpRequests := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "http_requests_total",
Help: "Total number of HTTP requests",
},
[]string{"method", "path", "status"},
)
// HTTP请求耗时直方图
httpDuration := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "http_request_duration_seconds",
Help: "HTTP request duration in seconds",
Buckets: prometheus.DefBuckets,
},
[]string{"method", "path"},
)
// 活跃用户数
activeUsers := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "active_users_total",
Help: "Current number of active users",
},
)
// 数据库连接数
dbConnections := prometheus.NewGauge(
prometheus.GaugeOpts{
Name: "database_connections_active",
Help: "Current number of active database connections",
},
)
// 缓存命中率
cacheHits := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "cache_operations_total",
Help: "Total number of cache operations",
},
[]string{"operation", "result"},
)
// 注册指标
registry.MustRegister(httpRequests)
registry.MustRegister(httpDuration)
registry.MustRegister(activeUsers)
registry.MustRegister(dbConnections)
registry.MustRegister(cacheHits)
return &PrometheusMetrics{
logger: logger,
registry: registry,
httpRequests: httpRequests,
httpDuration: httpDuration,
activeUsers: activeUsers,
dbConnections: dbConnections,
cacheHits: cacheHits,
businessMetrics: make(map[string]prometheus.Collector),
}
}
// RecordHTTPRequest 记录HTTP请求指标
func (m *PrometheusMetrics) RecordHTTPRequest(method, path string, statusCode int, duration float64) {
status := strconv.Itoa(statusCode)
m.httpRequests.WithLabelValues(method, path, status).Inc()
m.httpDuration.WithLabelValues(method, path).Observe(duration)
m.logger.Debug("Recorded HTTP request metric",
zap.String("method", method),
zap.String("path", path),
zap.String("status", status),
zap.Float64("duration", duration))
}
// RecordHTTPDuration 记录HTTP请求耗时
func (m *PrometheusMetrics) RecordHTTPDuration(method, path string, duration float64) {
m.httpDuration.WithLabelValues(method, path).Observe(duration)
m.logger.Debug("Recorded HTTP duration metric",
zap.String("method", method),
zap.String("path", path),
zap.Float64("duration", duration))
}
// IncrementCounter 增加计数器
func (m *PrometheusMetrics) IncrementCounter(name string, labels map[string]string) {
if counter, exists := m.getOrCreateCounter(name, labels); exists {
if vec, ok := counter.(*prometheus.CounterVec); ok {
vec.With(labels).Inc()
}
}
}
// RecordGauge 记录仪表盘值
func (m *PrometheusMetrics) RecordGauge(name string, value float64, labels map[string]string) {
if gauge, exists := m.getOrCreateGauge(name, labels); exists {
if vec, ok := gauge.(*prometheus.GaugeVec); ok {
vec.With(labels).Set(value)
} else if g, ok := gauge.(prometheus.Gauge); ok {
g.Set(value)
}
}
}
// RecordHistogram 记录直方图值
func (m *PrometheusMetrics) RecordHistogram(name string, value float64, labels map[string]string) {
if histogram, exists := m.getOrCreateHistogram(name, labels); exists {
if vec, ok := histogram.(*prometheus.HistogramVec); ok {
vec.With(labels).Observe(value)
}
}
}
// RegisterCounter 注册计数器
func (m *PrometheusMetrics) RegisterCounter(name, help string, labels []string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.businessMetrics[name]; exists {
return nil // 已存在
}
var counter prometheus.Collector
if len(labels) > 0 {
counter = prometheus.NewCounterVec(
prometheus.CounterOpts{Name: name, Help: help},
labels,
)
} else {
counter = prometheus.NewCounter(
prometheus.CounterOpts{Name: name, Help: help},
)
}
if err := m.registry.Register(counter); err != nil {
return err
}
m.businessMetrics[name] = counter
m.logger.Info("Registered counter metric", zap.String("name", name))
return nil
}
// RegisterGauge 注册仪表盘
func (m *PrometheusMetrics) RegisterGauge(name, help string, labels []string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.businessMetrics[name]; exists {
return nil
}
var gauge prometheus.Collector
if len(labels) > 0 {
gauge = prometheus.NewGaugeVec(
prometheus.GaugeOpts{Name: name, Help: help},
labels,
)
} else {
gauge = prometheus.NewGauge(
prometheus.GaugeOpts{Name: name, Help: help},
)
}
if err := m.registry.Register(gauge); err != nil {
return err
}
m.businessMetrics[name] = gauge
m.logger.Info("Registered gauge metric", zap.String("name", name))
return nil
}
// RegisterHistogram 注册直方图
func (m *PrometheusMetrics) RegisterHistogram(name, help string, labels []string, buckets []float64) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if _, exists := m.businessMetrics[name]; exists {
return nil
}
if buckets == nil {
buckets = prometheus.DefBuckets
}
var histogram prometheus.Collector
if len(labels) > 0 {
histogram = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: name,
Help: help,
Buckets: buckets,
},
labels,
)
} else {
histogram = prometheus.NewHistogram(
prometheus.HistogramOpts{
Name: name,
Help: help,
Buckets: buckets,
},
)
}
if err := m.registry.Register(histogram); err != nil {
return err
}
m.businessMetrics[name] = histogram
m.logger.Info("Registered histogram metric", zap.String("name", name))
return nil
}
// GetHandler 获取HTTP处理器
func (m *PrometheusMetrics) GetHandler() http.Handler {
return promhttp.HandlerFor(m.registry, promhttp.HandlerOpts{})
}
// 内部辅助方法
func (m *PrometheusMetrics) getOrCreateCounter(name string, labels map[string]string) (prometheus.Collector, bool) {
m.mutex.RLock()
counter, exists := m.businessMetrics[name]
m.mutex.RUnlock()
if !exists {
// 自动创建计数器
labelNames := make([]string, 0, len(labels))
for k := range labels {
labelNames = append(labelNames, k)
}
if err := m.RegisterCounter(name, "Auto-created counter", labelNames); err != nil {
m.logger.Error("Failed to auto-create counter", zap.String("name", name), zap.Error(err))
return nil, false
}
m.mutex.RLock()
counter, exists = m.businessMetrics[name]
m.mutex.RUnlock()
}
return counter, exists
}
func (m *PrometheusMetrics) getOrCreateGauge(name string, labels map[string]string) (prometheus.Collector, bool) {
m.mutex.RLock()
gauge, exists := m.businessMetrics[name]
m.mutex.RUnlock()
if !exists {
labelNames := make([]string, 0, len(labels))
for k := range labels {
labelNames = append(labelNames, k)
}
if err := m.RegisterGauge(name, "Auto-created gauge", labelNames); err != nil {
m.logger.Error("Failed to auto-create gauge", zap.String("name", name), zap.Error(err))
return nil, false
}
m.mutex.RLock()
gauge, exists = m.businessMetrics[name]
m.mutex.RUnlock()
}
return gauge, exists
}
func (m *PrometheusMetrics) getOrCreateHistogram(name string, labels map[string]string) (prometheus.Collector, bool) {
m.mutex.RLock()
histogram, exists := m.businessMetrics[name]
m.mutex.RUnlock()
if !exists {
labelNames := make([]string, 0, len(labels))
for k := range labels {
labelNames = append(labelNames, k)
}
if err := m.RegisterHistogram(name, "Auto-created histogram", labelNames, nil); err != nil {
m.logger.Error("Failed to auto-create histogram", zap.String("name", name), zap.Error(err))
return nil, false
}
m.mutex.RLock()
histogram, exists = m.businessMetrics[name]
m.mutex.RUnlock()
}
return histogram, exists
}
// UpdateActiveUsers 更新活跃用户数
func (m *PrometheusMetrics) UpdateActiveUsers(count float64) {
m.activeUsers.Set(count)
}
// UpdateDBConnections 更新数据库连接数
func (m *PrometheusMetrics) UpdateDBConnections(count float64) {
m.dbConnections.Set(count)
}
// RecordCacheOperation 记录缓存操作
func (m *PrometheusMetrics) RecordCacheOperation(operation, result string) {
m.cacheHits.WithLabelValues(operation, result).Inc()
}
// GetStats 获取指标统计
func (m *PrometheusMetrics) GetStats() map[string]interface{} {
m.mutex.RLock()
defer m.mutex.RUnlock()
return map[string]interface{}{
"registered_metrics": len(m.businessMetrics),
}
}

View File

@@ -42,31 +42,31 @@ func (m *JWTAuthMiddleware) Handle() gin.HandlerFunc {
// 获取Authorization头部
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
m.respondUnauthorized(c, "Missing authorization header")
m.respondUnauthorized(c, "缺少认证头部")
return
}
// 检查Bearer前缀
const bearerPrefix = "Bearer "
if !strings.HasPrefix(authHeader, bearerPrefix) {
m.respondUnauthorized(c, "Invalid authorization header format")
m.respondUnauthorized(c, "认证头部格式无效")
return
}
// 提取token
tokenString := authHeader[len(bearerPrefix):]
if tokenString == "" {
m.respondUnauthorized(c, "Missing token")
m.respondUnauthorized(c, "缺少认证令牌")
return
}
// 验证token
claims, err := m.validateToken(tokenString)
if err != nil {
m.logger.Warn("Invalid token",
m.logger.Warn("无效的认证令牌",
zap.Error(err),
zap.String("request_id", c.GetString("request_id")))
m.respondUnauthorized(c, "Invalid token")
m.respondUnauthorized(c, "认证令牌无效")
return
}
@@ -119,7 +119,7 @@ func (m *JWTAuthMiddleware) validateToken(tokenString string) (*JWTClaims, error
func (m *JWTAuthMiddleware) respondUnauthorized(c *gin.Context, message string) {
c.JSON(http.StatusUnauthorized, gin.H{
"success": false,
"message": "Unauthorized",
"message": "认证失败",
"error": message,
"request_id": c.GetString("request_id"),
"timestamp": time.Now().Unix(),

View File

@@ -2,11 +2,11 @@ package middleware
import (
"fmt"
"net/http"
"sync"
"time"
"tyapi-server/internal/config"
"tyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"golang.org/x/time/rate"
@@ -15,14 +15,16 @@ import (
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
config *config.Config
response interfaces.ResponseBuilder
limiters map[string]*rate.Limiter
mutex sync.RWMutex
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg *config.Config) *RateLimitMiddleware {
func NewRateLimitMiddleware(cfg *config.Config, response interfaces.ResponseBuilder) *RateLimitMiddleware {
return &RateLimitMiddleware{
config: cfg,
response: response,
limiters: make(map[string]*rate.Limiter),
}
}
@@ -48,15 +50,13 @@ func (m *RateLimitMiddleware) Handle() gin.HandlerFunc {
// 检查是否允许请求
if !limiter.Allow() {
// 添加限流头部信息
c.Header("X-RateLimit-Limit", fmt.Sprintf("%d", m.config.RateLimit.Requests))
c.Header("X-RateLimit-Window", m.config.RateLimit.Window.String())
c.Header("Retry-After", "60")
c.JSON(http.StatusTooManyRequests, gin.H{
"success": false,
"message": "Rate limit exceeded",
"error": "Too many requests",
})
// 使用统一的响应格式
m.response.TooManyRequests(c, "请求过于频繁,请稍后再试")
c.Abort()
return
}

View File

@@ -2,23 +2,35 @@ package middleware
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"tyapi-server/internal/shared/tracing"
)
// RequestLoggerMiddleware 请求日志中间件
type RequestLoggerMiddleware struct {
logger *zap.Logger
logger *zap.Logger
useColoredLog bool
isDevelopment bool
tracer *tracing.Tracer
}
// NewRequestLoggerMiddleware 创建请求日志中间件
func NewRequestLoggerMiddleware(logger *zap.Logger) *RequestLoggerMiddleware {
func NewRequestLoggerMiddleware(logger *zap.Logger, isDevelopment bool, tracer *tracing.Tracer) *RequestLoggerMiddleware {
return &RequestLoggerMiddleware{
logger: logger,
logger: logger,
useColoredLog: isDevelopment, // 开发环境使用彩色日志
isDevelopment: isDevelopment,
tracer: tracer,
}
}
@@ -34,24 +46,110 @@ func (m *RequestLoggerMiddleware) GetPriority() int {
// Handle 返回中间件处理函数
func (m *RequestLoggerMiddleware) Handle() gin.HandlerFunc {
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 使用zap logger记录请求信息
m.logger.Info("HTTP Request",
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
)
if m.useColoredLog {
// 开发环境使用Gin默认的彩色日志格式
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
var statusColor, methodColor, resetColor string
if param.IsOutputColor() {
statusColor = param.StatusCodeColor()
methodColor = param.MethodColor()
resetColor = param.ResetColor()
}
// 返回空字符串因为我们已经用zap记录了
return ""
})
if param.Latency > time.Minute {
param.Latency = param.Latency.Truncate(time.Second)
}
// 获取TraceID
traceID := param.Request.Header.Get("X-Trace-ID")
if traceID == "" && m.tracer != nil {
traceID = m.tracer.GetTraceID(param.Request.Context())
}
// 检查是否为错误响应
if param.StatusCode >= 400 && m.tracer != nil {
span := trace.SpanFromContext(param.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
}
}
traceInfo := ""
if traceID != "" {
traceInfo = fmt.Sprintf(" | TraceID: %s", traceID)
}
return fmt.Sprintf("[GIN] %v |%s %3d %s| %13v | %15s |%s %-7s %s %#v%s\n%s",
param.TimeStamp.Format("2006/01/02 - 15:04:05"),
statusColor, param.StatusCode, resetColor,
param.Latency,
param.ClientIP,
methodColor, param.Method, resetColor,
param.Path,
traceInfo,
param.ErrorMessage,
)
})
} else {
// 生产环境使用结构化JSON日志
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
// 获取TraceID
traceID := param.Request.Header.Get("X-Trace-ID")
if traceID == "" && m.tracer != nil {
traceID = m.tracer.GetTraceID(param.Request.Context())
}
// 检查是否为错误响应
if param.StatusCode >= 400 && m.tracer != nil {
span := trace.SpanFromContext(param.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 对于服务器错误,记录更详细的日志
if param.StatusCode >= 500 {
m.logger.Error("服务器错误",
zap.Int("status_code", param.StatusCode),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.Duration("latency", param.Latency),
zap.String("client_ip", param.ClientIP),
zap.String("trace_id", traceID),
)
}
}
}
// 记录请求日志
logFields := []zap.Field{
zap.String("client_ip", param.ClientIP),
zap.String("method", param.Method),
zap.String("path", param.Path),
zap.String("protocol", param.Request.Proto),
zap.Int("status_code", param.StatusCode),
zap.Duration("latency", param.Latency),
zap.String("user_agent", param.Request.UserAgent()),
zap.Int("body_size", param.BodySize),
zap.String("referer", param.Request.Referer()),
zap.String("request_id", param.Request.Header.Get("X-Request-ID")),
}
// 添加TraceID
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
m.logger.Info("HTTP请求", logFields...)
return ""
})
}
}
// IsGlobal 是否为全局中间件
@@ -102,6 +200,70 @@ func (m *RequestIDMiddleware) IsGlobal() bool {
return true
}
// TraceIDMiddleware 追踪ID中间件
type TraceIDMiddleware struct {
tracer *tracing.Tracer
}
// NewTraceIDMiddleware 创建追踪ID中间件
func NewTraceIDMiddleware(tracer *tracing.Tracer) *TraceIDMiddleware {
return &TraceIDMiddleware{
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *TraceIDMiddleware) GetName() string {
return "trace_id"
}
// GetPriority 返回中间件优先级
func (m *TraceIDMiddleware) GetPriority() int {
return 94 // 仅次于请求ID中间件
}
// Handle 返回中间件处理函数
func (m *TraceIDMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
// 获取或生成追踪ID
traceID := m.tracer.GetTraceID(c.Request.Context())
if traceID != "" {
// 设置追踪ID到响应头
c.Header("X-Trace-ID", traceID)
// 添加到上下文
c.Set("trace_id", traceID)
}
// 检查是否为错误请求例如URL不存在
c.Next()
// 请求完成后检查状态码
if c.Writer.Status() >= 400 {
// 获取当前span
span := trace.SpanFromContext(c.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 设置错误上下文以便后续span可以识别
c.Request = c.Request.WithContext(context.WithValue(
c.Request.Context(),
"otel_error_request",
true,
))
}
}
}
}
// IsGlobal 是否为全局中间件
func (m *TraceIDMiddleware) IsGlobal() bool {
return true
}
// SecurityHeadersMiddleware 安全头部中间件
type SecurityHeadersMiddleware struct{}
@@ -183,13 +345,15 @@ func (m *ResponseTimeMiddleware) IsGlobal() bool {
type RequestBodyLoggerMiddleware struct {
logger *zap.Logger
enable bool
tracer *tracing.Tracer
}
// NewRequestBodyLoggerMiddleware 创建请求体日志中间件
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool) *RequestBodyLoggerMiddleware {
func NewRequestBodyLoggerMiddleware(logger *zap.Logger, enable bool, tracer *tracing.Tracer) *RequestBodyLoggerMiddleware {
return &RequestBodyLoggerMiddleware{
logger: logger,
enable: enable,
tracer: tracer,
}
}
@@ -220,13 +384,26 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
// 重新设置body供后续处理使用
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// 获取追踪ID
traceID := ""
if m.tracer != nil {
traceID = m.tracer.GetTraceID(c.Request.Context())
}
// 记录请求体(注意:生产环境中应该谨慎记录敏感信息)
m.logger.Debug("Request Body",
logFields := []zap.Field{
zap.String("method", c.Request.Method),
zap.String("path", c.Request.URL.Path),
zap.String("body", string(bodyBytes)),
zap.String("request_id", c.GetString("request_id")),
)
}
// 添加追踪ID
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
m.logger.Debug("请求体", logFields...)
}
}
}
@@ -239,3 +416,83 @@ func (m *RequestBodyLoggerMiddleware) Handle() gin.HandlerFunc {
func (m *RequestBodyLoggerMiddleware) IsGlobal() bool {
return false // 可选中间件,不是全局的
}
// ErrorTrackingMiddleware 错误追踪中间件
type ErrorTrackingMiddleware struct {
logger *zap.Logger
tracer *tracing.Tracer
}
// NewErrorTrackingMiddleware 创建错误追踪中间件
func NewErrorTrackingMiddleware(logger *zap.Logger, tracer *tracing.Tracer) *ErrorTrackingMiddleware {
return &ErrorTrackingMiddleware{
logger: logger,
tracer: tracer,
}
}
// GetName 返回中间件名称
func (m *ErrorTrackingMiddleware) GetName() string {
return "error_tracking"
}
// GetPriority 返回中间件优先级
func (m *ErrorTrackingMiddleware) GetPriority() int {
return 60 // 低优先级,在大多数中间件之后执行
}
// Handle 返回中间件处理函数
func (m *ErrorTrackingMiddleware) Handle() gin.HandlerFunc {
return func(c *gin.Context) {
c.Next()
// 检查是否有错误
if len(c.Errors) > 0 || c.Writer.Status() >= 400 {
// 获取当前span
span := trace.SpanFromContext(c.Request.Context())
if span.IsRecording() {
// 标记为错误操作确保100%采样
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录错误日志
traceID := m.tracer.GetTraceID(c.Request.Context())
spanID := m.tracer.GetSpanID(c.Request.Context())
logFields := []zap.Field{
zap.Int("status_code", c.Writer.Status()),
zap.String("method", c.Request.Method),
zap.String("path", c.FullPath()),
zap.String("client_ip", c.ClientIP()),
}
// 添加追踪信息
if traceID != "" {
logFields = append(logFields, zap.String("trace_id", traceID))
}
if spanID != "" {
logFields = append(logFields, zap.String("span_id", spanID))
}
// 添加错误信息
if len(c.Errors) > 0 {
logFields = append(logFields, zap.String("errors", c.Errors.String()))
}
// 根据状态码决定日志级别
if c.Writer.Status() >= 500 {
m.logger.Error("服务器错误", logFields...)
} else {
m.logger.Warn("客户端错误", logFields...)
}
}
}
}
}
// IsGlobal 是否为全局中间件
func (m *ErrorTrackingMiddleware) IsGlobal() bool {
return true
}

View File

@@ -0,0 +1,389 @@
package resilience
import (
"context"
"errors"
"sync"
"time"
"go.uber.org/zap"
)
// CircuitState 熔断器状态
type CircuitState int
const (
// StateClosed 关闭状态(正常)
StateClosed CircuitState = iota
// StateOpen 开启状态(熔断)
StateOpen
// StateHalfOpen 半开状态(测试)
StateHalfOpen
)
func (s CircuitState) String() string {
switch s {
case StateClosed:
return "CLOSED"
case StateOpen:
return "OPEN"
case StateHalfOpen:
return "HALF_OPEN"
default:
return "UNKNOWN"
}
}
// CircuitBreakerConfig 熔断器配置
type CircuitBreakerConfig struct {
// 故障阈值
FailureThreshold int
// 重置超时时间
ResetTimeout time.Duration
// 检测窗口大小
WindowSize int
// 半开状态允许的请求数
HalfOpenMaxRequests int
// 成功阈值(半开->关闭)
SuccessThreshold int
}
// DefaultCircuitBreakerConfig 默认熔断器配置
func DefaultCircuitBreakerConfig() CircuitBreakerConfig {
return CircuitBreakerConfig{
FailureThreshold: 5,
ResetTimeout: 60 * time.Second,
WindowSize: 10,
HalfOpenMaxRequests: 3,
SuccessThreshold: 2,
}
}
// CircuitBreaker 熔断器
type CircuitBreaker struct {
config CircuitBreakerConfig
logger *zap.Logger
mutex sync.RWMutex
// 状态
state CircuitState
// 计数器
failures int
successes int
requests int
consecutiveFailures int
// 时间记录
lastFailTime time.Time
lastStateChange time.Time
// 统计窗口
window []bool // true=success, false=failure
windowIndex int
windowFull bool
// 事件回调
onStateChange func(from, to CircuitState)
}
// NewCircuitBreaker 创建熔断器
func NewCircuitBreaker(config CircuitBreakerConfig, logger *zap.Logger) *CircuitBreaker {
cb := &CircuitBreaker{
config: config,
logger: logger,
state: StateClosed,
window: make([]bool, config.WindowSize),
lastStateChange: time.Now(),
}
return cb
}
// Execute 执行函数,如果熔断器开启则快速失败
func (cb *CircuitBreaker) Execute(ctx context.Context, fn func() error) error {
// 检查是否允许执行
if !cb.allowRequest() {
return ErrCircuitBreakerOpen
}
// 执行函数
start := time.Now()
err := fn()
duration := time.Since(start)
// 记录结果
cb.recordResult(err == nil, duration)
return err
}
// allowRequest 检查是否允许请求
func (cb *CircuitBreaker) allowRequest() bool {
cb.mutex.Lock()
defer cb.mutex.Unlock()
now := time.Now()
switch cb.state {
case StateClosed:
return true
case StateOpen:
// 检查是否到了重置时间
if now.Sub(cb.lastStateChange) > cb.config.ResetTimeout {
cb.setState(StateHalfOpen)
return true
}
return false
case StateHalfOpen:
// 半开状态下限制请求数
return cb.requests < cb.config.HalfOpenMaxRequests
default:
return false
}
}
// recordResult 记录执行结果
func (cb *CircuitBreaker) recordResult(success bool, duration time.Duration) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.requests++
// 更新滑动窗口
cb.updateWindow(success)
if success {
cb.successes++
cb.consecutiveFailures = 0
cb.onSuccess()
} else {
cb.failures++
cb.consecutiveFailures++
cb.lastFailTime = time.Now()
cb.onFailure()
}
cb.logger.Debug("Circuit breaker recorded result",
zap.Bool("success", success),
zap.Duration("duration", duration),
zap.String("state", cb.state.String()),
zap.Int("failures", cb.failures),
zap.Int("successes", cb.successes))
}
// updateWindow 更新滑动窗口
func (cb *CircuitBreaker) updateWindow(success bool) {
cb.window[cb.windowIndex] = success
cb.windowIndex = (cb.windowIndex + 1) % cb.config.WindowSize
if cb.windowIndex == 0 {
cb.windowFull = true
}
}
// onSuccess 成功时的处理
func (cb *CircuitBreaker) onSuccess() {
if cb.state == StateHalfOpen {
// 半开状态下,如果成功次数达到阈值,则关闭熔断器
if cb.successes >= cb.config.SuccessThreshold {
cb.setState(StateClosed)
}
}
}
// onFailure 失败时的处理
func (cb *CircuitBreaker) onFailure() {
if cb.state == StateClosed {
// 关闭状态下,检查是否需要开启熔断器
if cb.shouldTrip() {
cb.setState(StateOpen)
}
} else if cb.state == StateHalfOpen {
// 半开状态下,如果失败则立即开启熔断器
cb.setState(StateOpen)
}
}
// shouldTrip 检查是否应该触发熔断
func (cb *CircuitBreaker) shouldTrip() bool {
// 基于连续失败次数
if cb.consecutiveFailures >= cb.config.FailureThreshold {
return true
}
// 基于滑动窗口的失败率
if cb.windowFull {
failures := 0
for _, success := range cb.window {
if !success {
failures++
}
}
failureRate := float64(failures) / float64(cb.config.WindowSize)
return failureRate >= 0.5 // 50%失败率
}
return false
}
// setState 设置状态
func (cb *CircuitBreaker) setState(newState CircuitState) {
if cb.state == newState {
return
}
oldState := cb.state
cb.state = newState
cb.lastStateChange = time.Now()
// 重置计数器
if newState == StateClosed {
cb.requests = 0
cb.failures = 0
cb.successes = 0
cb.consecutiveFailures = 0
} else if newState == StateHalfOpen {
cb.requests = 0
cb.successes = 0
}
cb.logger.Info("Circuit breaker state changed",
zap.String("from", oldState.String()),
zap.String("to", newState.String()),
zap.Int("failures", cb.failures),
zap.Int("successes", cb.successes))
// 触发状态变更回调
if cb.onStateChange != nil {
cb.onStateChange(oldState, newState)
}
}
// GetState 获取当前状态
func (cb *CircuitBreaker) GetState() CircuitState {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return cb.state
}
// GetStats 获取统计信息
func (cb *CircuitBreaker) GetStats() CircuitBreakerStats {
cb.mutex.RLock()
defer cb.mutex.RUnlock()
return CircuitBreakerStats{
State: cb.state.String(),
Failures: cb.failures,
Successes: cb.successes,
Requests: cb.requests,
ConsecutiveFailures: cb.consecutiveFailures,
LastFailTime: cb.lastFailTime,
LastStateChange: cb.lastStateChange,
FailureThreshold: cb.config.FailureThreshold,
ResetTimeout: cb.config.ResetTimeout,
}
}
// Reset 重置熔断器
func (cb *CircuitBreaker) Reset() {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.setState(StateClosed)
cb.window = make([]bool, cb.config.WindowSize)
cb.windowIndex = 0
cb.windowFull = false
cb.logger.Info("Circuit breaker reset")
}
// SetStateChangeCallback 设置状态变更回调
func (cb *CircuitBreaker) SetStateChangeCallback(callback func(from, to CircuitState)) {
cb.mutex.Lock()
defer cb.mutex.Unlock()
cb.onStateChange = callback
}
// CircuitBreakerStats 熔断器统计信息
type CircuitBreakerStats struct {
State string `json:"state"`
Failures int `json:"failures"`
Successes int `json:"successes"`
Requests int `json:"requests"`
ConsecutiveFailures int `json:"consecutive_failures"`
LastFailTime time.Time `json:"last_fail_time"`
LastStateChange time.Time `json:"last_state_change"`
FailureThreshold int `json:"failure_threshold"`
ResetTimeout time.Duration `json:"reset_timeout"`
}
// 预定义错误
var (
ErrCircuitBreakerOpen = errors.New("circuit breaker is open")
)
// Wrapper 熔断器包装器
type Wrapper struct {
breakers map[string]*CircuitBreaker
logger *zap.Logger
mutex sync.RWMutex
}
// NewWrapper 创建熔断器包装器
func NewWrapper(logger *zap.Logger) *Wrapper {
return &Wrapper{
breakers: make(map[string]*CircuitBreaker),
logger: logger,
}
}
// GetOrCreate 获取或创建熔断器
func (w *Wrapper) GetOrCreate(name string, config CircuitBreakerConfig) *CircuitBreaker {
w.mutex.Lock()
defer w.mutex.Unlock()
if cb, exists := w.breakers[name]; exists {
return cb
}
cb := NewCircuitBreaker(config, w.logger.Named(name))
w.breakers[name] = cb
w.logger.Info("Created circuit breaker", zap.String("name", name))
return cb
}
// Execute 执行带熔断器的函数
func (w *Wrapper) Execute(ctx context.Context, name string, fn func() error) error {
cb := w.GetOrCreate(name, DefaultCircuitBreakerConfig())
return cb.Execute(ctx, fn)
}
// GetStats 获取所有熔断器统计
func (w *Wrapper) GetStats() map[string]CircuitBreakerStats {
w.mutex.RLock()
defer w.mutex.RUnlock()
stats := make(map[string]CircuitBreakerStats)
for name, cb := range w.breakers {
stats[name] = cb.GetStats()
}
return stats
}
// ResetAll 重置所有熔断器
func (w *Wrapper) ResetAll() {
w.mutex.RLock()
defer w.mutex.RUnlock()
for name, cb := range w.breakers {
cb.Reset()
w.logger.Info("Reset circuit breaker", zap.String("name", name))
}
}

View File

@@ -0,0 +1,467 @@
package resilience
import (
"context"
"fmt"
"math/rand"
"sync"
"time"
"go.uber.org/zap"
)
// RetryConfig 重试配置
type RetryConfig struct {
// 最大重试次数
MaxAttempts int
// 初始延迟
InitialDelay time.Duration
// 最大延迟
MaxDelay time.Duration
// 退避倍数
BackoffMultiplier float64
// 抖动系数
JitterFactor float64
// 重试条件
RetryCondition func(error) bool
// 延迟函数
DelayFunc func(attempt int, config RetryConfig) time.Duration
}
// DefaultRetryConfig 默认重试配置
func DefaultRetryConfig() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffMultiplier: 2.0,
JitterFactor: 0.1,
RetryCondition: DefaultRetryCondition,
DelayFunc: ExponentialBackoffWithJitter,
}
}
// RetryableError 可重试错误接口
type RetryableError interface {
error
IsRetryable() bool
}
// DefaultRetryCondition 默认重试条件
func DefaultRetryCondition(err error) bool {
if err == nil {
return false
}
// 检查是否实现了RetryableError接口
if retryable, ok := err.(RetryableError); ok {
return retryable.IsRetryable()
}
// 默认所有错误都重试
return true
}
// IsRetryableHTTPError HTTP错误重试条件
func IsRetryableHTTPError(statusCode int) bool {
// 5xx错误通常可以重试
// 429Too Many Requests也可以重试
return statusCode >= 500 || statusCode == 429
}
// DelayFunction 延迟函数类型
type DelayFunction func(attempt int, config RetryConfig) time.Duration
// FixedDelay 固定延迟
func FixedDelay(attempt int, config RetryConfig) time.Duration {
return config.InitialDelay
}
// LinearBackoff 线性退避
func LinearBackoff(attempt int, config RetryConfig) time.Duration {
delay := time.Duration(attempt) * config.InitialDelay
if delay > config.MaxDelay {
delay = config.MaxDelay
}
return delay
}
// ExponentialBackoff 指数退避
func ExponentialBackoff(attempt int, config RetryConfig) time.Duration {
delay := config.InitialDelay
for i := 0; i < attempt; i++ {
delay = time.Duration(float64(delay) * config.BackoffMultiplier)
}
if delay > config.MaxDelay {
delay = config.MaxDelay
}
return delay
}
// ExponentialBackoffWithJitter 带抖动的指数退避
func ExponentialBackoffWithJitter(attempt int, config RetryConfig) time.Duration {
delay := ExponentialBackoff(attempt, config)
// 添加抖动
jitter := config.JitterFactor
if jitter > 0 {
jitterRange := float64(delay) * jitter
jitterOffset := (rand.Float64() - 0.5) * 2 * jitterRange
delay = time.Duration(float64(delay) + jitterOffset)
}
if delay < 0 {
delay = config.InitialDelay
}
return delay
}
// RetryStats 重试统计
type RetryStats struct {
TotalAttempts int `json:"total_attempts"`
Successes int `json:"successes"`
Failures int `json:"failures"`
TotalRetries int `json:"total_retries"`
AverageAttempts float64 `json:"average_attempts"`
TotalDelay time.Duration `json:"total_delay"`
LastError string `json:"last_error,omitempty"`
}
// Retryer 重试器
type Retryer struct {
config RetryConfig
logger *zap.Logger
stats RetryStats
}
// NewRetryer 创建重试器
func NewRetryer(config RetryConfig, logger *zap.Logger) *Retryer {
if config.DelayFunc == nil {
config.DelayFunc = ExponentialBackoffWithJitter
}
if config.RetryCondition == nil {
config.RetryCondition = DefaultRetryCondition
}
return &Retryer{
config: config,
logger: logger,
}
}
// Execute 执行带重试的函数
func (r *Retryer) Execute(ctx context.Context, operation func() error) error {
return r.ExecuteWithResult(ctx, func() (interface{}, error) {
return nil, operation()
})
}
// ExecuteWithResult 执行带重试和返回值的函数
func (r *Retryer) ExecuteWithResult(ctx context.Context, operation func() (interface{}, error)) error {
var lastErr error
startTime := time.Now()
for attempt := 0; attempt < r.config.MaxAttempts; attempt++ {
// 检查上下文是否被取消
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// 执行操作
attemptStart := time.Now()
_, err := operation()
attemptDuration := time.Since(attemptStart)
// 更新统计
r.stats.TotalAttempts++
if err == nil {
r.stats.Successes++
r.logger.Debug("Operation succeeded",
zap.Int("attempt", attempt+1),
zap.Duration("duration", attemptDuration))
return nil
}
lastErr = err
r.stats.Failures++
if attempt > 0 {
r.stats.TotalRetries++
}
// 检查是否应该重试
if !r.config.RetryCondition(err) {
r.logger.Debug("Error is not retryable",
zap.Error(err),
zap.Int("attempt", attempt+1))
break
}
// 如果这是最后一次尝试,不需要延迟
if attempt == r.config.MaxAttempts-1 {
r.logger.Debug("Reached max attempts",
zap.Error(err),
zap.Int("max_attempts", r.config.MaxAttempts))
break
}
// 计算延迟
delay := r.config.DelayFunc(attempt, r.config)
r.stats.TotalDelay += delay
r.logger.Debug("Operation failed, retrying",
zap.Error(err),
zap.Int("attempt", attempt+1),
zap.Duration("delay", delay),
zap.Duration("attempt_duration", attemptDuration))
// 等待重试
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
}
}
// 更新最终统计
totalDuration := time.Since(startTime)
if r.stats.TotalAttempts > 0 {
r.stats.AverageAttempts = float64(r.stats.TotalRetries) / float64(r.stats.Successes+r.stats.Failures)
}
if lastErr != nil {
r.stats.LastError = lastErr.Error()
}
r.logger.Warn("Operation failed after all retries",
zap.Error(lastErr),
zap.Int("total_attempts", r.stats.TotalAttempts),
zap.Duration("total_duration", totalDuration))
return fmt.Errorf("operation failed after %d attempts: %w", r.config.MaxAttempts, lastErr)
}
// GetStats 获取重试统计
func (r *Retryer) GetStats() RetryStats {
return r.stats
}
// Reset 重置统计
func (r *Retryer) Reset() {
r.stats = RetryStats{}
r.logger.Debug("Retry stats reset")
}
// Retry 简单重试函数
func Retry(ctx context.Context, config RetryConfig, operation func() error) error {
retryer := NewRetryer(config, zap.NewNop())
return retryer.Execute(ctx, operation)
}
// RetryWithResult 带返回值的重试函数
func RetryWithResult[T any](ctx context.Context, config RetryConfig, operation func() (T, error)) (T, error) {
var result T
var finalErr error
retryer := NewRetryer(config, zap.NewNop())
err := retryer.ExecuteWithResult(ctx, func() (interface{}, error) {
r, e := operation()
result = r
return r, e
})
if err != nil {
finalErr = err
}
return result, finalErr
}
// 预定义的重试配置
// QuickRetry 快速重试(适用于轻量级操作)
func QuickRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 50 * time.Millisecond,
MaxDelay: 500 * time.Millisecond,
BackoffMultiplier: 2.0,
JitterFactor: 0.1,
RetryCondition: DefaultRetryCondition,
DelayFunc: ExponentialBackoffWithJitter,
}
}
// StandardRetry 标准重试(适用于一般操作)
func StandardRetry() RetryConfig {
return DefaultRetryConfig()
}
// PatientRetry 耐心重试(适用于重要操作)
func PatientRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 5,
InitialDelay: 200 * time.Millisecond,
MaxDelay: 10 * time.Second,
BackoffMultiplier: 2.0,
JitterFactor: 0.2,
RetryCondition: DefaultRetryCondition,
DelayFunc: ExponentialBackoffWithJitter,
}
}
// DatabaseRetry 数据库重试配置
func DatabaseRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 100 * time.Millisecond,
MaxDelay: 2 * time.Second,
BackoffMultiplier: 1.5,
JitterFactor: 0.1,
RetryCondition: func(err error) bool {
// 这里可以根据具体的数据库错误类型判断
// 例如:连接超时、临时网络错误等
return DefaultRetryCondition(err)
},
DelayFunc: ExponentialBackoffWithJitter,
}
}
// HTTPRetry HTTP重试配置
func HTTPRetry() RetryConfig {
return RetryConfig{
MaxAttempts: 3,
InitialDelay: 200 * time.Millisecond,
MaxDelay: 5 * time.Second,
BackoffMultiplier: 2.0,
JitterFactor: 0.15,
RetryCondition: func(err error) bool {
// HTTP相关的重试条件
return DefaultRetryCondition(err)
},
DelayFunc: ExponentialBackoffWithJitter,
}
}
// RetryManager 重试管理器
type RetryManager struct {
retryers map[string]*Retryer
logger *zap.Logger
mutex sync.RWMutex
}
// NewRetryManager 创建重试管理器
func NewRetryManager(logger *zap.Logger) *RetryManager {
return &RetryManager{
retryers: make(map[string]*Retryer),
logger: logger,
}
}
// GetOrCreate 获取或创建重试器
func (rm *RetryManager) GetOrCreate(name string, config RetryConfig) *Retryer {
rm.mutex.Lock()
defer rm.mutex.Unlock()
if retryer, exists := rm.retryers[name]; exists {
return retryer
}
retryer := NewRetryer(config, rm.logger.Named(name))
rm.retryers[name] = retryer
rm.logger.Info("Created retryer", zap.String("name", name))
return retryer
}
// Execute 执行带重试的操作
func (rm *RetryManager) Execute(ctx context.Context, name string, operation func() error) error {
retryer := rm.GetOrCreate(name, DefaultRetryConfig())
return retryer.Execute(ctx, operation)
}
// GetStats 获取所有重试器统计
func (rm *RetryManager) GetStats() map[string]RetryStats {
rm.mutex.RLock()
defer rm.mutex.RUnlock()
stats := make(map[string]RetryStats)
for name, retryer := range rm.retryers {
stats[name] = retryer.GetStats()
}
return stats
}
// ResetAll 重置所有重试器统计
func (rm *RetryManager) ResetAll() {
rm.mutex.RLock()
defer rm.mutex.RUnlock()
for name, retryer := range rm.retryers {
retryer.Reset()
rm.logger.Info("Reset retryer stats", zap.String("name", name))
}
}
// RetryerWrapper 重试器包装器
type RetryerWrapper struct {
manager *RetryManager
logger *zap.Logger
}
// NewRetryerWrapper 创建重试器包装器
func NewRetryerWrapper(logger *zap.Logger) *RetryerWrapper {
return &RetryerWrapper{
manager: NewRetryManager(logger),
logger: logger,
}
}
// ExecuteWithQuickRetry 执行快速重试
func (rw *RetryerWrapper) ExecuteWithQuickRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".quick", QuickRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithStandardRetry 执行标准重试
func (rw *RetryerWrapper) ExecuteWithStandardRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".standard", StandardRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithDatabaseRetry 执行数据库重试
func (rw *RetryerWrapper) ExecuteWithDatabaseRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".database", DatabaseRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithHTTPRetry 执行HTTP重试
func (rw *RetryerWrapper) ExecuteWithHTTPRetry(ctx context.Context, name string, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".http", HTTPRetry())
return retryer.Execute(ctx, operation)
}
// ExecuteWithCustomRetry 执行自定义重试
func (rw *RetryerWrapper) ExecuteWithCustomRetry(ctx context.Context, name string, config RetryConfig, operation func() error) error {
retryer := rw.manager.GetOrCreate(name+".custom", config)
return retryer.Execute(ctx, operation)
}
// GetManager 获取重试管理器
func (rw *RetryerWrapper) GetManager() *RetryManager {
return rw.manager
}
// GetAllStats 获取所有统计信息
func (rw *RetryerWrapper) GetAllStats() map[string]RetryStats {
return rw.manager.GetStats()
}
// ResetAllStats 重置所有统计信息
func (rw *RetryerWrapper) ResetAllStats() {
rw.manager.ResetAll()
}

View File

@@ -0,0 +1,612 @@
package saga
import (
"context"
"fmt"
"sync"
"time"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// SagaStatus Saga状态
type SagaStatus int
const (
// StatusPending 等待中
StatusPending SagaStatus = iota
// StatusRunning 执行中
StatusRunning
// StatusCompleted 已完成
StatusCompleted
// StatusFailed 失败
StatusFailed
// StatusCompensating 补偿中
StatusCompensating
// StatusCompensated 已补偿
StatusCompensated
// StatusAborted 已中止
StatusAborted
)
func (s SagaStatus) String() string {
switch s {
case StatusPending:
return "PENDING"
case StatusRunning:
return "RUNNING"
case StatusCompleted:
return "COMPLETED"
case StatusFailed:
return "FAILED"
case StatusCompensating:
return "COMPENSATING"
case StatusCompensated:
return "COMPENSATED"
case StatusAborted:
return "ABORTED"
default:
return "UNKNOWN"
}
}
// StepStatus 步骤状态
type StepStatus int
const (
// StepPending 等待执行
StepPending StepStatus = iota
// StepRunning 执行中
StepRunning
// StepCompleted 完成
StepCompleted
// StepFailed 失败
StepFailed
// StepCompensated 已补偿
StepCompensated
// StepSkipped 跳过
StepSkipped
)
func (s StepStatus) String() string {
switch s {
case StepPending:
return "PENDING"
case StepRunning:
return "RUNNING"
case StepCompleted:
return "COMPLETED"
case StepFailed:
return "FAILED"
case StepCompensated:
return "COMPENSATED"
case StepSkipped:
return "SKIPPED"
default:
return "UNKNOWN"
}
}
// SagaStep Saga步骤
type SagaStep struct {
Name string
Action func(ctx context.Context, data interface{}) error
Compensate func(ctx context.Context, data interface{}) error
Status StepStatus
Error error
StartTime time.Time
EndTime time.Time
RetryCount int
MaxRetries int
Timeout time.Duration
}
// SagaConfig Saga配置
type SagaConfig struct {
// 默认超时时间
DefaultTimeout time.Duration
// 默认重试次数
DefaultMaxRetries int
// 是否并行执行(当前只支持串行)
Parallel bool
// 事件发布器
EventBus interfaces.EventBus
}
// DefaultSagaConfig 默认Saga配置
func DefaultSagaConfig() SagaConfig {
return SagaConfig{
DefaultTimeout: 30 * time.Second,
DefaultMaxRetries: 3,
Parallel: false,
}
}
// Saga 分布式事务
type Saga struct {
ID string
Name string
Steps []*SagaStep
Status SagaStatus
Data interface{}
StartTime time.Time
EndTime time.Time
Error error
Config SagaConfig
logger *zap.Logger
mutex sync.RWMutex
currentStep int
result interface{}
}
// NewSaga 创建新的Saga
func NewSaga(id, name string, config SagaConfig, logger *zap.Logger) *Saga {
return &Saga{
ID: id,
Name: name,
Steps: make([]*SagaStep, 0),
Status: StatusPending,
Config: config,
logger: logger,
currentStep: -1,
}
}
// AddStep 添加步骤
func (s *Saga) AddStep(name string, action, compensate func(ctx context.Context, data interface{}) error) *Saga {
step := &SagaStep{
Name: name,
Action: action,
Compensate: compensate,
Status: StepPending,
MaxRetries: s.Config.DefaultMaxRetries,
Timeout: s.Config.DefaultTimeout,
}
s.mutex.Lock()
s.Steps = append(s.Steps, step)
s.mutex.Unlock()
s.logger.Debug("Added step to saga",
zap.String("saga_id", s.ID),
zap.String("step_name", name))
return s
}
// AddStepWithConfig 添加带配置的步骤
func (s *Saga) AddStepWithConfig(name string, action, compensate func(ctx context.Context, data interface{}) error, maxRetries int, timeout time.Duration) *Saga {
step := &SagaStep{
Name: name,
Action: action,
Compensate: compensate,
Status: StepPending,
MaxRetries: maxRetries,
Timeout: timeout,
}
s.mutex.Lock()
s.Steps = append(s.Steps, step)
s.mutex.Unlock()
s.logger.Debug("Added step with config to saga",
zap.String("saga_id", s.ID),
zap.String("step_name", name),
zap.Int("max_retries", maxRetries),
zap.Duration("timeout", timeout))
return s
}
// Execute 执行Saga
func (s *Saga) Execute(ctx context.Context, data interface{}) error {
s.mutex.Lock()
if s.Status != StatusPending {
s.mutex.Unlock()
return fmt.Errorf("saga %s is not in pending status", s.ID)
}
s.Status = StatusRunning
s.Data = data
s.StartTime = time.Now()
s.mutex.Unlock()
s.logger.Info("Starting saga execution",
zap.String("saga_id", s.ID),
zap.String("saga_name", s.Name),
zap.Int("total_steps", len(s.Steps)))
// 发布Saga开始事件
s.publishEvent(ctx, "saga.started")
// 执行所有步骤
for i, step := range s.Steps {
s.mutex.Lock()
s.currentStep = i
s.mutex.Unlock()
if err := s.executeStep(ctx, step, data); err != nil {
s.logger.Error("Step execution failed",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Error(err))
// 执行补偿
if compensateErr := s.compensate(ctx, i-1); compensateErr != nil {
s.logger.Error("Compensation failed",
zap.String("saga_id", s.ID),
zap.Error(compensateErr))
s.setStatus(StatusAborted)
s.publishEvent(ctx, "saga.aborted")
return fmt.Errorf("saga execution failed and compensation failed: %w", compensateErr)
}
s.setStatus(StatusCompensated)
s.publishEvent(ctx, "saga.compensated")
return fmt.Errorf("saga execution failed: %w", err)
}
}
// 所有步骤成功完成
s.setStatus(StatusCompleted)
s.EndTime = time.Now()
s.logger.Info("Saga completed successfully",
zap.String("saga_id", s.ID),
zap.Duration("duration", s.EndTime.Sub(s.StartTime)))
s.publishEvent(ctx, "saga.completed")
return nil
}
// executeStep 执行单个步骤
func (s *Saga) executeStep(ctx context.Context, step *SagaStep, data interface{}) error {
step.Status = StepRunning
step.StartTime = time.Now()
s.logger.Debug("Executing step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
// 设置超时上下文
stepCtx, cancel := context.WithTimeout(ctx, step.Timeout)
defer cancel()
// 重试逻辑
var lastErr error
for attempt := 0; attempt <= step.MaxRetries; attempt++ {
if attempt > 0 {
s.logger.Debug("Retrying step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Int("attempt", attempt))
}
err := step.Action(stepCtx, data)
if err == nil {
step.Status = StepCompleted
step.EndTime = time.Now()
s.logger.Debug("Step completed successfully",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Duration("duration", step.EndTime.Sub(step.StartTime)))
return nil
}
lastErr = err
step.RetryCount = attempt
// 检查是否应该重试
if attempt < step.MaxRetries {
select {
case <-stepCtx.Done():
// 上下文被取消,停止重试
break
case <-time.After(time.Duration(attempt+1) * 100 * time.Millisecond):
// 等待一段时间后重试
}
}
}
// 所有重试都失败了
step.Status = StepFailed
step.Error = lastErr
step.EndTime = time.Now()
return lastErr
}
// compensate 执行补偿
func (s *Saga) compensate(ctx context.Context, fromStep int) error {
s.setStatus(StatusCompensating)
s.logger.Info("Starting compensation",
zap.String("saga_id", s.ID),
zap.Int("from_step", fromStep))
// 逆序执行补偿
for i := fromStep; i >= 0; i-- {
step := s.Steps[i]
// 只补偿已完成的步骤
if step.Status != StepCompleted {
step.Status = StepSkipped
continue
}
if step.Compensate == nil {
s.logger.Warn("No compensation function for step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
continue
}
s.logger.Debug("Compensating step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
// 设置超时上下文
compensateCtx, cancel := context.WithTimeout(ctx, step.Timeout)
err := step.Compensate(compensateCtx, s.Data)
cancel()
if err != nil {
s.logger.Error("Compensation failed for step",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name),
zap.Error(err))
return err
}
step.Status = StepCompensated
s.logger.Debug("Step compensated successfully",
zap.String("saga_id", s.ID),
zap.String("step_name", step.Name))
}
s.logger.Info("Compensation completed",
zap.String("saga_id", s.ID))
return nil
}
// setStatus 设置状态
func (s *Saga) setStatus(status SagaStatus) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.Status = status
}
// GetStatus 获取状态
func (s *Saga) GetStatus() SagaStatus {
s.mutex.RLock()
defer s.mutex.RUnlock()
return s.Status
}
// GetProgress 获取进度
func (s *Saga) GetProgress() SagaProgress {
s.mutex.RLock()
defer s.mutex.RUnlock()
completed := 0
for _, step := range s.Steps {
if step.Status == StepCompleted {
completed++
}
}
var percentage float64
if len(s.Steps) > 0 {
percentage = float64(completed) / float64(len(s.Steps)) * 100
}
return SagaProgress{
SagaID: s.ID,
Status: s.Status.String(),
TotalSteps: len(s.Steps),
CompletedSteps: completed,
CurrentStep: s.currentStep + 1,
PercentComplete: percentage,
StartTime: s.StartTime,
Duration: time.Since(s.StartTime),
}
}
// GetStepStatus 获取所有步骤状态
func (s *Saga) GetStepStatus() []StepProgress {
s.mutex.RLock()
defer s.mutex.RUnlock()
progress := make([]StepProgress, len(s.Steps))
for i, step := range s.Steps {
progress[i] = StepProgress{
Name: step.Name,
Status: step.Status.String(),
RetryCount: step.RetryCount,
StartTime: step.StartTime,
EndTime: step.EndTime,
Duration: step.EndTime.Sub(step.StartTime),
Error: "",
}
if step.Error != nil {
progress[i].Error = step.Error.Error()
}
}
return progress
}
// publishEvent 发布事件
func (s *Saga) publishEvent(ctx context.Context, eventType string) {
if s.Config.EventBus == nil {
return
}
event := &SagaEvent{
SagaID: s.ID,
SagaName: s.Name,
EventType: eventType,
Status: s.Status.String(),
Timestamp: time.Now(),
Data: s.Data,
}
// 这里应该实现Event接口简化处理
_ = event
}
// SagaProgress Saga进度
type SagaProgress struct {
SagaID string `json:"saga_id"`
Status string `json:"status"`
TotalSteps int `json:"total_steps"`
CompletedSteps int `json:"completed_steps"`
CurrentStep int `json:"current_step"`
PercentComplete float64 `json:"percent_complete"`
StartTime time.Time `json:"start_time"`
Duration time.Duration `json:"duration"`
}
// StepProgress 步骤进度
type StepProgress struct {
Name string `json:"name"`
Status string `json:"status"`
RetryCount int `json:"retry_count"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration time.Duration `json:"duration"`
Error string `json:"error,omitempty"`
}
// SagaEvent Saga事件
type SagaEvent struct {
SagaID string `json:"saga_id"`
SagaName string `json:"saga_name"`
EventType string `json:"event_type"`
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Data interface{} `json:"data,omitempty"`
}
// SagaManager Saga管理器
type SagaManager struct {
sagas map[string]*Saga
logger *zap.Logger
mutex sync.RWMutex
config SagaConfig
}
// NewSagaManager 创建Saga管理器
func NewSagaManager(config SagaConfig, logger *zap.Logger) *SagaManager {
return &SagaManager{
sagas: make(map[string]*Saga),
logger: logger,
config: config,
}
}
// CreateSaga 创建Saga
func (sm *SagaManager) CreateSaga(id, name string) *Saga {
saga := NewSaga(id, name, sm.config, sm.logger.Named("saga"))
sm.mutex.Lock()
sm.sagas[id] = saga
sm.mutex.Unlock()
sm.logger.Info("Created saga",
zap.String("saga_id", id),
zap.String("saga_name", name))
return saga
}
// GetSaga 获取Saga
func (sm *SagaManager) GetSaga(id string) (*Saga, bool) {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
saga, exists := sm.sagas[id]
return saga, exists
}
// ListSagas 列出所有Saga
func (sm *SagaManager) ListSagas() []*Saga {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
sagas := make([]*Saga, 0, len(sm.sagas))
for _, saga := range sm.sagas {
sagas = append(sagas, saga)
}
return sagas
}
// GetSagaProgress 获取Saga进度
func (sm *SagaManager) GetSagaProgress(id string) (SagaProgress, bool) {
saga, exists := sm.GetSaga(id)
if !exists {
return SagaProgress{}, false
}
return saga.GetProgress(), true
}
// RemoveSaga 移除Saga
func (sm *SagaManager) RemoveSaga(id string) {
sm.mutex.Lock()
defer sm.mutex.Unlock()
delete(sm.sagas, id)
sm.logger.Debug("Removed saga", zap.String("saga_id", id))
}
// GetStats 获取统计信息
func (sm *SagaManager) GetStats() map[string]interface{} {
sm.mutex.RLock()
defer sm.mutex.RUnlock()
statusCount := make(map[string]int)
for _, saga := range sm.sagas {
status := saga.GetStatus().String()
statusCount[status]++
}
return map[string]interface{}{
"total_sagas": len(sm.sagas),
"status_count": statusCount,
}
}
// 实现Service接口
// Name 返回服务名称
func (sm *SagaManager) Name() string {
return "saga-manager"
}
// Initialize 初始化服务
func (sm *SagaManager) Initialize(ctx context.Context) error {
sm.logger.Info("Saga manager service initialized")
return nil
}
// HealthCheck 健康检查
func (sm *SagaManager) HealthCheck(ctx context.Context) error {
return nil
}
// Shutdown 关闭服务
func (sm *SagaManager) Shutdown(ctx context.Context) error {
sm.logger.Info("Saga manager service shutdown")
return nil
}

View File

@@ -0,0 +1,130 @@
package sms
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
"go.uber.org/zap"
"tyapi-server/internal/config"
)
// Service 短信服务接口
type Service interface {
SendVerificationCode(ctx context.Context, phone string, code string) error
GenerateCode(length int) string
}
// AliSMSService 阿里云短信服务实现
type AliSMSService struct {
client *dysmsapi.Client
config config.SMSConfig
logger *zap.Logger
}
// NewAliSMSService 创建阿里云短信服务
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
if err != nil {
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
}
return &AliSMSService{
client: client,
config: cfg,
logger: logger,
}, nil
}
// SendVerificationCode 发送验证码
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.PhoneNumbers = phone
request.SignName = s.config.SignName
request.TemplateCode = s.config.TemplateCode
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("Failed to send SMS",
zap.String("phone", phone),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Code != "OK" {
s.logger.Error("SMS send failed",
zap.String("phone", phone),
zap.String("code", response.Code),
zap.String("message", response.Message))
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
}
s.logger.Info("SMS sent successfully",
zap.String("phone", phone),
zap.String("bizId", response.BizId))
return nil
}
// GenerateCode 生成验证码
func (s *AliSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
// 生成指定长度的数字验证码
max := big.NewInt(int64(pow10(length)))
n, _ := rand.Int(rand.Reader, max)
// 格式化为指定长度不足时前面补0
format := fmt.Sprintf("%%0%dd", length)
return fmt.Sprintf(format, n.Int64())
}
// pow10 计算10的n次方
func pow10(n int) int {
result := 1
for i := 0; i < n; i++ {
result *= 10
}
return result
}
// MockSMSService 模拟短信服务(用于开发和测试)
type MockSMSService struct {
logger *zap.Logger
}
// NewMockSMSService 创建模拟短信服务
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
return &MockSMSService{
logger: logger,
}
}
// SendVerificationCode 模拟发送验证码
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
s.logger.Info("Mock SMS sent",
zap.String("phone", phone),
zap.String("code", code))
return nil
}
// GenerateCode 生成验证码
func (s *MockSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
// 开发环境使用固定验证码便于测试
result := ""
for i := 0; i < length; i++ {
result += "1"
}
return result
}

View File

@@ -0,0 +1,292 @@
package tracing
import (
"context"
"fmt"
"reflect"
"runtime"
"strings"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
// TracableService 可追踪的服务接口
type TracableService interface {
Name() string
}
// ServiceDecorator 服务装饰器
type ServiceDecorator struct {
tracer *Tracer
logger *zap.Logger
config DecoratorConfig
}
// DecoratorConfig 装饰器配置
type DecoratorConfig struct {
EnableMethodTracing bool
ExcludePatterns []string
IncludeArguments bool
IncludeResults bool
SlowMethodThreshold time.Duration
}
// DefaultDecoratorConfig 默认装饰器配置
func DefaultDecoratorConfig() DecoratorConfig {
return DecoratorConfig{
EnableMethodTracing: true,
ExcludePatterns: []string{"Health", "Ping", "Name"},
IncludeArguments: true,
IncludeResults: false,
SlowMethodThreshold: 100 * time.Millisecond,
}
}
// NewServiceDecorator 创建服务装饰器
func NewServiceDecorator(tracer *Tracer, logger *zap.Logger) *ServiceDecorator {
return &ServiceDecorator{
tracer: tracer,
logger: logger,
config: DefaultDecoratorConfig(),
}
}
// WrapService 自动包装服务,为所有方法添加链路追踪
func (d *ServiceDecorator) WrapService(service interface{}) interface{} {
serviceValue := reflect.ValueOf(service)
serviceType := reflect.TypeOf(service)
if serviceType.Kind() == reflect.Ptr {
serviceType = serviceType.Elem()
serviceValue = serviceValue.Elem()
}
// 创建代理结构
proxyType := d.createProxyType(serviceType)
proxyValue := reflect.New(proxyType).Elem()
// 设置原始服务字段
proxyValue.FieldByName("target").Set(reflect.ValueOf(service))
proxyValue.FieldByName("decorator").Set(reflect.ValueOf(d))
return proxyValue.Addr().Interface()
}
// createProxyType 创建代理类型
func (d *ServiceDecorator) createProxyType(serviceType reflect.Type) reflect.Type {
// 获取服务名称
serviceName := d.getServiceName(serviceType)
// 创建代理结构字段
fields := []reflect.StructField{
{
Name: "target",
Type: reflect.PtrTo(serviceType),
},
{
Name: "decorator",
Type: reflect.TypeOf(d),
},
}
// 为每个方法创建包装器方法
for i := 0; i < serviceType.NumMethod(); i++ {
method := serviceType.Method(i)
if d.shouldTraceMethod(method.Name) {
// 创建方法字段(用于存储方法实现)
fields = append(fields, reflect.StructField{
Name: method.Name,
Type: method.Type,
})
}
}
// 创建新的结构类型
proxyType := reflect.StructOf(fields)
// 实现接口方法
d.implementMethods(proxyType, serviceType, serviceName)
return proxyType
}
// shouldTraceMethod 判断是否应该追踪方法
func (d *ServiceDecorator) shouldTraceMethod(methodName string) bool {
if !d.config.EnableMethodTracing {
return false
}
for _, pattern := range d.config.ExcludePatterns {
if strings.Contains(methodName, pattern) {
return false
}
}
return true
}
// getServiceName 获取服务名称
func (d *ServiceDecorator) getServiceName(serviceType reflect.Type) string {
serviceName := serviceType.Name()
// 移除Service后缀
if strings.HasSuffix(serviceName, "Service") {
serviceName = strings.TrimSuffix(serviceName, "Service")
}
return strings.ToLower(serviceName)
}
// TraceMethodCall 追踪方法调用
func (d *ServiceDecorator) TraceMethodCall(
ctx context.Context,
serviceName, methodName string,
fn func(context.Context) ([]reflect.Value, error),
args []reflect.Value,
) ([]reflect.Value, error) {
// 创建span名称
spanName := fmt.Sprintf("%s.%s", serviceName, methodName)
// 开始追踪
ctx, span := d.tracer.StartSpan(ctx, spanName)
defer span.End()
// 添加基础属性
d.tracer.AddSpanAttributes(span,
attribute.String("service.name", serviceName),
attribute.String("service.method", methodName),
attribute.String("service.type", "business"),
)
// 添加参数信息(如果启用)
if d.config.IncludeArguments {
d.addArgumentAttributes(span, args)
}
// 记录开始时间
startTime := time.Now()
// 执行原始方法
results, err := fn(ctx)
// 计算执行时间
duration := time.Since(startTime)
d.tracer.AddSpanAttributes(span,
attribute.Int64("service.duration_ms", duration.Milliseconds()),
)
// 标记慢方法
if duration > d.config.SlowMethodThreshold {
d.tracer.AddSpanAttributes(span,
attribute.Bool("service.slow_method", true),
)
d.logger.Warn("慢方法检测",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Duration("duration", duration),
zap.String("trace_id", d.tracer.GetTraceID(ctx)),
)
}
// 处理错误
if err != nil {
d.tracer.SetSpanError(span, err)
d.logger.Error("服务方法执行失败",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Error(err),
zap.String("trace_id", d.tracer.GetTraceID(ctx)),
)
} else {
d.tracer.SetSpanSuccess(span)
// 添加结果信息(如果启用)
if d.config.IncludeResults {
d.addResultAttributes(span, results)
}
}
return results, err
}
// addArgumentAttributes 添加参数属性
func (d *ServiceDecorator) addArgumentAttributes(span trace.Span, args []reflect.Value) {
for i, arg := range args {
if i == 0 && arg.Type().String() == "context.Context" {
continue // 跳过context参数
}
argName := fmt.Sprintf("service.arg_%d", i)
argValue := d.extractValue(arg)
if argValue != "" && len(argValue) < 1000 { // 限制长度避免性能问题
d.tracer.AddSpanAttributes(span,
attribute.String(argName, argValue),
)
}
}
}
// addResultAttributes 添加结果属性
func (d *ServiceDecorator) addResultAttributes(span trace.Span, results []reflect.Value) {
for i, result := range results {
if result.Type().String() == "error" {
continue // 错误在其他地方处理
}
resultName := fmt.Sprintf("service.result_%d", i)
resultValue := d.extractValue(result)
if resultValue != "" && len(resultValue) < 1000 {
d.tracer.AddSpanAttributes(span,
attribute.String(resultName, resultValue),
)
}
}
}
// extractValue 提取值的字符串表示
func (d *ServiceDecorator) extractValue(value reflect.Value) string {
if !value.IsValid() {
return ""
}
switch value.Kind() {
case reflect.String:
return value.String()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return fmt.Sprintf("%d", value.Int())
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return fmt.Sprintf("%d", value.Uint())
case reflect.Float32, reflect.Float64:
return fmt.Sprintf("%.2f", value.Float())
case reflect.Bool:
return fmt.Sprintf("%t", value.Bool())
case reflect.Ptr:
if value.IsNil() {
return "nil"
}
return d.extractValue(value.Elem())
case reflect.Struct:
// 对于结构体,只返回类型名
return value.Type().Name()
case reflect.Slice, reflect.Array:
return fmt.Sprintf("[%d items]", value.Len())
default:
return value.Type().Name()
}
}
// implementMethods 实现接口方法(占位符,实际需要运行时代理)
func (d *ServiceDecorator) implementMethods(proxyType, serviceType reflect.Type, serviceName string) {
// 这里是运行时方法实现的占位符
// 实际实现需要使用reflect.MakeFunc或其他运行时代理技术
}
// GetFunctionName 获取函数名称
func GetFunctionName(fn interface{}) string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
parts := strings.Split(name, ".")
return parts[len(parts)-1]
}

View File

@@ -0,0 +1,320 @@
package tracing
import (
"context"
"fmt"
"strings"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
gormSpanKey = "otel:span"
gormOperationKey = "otel:operation"
gormTableNameKey = "otel:table_name"
gormStartTimeKey = "otel:start_time"
)
// GormTracingPlugin GORM链路追踪插件
type GormTracingPlugin struct {
tracer *Tracer
logger *zap.Logger
config GormPluginConfig
}
// GormPluginConfig GORM插件配置
type GormPluginConfig struct {
IncludeSQL bool
IncludeValues bool
SlowThreshold time.Duration
ExcludeTables []string
SanitizeSQL bool
}
// DefaultGormPluginConfig 默认GORM插件配置
func DefaultGormPluginConfig() GormPluginConfig {
return GormPluginConfig{
IncludeSQL: true,
IncludeValues: false, // 生产环境建议设为false避免记录敏感数据
SlowThreshold: 200 * time.Millisecond,
ExcludeTables: []string{"migrations", "schema_migrations"},
SanitizeSQL: true,
}
}
// NewGormTracingPlugin 创建GORM追踪插件
func NewGormTracingPlugin(tracer *Tracer, logger *zap.Logger) *GormTracingPlugin {
return &GormTracingPlugin{
tracer: tracer,
logger: logger,
config: DefaultGormPluginConfig(),
}
}
// Name 返回插件名称
func (p *GormTracingPlugin) Name() string {
return "gorm-otel-tracing"
}
// Initialize 初始化插件
func (p *GormTracingPlugin) Initialize(db *gorm.DB) error {
// 注册各种操作的回调
callbacks := []string{"create", "query", "update", "delete", "raw"}
for _, operation := range callbacks {
switch operation {
case "create":
err := db.Callback().Create().Before("gorm:create").
Register(p.Name()+":before_create", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before create callback: %w", err)
}
err = db.Callback().Create().After("gorm:create").
Register(p.Name()+":after_create", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after create callback: %w", err)
}
case "query":
err := db.Callback().Query().Before("gorm:query").
Register(p.Name()+":before_query", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before query callback: %w", err)
}
err = db.Callback().Query().After("gorm:query").
Register(p.Name()+":after_query", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after query callback: %w", err)
}
case "update":
err := db.Callback().Update().Before("gorm:update").
Register(p.Name()+":before_update", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before update callback: %w", err)
}
err = db.Callback().Update().After("gorm:update").
Register(p.Name()+":after_update", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after update callback: %w", err)
}
case "delete":
err := db.Callback().Delete().Before("gorm:delete").
Register(p.Name()+":before_delete", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before delete callback: %w", err)
}
err = db.Callback().Delete().After("gorm:delete").
Register(p.Name()+":after_delete", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after delete callback: %w", err)
}
case "raw":
err := db.Callback().Raw().Before("gorm:raw").
Register(p.Name()+":before_raw", p.beforeOperation)
if err != nil {
return fmt.Errorf("failed to register before raw callback: %w", err)
}
err = db.Callback().Raw().After("gorm:raw").
Register(p.Name()+":after_raw", p.afterOperation)
if err != nil {
return fmt.Errorf("failed to register after raw callback: %w", err)
}
}
}
p.logger.Info("GORM追踪插件已初始化")
return nil
}
// beforeOperation 操作前回调
func (p *GormTracingPlugin) beforeOperation(db *gorm.DB) {
// 检查是否应该跳过追踪
if p.shouldSkipTracing(db) {
return
}
ctx := db.Statement.Context
if ctx == nil {
ctx = context.Background()
}
// 获取操作信息
operation := p.getOperationType(db)
tableName := p.getTableName(db)
// 检查是否应该排除此表
if p.isExcludedTable(tableName) {
return
}
// 开始追踪
ctx, span := p.tracer.StartDBSpan(ctx, operation, tableName)
// 添加基础属性
p.tracer.AddSpanAttributes(span,
attribute.String("db.system", "postgresql"),
attribute.String("db.operation", operation),
)
if tableName != "" {
p.tracer.AddSpanAttributes(span, attribute.String("db.table", tableName))
}
// 保存追踪信息到GORM context
db.Set(gormSpanKey, span)
db.Set(gormOperationKey, operation)
db.Set(gormTableNameKey, tableName)
db.Set(gormStartTimeKey, time.Now())
// 更新statement context
db.Statement.Context = ctx
}
// afterOperation 操作后回调
func (p *GormTracingPlugin) afterOperation(db *gorm.DB) {
// 获取span
spanValue, exists := db.Get(gormSpanKey)
if !exists {
return
}
span, ok := spanValue.(trace.Span)
if !ok {
return
}
defer span.End()
// 获取操作信息
operation, _ := db.Get(gormOperationKey)
tableName, _ := db.Get(gormTableNameKey)
startTime, _ := db.Get(gormStartTimeKey)
// 计算执行时间
var duration time.Duration
if st, ok := startTime.(time.Time); ok {
duration = time.Since(st)
p.tracer.AddSpanAttributes(span,
attribute.Int64("db.duration_ms", duration.Milliseconds()),
)
}
// 添加SQL信息
if p.config.IncludeSQL && db.Statement.SQL.String() != "" {
sql := db.Statement.SQL.String()
if p.config.SanitizeSQL {
sql = p.sanitizeSQL(sql)
}
p.tracer.AddSpanAttributes(span, attribute.String("db.statement", sql))
}
// 添加影响行数
if db.Statement.RowsAffected >= 0 {
p.tracer.AddSpanAttributes(span,
attribute.Int64("db.rows_affected", db.Statement.RowsAffected),
)
}
// 处理错误
if db.Error != nil {
p.tracer.SetSpanError(span, db.Error)
span.SetStatus(codes.Error, db.Error.Error())
p.logger.Error("数据库操作失败",
zap.String("operation", fmt.Sprintf("%v", operation)),
zap.String("table", fmt.Sprintf("%v", tableName)),
zap.Error(db.Error),
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
)
} else {
p.tracer.SetSpanSuccess(span)
span.SetStatus(codes.Ok, "success")
// 检查慢查询
if duration > p.config.SlowThreshold {
p.tracer.AddSpanAttributes(span,
attribute.Bool("db.slow_query", true),
)
p.logger.Warn("慢SQL查询检测",
zap.String("operation", fmt.Sprintf("%v", operation)),
zap.String("table", fmt.Sprintf("%v", tableName)),
zap.Duration("duration", duration),
zap.String("sql", db.Statement.SQL.String()),
zap.String("trace_id", p.tracer.GetTraceID(db.Statement.Context)),
)
}
}
}
// shouldSkipTracing 检查是否应该跳过追踪
func (p *GormTracingPlugin) shouldSkipTracing(db *gorm.DB) bool {
// 检查是否已有span避免重复追踪
if _, exists := db.Get(gormSpanKey); exists {
return true
}
return false
}
// getOperationType 获取操作类型
func (p *GormTracingPlugin) getOperationType(db *gorm.DB) string {
switch db.Statement.ReflectValue.Kind() {
default:
sql := strings.ToUpper(strings.TrimSpace(db.Statement.SQL.String()))
if sql == "" {
return "unknown"
}
if strings.HasPrefix(sql, "SELECT") {
return "select"
} else if strings.HasPrefix(sql, "INSERT") {
return "insert"
} else if strings.HasPrefix(sql, "UPDATE") {
return "update"
} else if strings.HasPrefix(sql, "DELETE") {
return "delete"
} else if strings.HasPrefix(sql, "CREATE") {
return "create"
} else if strings.HasPrefix(sql, "DROP") {
return "drop"
} else if strings.HasPrefix(sql, "ALTER") {
return "alter"
}
return "query"
}
}
// getTableName 获取表名
func (p *GormTracingPlugin) getTableName(db *gorm.DB) string {
if db.Statement.Table != "" {
return db.Statement.Table
}
if db.Statement.Schema != nil && db.Statement.Schema.Table != "" {
return db.Statement.Schema.Table
}
return ""
}
// isExcludedTable 检查是否为排除的表
func (p *GormTracingPlugin) isExcludedTable(tableName string) bool {
for _, excluded := range p.config.ExcludeTables {
if tableName == excluded {
return true
}
}
return false
}
// sanitizeSQL 清理SQL语句移除敏感信息
func (p *GormTracingPlugin) sanitizeSQL(sql string) string {
// 简单的SQL清理将参数替换为占位符
// 在生产环境中,您可能需要更复杂的清理逻辑
return strings.ReplaceAll(sql, "'", "?")
}

View File

@@ -0,0 +1,407 @@
package tracing
import (
"context"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
"tyapi-server/internal/shared/interfaces"
)
// TracedRedisCache Redis缓存自动追踪包装器
type TracedRedisCache struct {
client redis.UniversalClient
tracer *Tracer
logger *zap.Logger
prefix string
config RedisTracingConfig
}
// RedisTracingConfig Redis追踪配置
type RedisTracingConfig struct {
IncludeKeys bool
IncludeValues bool
MaxKeyLength int
MaxValueLength int
SlowThreshold time.Duration
SanitizeValues bool
}
// DefaultRedisTracingConfig 默认Redis追踪配置
func DefaultRedisTracingConfig() RedisTracingConfig {
return RedisTracingConfig{
IncludeKeys: true,
IncludeValues: false, // 生产环境建议设为false保护敏感数据
MaxKeyLength: 100,
MaxValueLength: 1000,
SlowThreshold: 50 * time.Millisecond,
SanitizeValues: true,
}
}
// NewTracedRedisCache 创建带追踪的Redis缓存
func NewTracedRedisCache(client redis.UniversalClient, tracer *Tracer, logger *zap.Logger, prefix string) interfaces.CacheService {
return &TracedRedisCache{
client: client,
tracer: tracer,
logger: logger,
prefix: prefix,
config: DefaultRedisTracingConfig(),
}
}
// Name 返回服务名称
func (c *TracedRedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (c *TracedRedisCache) Initialize(ctx context.Context) error {
c.logger.Info("Redis缓存服务已初始化")
return nil
}
// HealthCheck 健康检查
func (c *TracedRedisCache) HealthCheck(ctx context.Context) error {
_, err := c.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (c *TracedRedisCache) Shutdown(ctx context.Context) error {
c.logger.Info("Redis缓存服务已关闭")
return c.client.Close()
}
// Get 获取缓存值
func (c *TracedRedisCache) Get(ctx context.Context, key string, dest interface{}) error {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "get", key)
defer span.End()
// 添加基础属性
c.addBaseAttributes(span, "get", key)
// 记录开始时间
startTime := time.Now()
// 构建完整键名
fullKey := c.buildKey(key)
// 执行Redis操作
result, err := c.client.Get(ctx, fullKey).Result()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
)
// 检查慢操作
if duration > c.config.SlowThreshold {
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.slow_operation", true),
)
c.logger.Warn("Redis慢操作检测",
zap.String("operation", "get"),
zap.String("key", c.sanitizeKey(key)),
zap.Duration("duration", duration),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
}
// 处理结果
if err != nil {
if err == redis.Nil {
// 缓存未命中
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.hit", false),
attribute.String("redis.result", "miss"),
)
c.tracer.SetSpanSuccess(span)
return interfaces.ErrCacheMiss
} else {
// Redis错误
c.tracer.SetSpanError(span, err)
c.logger.Error("Redis GET操作失败",
zap.String("key", c.sanitizeKey(key)),
zap.Error(err),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
return err
}
}
// 缓存命中
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.hit", true),
attribute.String("redis.result", "hit"),
attribute.Int("redis.value_size", len(result)),
)
// 反序列化
if err := c.deserialize(result, dest); err != nil {
c.tracer.SetSpanError(span, err)
return err
}
c.tracer.SetSpanSuccess(span)
return nil
}
// Set 设置缓存值
func (c *TracedRedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "set", key)
defer span.End()
// 添加基础属性
c.addBaseAttributes(span, "set", key)
// 处理TTL
var expiration time.Duration
if len(ttl) > 0 {
if duration, ok := ttl[0].(time.Duration); ok {
expiration = duration
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.ttl_seconds", int64(expiration.Seconds())),
)
}
}
// 记录开始时间
startTime := time.Now()
// 序列化值
serialized, err := c.serialize(value)
if err != nil {
c.tracer.SetSpanError(span, err)
return err
}
// 构建完整键名
fullKey := c.buildKey(key)
// 执行Redis操作
err = c.client.Set(ctx, fullKey, serialized, expiration).Err()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
attribute.Int("redis.value_size", len(serialized)),
)
// 检查慢操作
if duration > c.config.SlowThreshold {
c.tracer.AddSpanAttributes(span,
attribute.Bool("redis.slow_operation", true),
)
c.logger.Warn("Redis慢操作检测",
zap.String("operation", "set"),
zap.String("key", c.sanitizeKey(key)),
zap.Duration("duration", duration),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
}
// 处理错误
if err != nil {
c.tracer.SetSpanError(span, err)
c.logger.Error("Redis SET操作失败",
zap.String("key", c.sanitizeKey(key)),
zap.Error(err),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
return err
}
c.tracer.SetSpanSuccess(span)
return nil
}
// Delete 删除缓存
func (c *TracedRedisCache) Delete(ctx context.Context, keys ...string) error {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "delete", strings.Join(keys, ","))
defer span.End()
// 添加基础属性
c.tracer.AddSpanAttributes(span,
attribute.String("redis.operation", "delete"),
attribute.Int("redis.key_count", len(keys)),
)
// 记录开始时间
startTime := time.Now()
// 构建完整键名
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = c.buildKey(key)
}
// 执行Redis操作
deleted, err := c.client.Del(ctx, fullKeys...).Result()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
attribute.Int64("redis.deleted_count", deleted),
)
// 处理错误
if err != nil {
c.tracer.SetSpanError(span, err)
c.logger.Error("Redis DELETE操作失败",
zap.Strings("keys", c.sanitizeKeys(keys)),
zap.Error(err),
zap.String("trace_id", c.tracer.GetTraceID(ctx)),
)
return err
}
c.tracer.SetSpanSuccess(span)
return nil
}
// Exists 检查键是否存在
func (c *TracedRedisCache) Exists(ctx context.Context, key string) (bool, error) {
// 开始追踪
ctx, span := c.tracer.StartCacheSpan(ctx, "exists", key)
defer span.End()
// 添加基础属性
c.addBaseAttributes(span, "exists", key)
// 记录开始时间
startTime := time.Now()
// 构建完整键名
fullKey := c.buildKey(key)
// 执行Redis操作
count, err := c.client.Exists(ctx, fullKey).Result()
// 计算执行时间
duration := time.Since(startTime)
c.tracer.AddSpanAttributes(span,
attribute.Int64("redis.duration_ms", duration.Milliseconds()),
attribute.Bool("redis.exists", count > 0),
)
// 处理错误
if err != nil {
c.tracer.SetSpanError(span, err)
return false, err
}
c.tracer.SetSpanSuccess(span)
return count > 0, nil
}
// GetMultiple 批量获取(基础实现)
func (c *TracedRedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
result := make(map[string]interface{})
// 简单实现逐个获取实际应用中可以使用MGET优化
for _, key := range keys {
var value interface{}
if err := c.Get(ctx, key, &value); err == nil {
result[key] = value
}
}
return result, nil
}
// SetMultiple 批量设置(基础实现)
func (c *TracedRedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
// 简单实现逐个设置实际应用中可以使用pipeline优化
for key, value := range data {
if err := c.Set(ctx, key, value, ttl...); err != nil {
return err
}
}
return nil
}
// DeletePattern 按模式删除(基础实现)
func (c *TracedRedisCache) DeletePattern(ctx context.Context, pattern string) error {
// 这里需要实现模式删除逻辑
return fmt.Errorf("DeletePattern not implemented")
}
// Keys 获取匹配的键(基础实现)
func (c *TracedRedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
// 这里需要实现键匹配逻辑
return nil, fmt.Errorf("Keys not implemented")
}
// Stats 获取缓存统计(基础实现)
func (c *TracedRedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
return interfaces.CacheStats{}, fmt.Errorf("Stats not implemented")
}
// 辅助方法
// addBaseAttributes 添加基础属性
func (c *TracedRedisCache) addBaseAttributes(span trace.Span, operation, key string) {
c.tracer.AddSpanAttributes(span,
attribute.String("redis.operation", operation),
attribute.String("db.system", "redis"),
)
if c.config.IncludeKeys {
sanitizedKey := c.sanitizeKey(key)
if len(sanitizedKey) <= c.config.MaxKeyLength {
c.tracer.AddSpanAttributes(span,
attribute.String("redis.key", sanitizedKey),
)
}
}
}
// buildKey 构建完整的Redis键名
func (c *TracedRedisCache) buildKey(key string) string {
if c.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", c.prefix, key)
}
// sanitizeKey 清理键名用于日志记录
func (c *TracedRedisCache) sanitizeKey(key string) string {
if len(key) <= c.config.MaxKeyLength {
return key
}
return key[:c.config.MaxKeyLength] + "..."
}
// sanitizeKeys 批量清理键名
func (c *TracedRedisCache) sanitizeKeys(keys []string) []string {
result := make([]string, len(keys))
for i, key := range keys {
result[i] = c.sanitizeKey(key)
}
return result
}
// serialize 序列化值(简单实现)
func (c *TracedRedisCache) serialize(value interface{}) (string, error) {
// 这里应该使用JSON或其他序列化方法
return fmt.Sprintf("%v", value), nil
}
// deserialize 反序列化值(简单实现)
func (c *TracedRedisCache) deserialize(data string, dest interface{}) error {
// 这里应该实现真正的反序列化逻辑
return fmt.Errorf("deserialize not fully implemented")
}

View File

@@ -0,0 +1,189 @@
package tracing
import (
"context"
"fmt"
"time"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"tyapi-server/internal/domains/user/dto"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/shared/interfaces"
)
// ServiceWrapper 服务包装器,提供自动追踪能力
type ServiceWrapper struct {
tracer *Tracer
logger *zap.Logger
}
// NewServiceWrapper 创建服务包装器
func NewServiceWrapper(tracer *Tracer, logger *zap.Logger) *ServiceWrapper {
return &ServiceWrapper{
tracer: tracer,
logger: logger,
}
}
// TraceServiceCall 追踪服务调用的通用方法
func (w *ServiceWrapper) TraceServiceCall(
ctx context.Context,
serviceName, methodName string,
fn func(context.Context) error,
) error {
// 创建span名称
spanName := fmt.Sprintf("%s.%s", serviceName, methodName)
// 开始追踪
ctx, span := w.tracer.StartSpan(ctx, spanName)
defer span.End()
// 添加基础属性
w.tracer.AddSpanAttributes(span,
attribute.String("service.name", serviceName),
attribute.String("service.method", methodName),
attribute.String("service.type", "business"),
)
// 记录开始时间
startTime := time.Now()
// 执行原始方法
err := fn(ctx)
// 计算执行时间
duration := time.Since(startTime)
w.tracer.AddSpanAttributes(span,
attribute.Int64("service.duration_ms", duration.Milliseconds()),
)
// 标记慢方法
if duration > 100*time.Millisecond {
w.tracer.AddSpanAttributes(span,
attribute.Bool("service.slow_method", true),
)
w.logger.Warn("慢方法检测",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Duration("duration", duration),
zap.String("trace_id", w.tracer.GetTraceID(ctx)),
)
}
// 处理错误
if err != nil {
w.tracer.SetSpanError(span, err)
w.logger.Error("服务方法执行失败",
zap.String("service", serviceName),
zap.String("method", methodName),
zap.Error(err),
zap.String("trace_id", w.tracer.GetTraceID(ctx)),
)
} else {
w.tracer.SetSpanSuccess(span)
}
return err
}
// TracedUserService 自动追踪的用户服务包装器
type TracedUserService struct {
service interfaces.UserService
wrapper *ServiceWrapper
}
// NewTracedUserService 创建带追踪的用户服务
func NewTracedUserService(service interfaces.UserService, wrapper *ServiceWrapper) interfaces.UserService {
return &TracedUserService{
service: service,
wrapper: wrapper,
}
}
func (t *TracedUserService) Name() string {
return "user-service"
}
func (t *TracedUserService) Initialize(ctx context.Context) error {
return t.wrapper.TraceServiceCall(ctx, "user", "initialize", t.service.Initialize)
}
func (t *TracedUserService) HealthCheck(ctx context.Context) error {
return t.service.HealthCheck(ctx) // 不追踪健康检查
}
func (t *TracedUserService) Shutdown(ctx context.Context) error {
return t.wrapper.TraceServiceCall(ctx, "user", "shutdown", t.service.Shutdown)
}
func (t *TracedUserService) Register(ctx context.Context, req *dto.RegisterRequest) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "register", func(ctx context.Context) error {
result, err = t.service.Register(ctx, req)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}
func (t *TracedUserService) LoginWithPassword(ctx context.Context, req *dto.LoginWithPasswordRequest) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "login_password", func(ctx context.Context) error {
result, err = t.service.LoginWithPassword(ctx, req)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}
func (t *TracedUserService) LoginWithSMS(ctx context.Context, req *dto.LoginWithSMSRequest) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "login_sms", func(ctx context.Context) error {
result, err = t.service.LoginWithSMS(ctx, req)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}
func (t *TracedUserService) ChangePassword(ctx context.Context, userID string, req *dto.ChangePasswordRequest) error {
return t.wrapper.TraceServiceCall(ctx, "user", "change_password", func(ctx context.Context) error {
return t.service.ChangePassword(ctx, userID, req)
})
}
func (t *TracedUserService) GetByID(ctx context.Context, id string) (*entities.User, error) {
var result *entities.User
var err error
traceErr := t.wrapper.TraceServiceCall(ctx, "user", "get_by_id", func(ctx context.Context) error {
result, err = t.service.GetByID(ctx, id)
return err
})
if traceErr != nil {
return nil, traceErr
}
return result, err
}

View File

@@ -0,0 +1,474 @@
package tracing
import (
"context"
"fmt"
"sync"
"time"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
// TracerConfig 追踪器配置
type TracerConfig struct {
ServiceName string
ServiceVersion string
Environment string
Endpoint string
SampleRate float64
Enabled bool
}
// DefaultTracerConfig 默认追踪器配置
func DefaultTracerConfig() TracerConfig {
return TracerConfig{
ServiceName: "tyapi-server",
ServiceVersion: "1.0.0",
Environment: "development",
Endpoint: "http://localhost:4317",
SampleRate: 0.1,
Enabled: true,
}
}
// Tracer 链路追踪器
type Tracer struct {
config TracerConfig
logger *zap.Logger
provider *sdktrace.TracerProvider
tracer trace.Tracer
mutex sync.RWMutex
initialized bool
shutdown func(context.Context) error
}
// NewTracer 创建链路追踪器
func NewTracer(config TracerConfig, logger *zap.Logger) *Tracer {
return &Tracer{
config: config,
logger: logger,
}
}
// Initialize 初始化追踪器
func (t *Tracer) Initialize(ctx context.Context) error {
t.mutex.Lock()
defer t.mutex.Unlock()
if t.initialized {
return nil
}
if !t.config.Enabled {
t.logger.Info("Tracing is disabled")
return nil
}
// 创建资源
res, err := resource.New(ctx,
resource.WithAttributes(
attribute.String("service.name", t.config.ServiceName),
attribute.String("service.version", t.config.ServiceVersion),
attribute.String("environment", t.config.Environment),
),
)
if err != nil {
return fmt.Errorf("failed to create resource: %w", err)
}
// 创建采样器
sampler := sdktrace.TraceIDRatioBased(t.config.SampleRate)
// 创建导出器
var spanProcessor sdktrace.SpanProcessor
if t.config.Endpoint != "" {
// 使用OTLP gRPC导出器支持Jaeger、Tempo等
exporter, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(t.config.Endpoint),
otlptracegrpc.WithInsecure(), // 开发环境使用生产环境应配置TLS
otlptracegrpc.WithTimeout(time.Second*10),
otlptracegrpc.WithRetry(otlptracegrpc.RetryConfig{
Enabled: true,
InitialInterval: time.Millisecond * 100,
MaxInterval: time.Second * 5,
MaxElapsedTime: time.Second * 30,
}),
)
if err != nil {
t.logger.Warn("Failed to create OTLP exporter, using noop exporter",
zap.Error(err),
zap.String("endpoint", t.config.Endpoint))
spanProcessor = sdktrace.NewSimpleSpanProcessor(&noopExporter{})
} else {
// 在生产环境中使用批处理器以提高性能
spanProcessor = sdktrace.NewBatchSpanProcessor(exporter,
sdktrace.WithBatchTimeout(time.Second*5),
sdktrace.WithMaxExportBatchSize(512),
sdktrace.WithMaxQueueSize(2048),
sdktrace.WithExportTimeout(time.Second*30),
)
t.logger.Info("OTLP exporter initialized successfully",
zap.String("endpoint", t.config.Endpoint))
}
} else {
// 如果没有配置端点,使用空导出器
spanProcessor = sdktrace.NewSimpleSpanProcessor(&noopExporter{})
t.logger.Info("Using noop exporter (no endpoint configured)")
}
// 创建TracerProvider
provider := sdktrace.NewTracerProvider(
sdktrace.WithResource(res),
sdktrace.WithSampler(sampler),
sdktrace.WithSpanProcessor(spanProcessor),
)
// 设置全局TracerProvider
otel.SetTracerProvider(provider)
// 创建Tracer
tracer := provider.Tracer(t.config.ServiceName)
t.provider = provider
t.tracer = tracer
t.shutdown = func(ctx context.Context) error {
return provider.Shutdown(ctx)
}
t.initialized = true
t.logger.Info("Tracing initialized successfully",
zap.String("service", t.config.ServiceName),
zap.Float64("sample_rate", t.config.SampleRate))
return nil
}
// StartSpan 开始一个新的span
func (t *Tracer) StartSpan(ctx context.Context, name string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
if !t.initialized || !t.config.Enabled {
return ctx, trace.SpanFromContext(ctx)
}
return t.tracer.Start(ctx, name, opts...)
}
// StartHTTPSpan 开始一个HTTP span
func (t *Tracer) StartHTTPSpan(ctx context.Context, method, path string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("%s %s", method, path)
// 检查是否已有错误标记,如果有则使用"error"作为操作名
// 这样可以匹配Jaeger采样配置中的错误操作策略
if ctx.Value("otel_error_request") != nil {
spanName = "error"
}
ctx, span := t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindServer),
trace.WithAttributes(
attribute.String("http.method", method),
attribute.String("http.route", path),
),
)
// 保存原始操作名,以便在错误发生时可以更新
if ctx.Value("otel_error_request") == nil {
ctx = context.WithValue(ctx, "otel_original_operation", spanName)
}
return ctx, span
}
// StartDBSpan 开始一个数据库span
func (t *Tracer) StartDBSpan(ctx context.Context, operation, table string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("db.%s.%s", operation, table)
return t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("db.operation", operation),
attribute.String("db.table", table),
attribute.String("db.system", "postgresql"),
),
)
}
// StartCacheSpan 开始一个缓存span
func (t *Tracer) StartCacheSpan(ctx context.Context, operation, key string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("cache.%s", operation)
return t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("cache.operation", operation),
attribute.String("cache.system", "redis"),
),
)
}
// StartExternalAPISpan 开始一个外部API调用span
func (t *Tracer) StartExternalAPISpan(ctx context.Context, service, operation string) (context.Context, trace.Span) {
spanName := fmt.Sprintf("api.%s.%s", service, operation)
return t.StartSpan(ctx, spanName,
trace.WithSpanKind(trace.SpanKindClient),
trace.WithAttributes(
attribute.String("api.service", service),
attribute.String("api.operation", operation),
),
)
}
// AddSpanAttributes 添加span属性
func (t *Tracer) AddSpanAttributes(span trace.Span, attrs ...attribute.KeyValue) {
if span.IsRecording() {
span.SetAttributes(attrs...)
}
}
// SetSpanError 设置span错误
func (t *Tracer) SetSpanError(span trace.Span, err error) {
if span.IsRecording() {
span.SetStatus(codes.Error, err.Error())
span.RecordError(err)
// 将span操作名更新为"error"以匹配Jaeger采样配置
// 注意这是一种变通方法因为OpenTelemetry不支持直接更改span名称
// 我们通过添加特殊属性来标识这是一个错误span
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录错误日志包含trace ID便于关联
if t.logger != nil {
ctx := trace.ContextWithSpan(context.Background(), span)
t.logger.Error("操作发生错误",
zap.Error(err),
zap.String("trace_id", t.GetTraceID(ctx)),
zap.String("span_id", t.GetSpanID(ctx)),
)
}
}
}
// SetSpanSuccess 设置span成功
func (t *Tracer) SetSpanSuccess(span trace.Span) {
if span.IsRecording() {
span.SetStatus(codes.Ok, "success")
}
}
// SetHTTPStatus 根据HTTP状态码设置span状态
func (t *Tracer) SetHTTPStatus(span trace.Span, statusCode int) {
if !span.IsRecording() {
return
}
// 添加HTTP状态码属性
span.SetAttributes(attribute.Int("http.status_code", statusCode))
// 对于4xx和5xx错误标记为错误并应用错误采样策略
if statusCode >= 400 {
errorMsg := fmt.Sprintf("HTTP %d", statusCode)
span.SetStatus(codes.Error, errorMsg)
// 添加错误操作标记以匹配Jaeger采样配置
span.SetAttributes(
attribute.String("error.operation", "true"),
attribute.String("operation.type", "error"),
)
// 记录HTTP错误
if t.logger != nil {
ctx := trace.ContextWithSpan(context.Background(), span)
t.logger.Warn("HTTP请求错误",
zap.Int("status_code", statusCode),
zap.String("trace_id", t.GetTraceID(ctx)),
zap.String("span_id", t.GetSpanID(ctx)),
)
}
} else {
span.SetStatus(codes.Ok, "success")
}
}
// GetTraceID 获取当前上下文的trace ID
func (t *Tracer) GetTraceID(ctx context.Context) string {
span := trace.SpanFromContext(ctx)
if span.SpanContext().IsValid() {
return span.SpanContext().TraceID().String()
}
return ""
}
// GetSpanID 获取当前上下文的span ID
func (t *Tracer) GetSpanID(ctx context.Context) string {
span := trace.SpanFromContext(ctx)
if span.SpanContext().IsValid() {
return span.SpanContext().SpanID().String()
}
return ""
}
// IsTracing 检查是否正在追踪
func (t *Tracer) IsTracing(ctx context.Context) bool {
span := trace.SpanFromContext(ctx)
return span.SpanContext().IsValid() && span.IsRecording()
}
// Shutdown 关闭追踪器
func (t *Tracer) Shutdown(ctx context.Context) error {
t.mutex.Lock()
defer t.mutex.Unlock()
if !t.initialized || t.shutdown == nil {
return nil
}
err := t.shutdown(ctx)
if err != nil {
t.logger.Error("Failed to shutdown tracer", zap.Error(err))
return err
}
t.initialized = false
t.logger.Info("Tracer shutdown successfully")
return nil
}
// GetStats 获取追踪统计信息
func (t *Tracer) GetStats() map[string]interface{} {
t.mutex.RLock()
defer t.mutex.RUnlock()
return map[string]interface{}{
"initialized": t.initialized,
"enabled": t.config.Enabled,
"service_name": t.config.ServiceName,
"service_version": t.config.ServiceVersion,
"environment": t.config.Environment,
"sample_rate": t.config.SampleRate,
"endpoint": t.config.Endpoint,
}
}
// 实现Service接口
// Name 返回服务名称
func (t *Tracer) Name() string {
return "tracer"
}
// HealthCheck 健康检查
func (t *Tracer) HealthCheck(ctx context.Context) error {
if !t.config.Enabled {
return nil
}
if !t.initialized {
return fmt.Errorf("tracer not initialized")
}
return nil
}
// noopExporter 简单的无操作导出器(用于演示)
type noopExporter struct{}
func (e *noopExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error {
// 在实际应用中这里应该将spans发送到Jaeger或其他追踪系统
return nil
}
func (e *noopExporter) Shutdown(ctx context.Context) error {
return nil
}
// TraceMiddleware 追踪中间件工厂
func (t *Tracer) TraceMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !t.initialized || !t.config.Enabled {
c.Next()
return
}
// 开始HTTP span
ctx, span := t.StartHTTPSpan(c.Request.Context(), c.Request.Method, c.FullPath())
defer span.End()
// 将trace ID添加到响应头
traceID := t.GetTraceID(ctx)
if traceID != "" {
c.Header("X-Trace-ID", traceID)
}
// 将span上下文存储到gin上下文
c.Request = c.Request.WithContext(ctx)
// 处理请求
c.Next()
// 设置HTTP状态码
t.SetHTTPStatus(span, c.Writer.Status())
// 添加响应信息
t.AddSpanAttributes(span,
attribute.Int("http.status_code", c.Writer.Status()),
attribute.Int("http.response_size", c.Writer.Size()),
)
// 添加错误信息
if len(c.Errors) > 0 {
errMsg := c.Errors.String()
t.SetSpanError(span, fmt.Errorf(errMsg))
}
}
}
// GinTraceMiddleware 兼容旧的方法名,保持向后兼容
func (t *Tracer) GinTraceMiddleware() gin.HandlerFunc {
return t.TraceMiddleware()
}
// WithTracing 添加追踪到上下文的辅助函数
func WithTracing(ctx context.Context, tracer *Tracer, name string) (context.Context, trace.Span) {
return tracer.StartSpan(ctx, name)
}
// TraceFunction 追踪函数执行的辅助函数
func (t *Tracer) TraceFunction(ctx context.Context, name string, fn func(context.Context) error) error {
ctx, span := t.StartSpan(ctx, name)
defer span.End()
err := fn(ctx)
if err != nil {
t.SetSpanError(span, err)
} else {
t.SetSpanSuccess(span)
}
return err
}
// TraceFunctionWithResult 追踪带返回值的函数执行
func TraceFunctionWithResult[T any](ctx context.Context, tracer *Tracer, name string, fn func(context.Context) (T, error)) (T, error) {
ctx, span := tracer.StartSpan(ctx, name)
defer span.End()
result, err := fn(ctx)
if err != nil {
tracer.SetSpanError(span, err)
} else {
tracer.SetSpanSuccess(span)
}
return result, err
}