Files
tyapi-server/internal/shared/hooks/hook_system.go

588 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package hooks
import (
"context"
"fmt"
"reflect"
"sort"
"sync"
"time"
"go.uber.org/zap"
)
// HookPriority 钩子优先级
type HookPriority int
const (
// PriorityLowest 最低优先级
PriorityLowest HookPriority = 0
// PriorityLow 低优先级
PriorityLow HookPriority = 25
// PriorityNormal 普通优先级
PriorityNormal HookPriority = 50
// PriorityHigh 高优先级
PriorityHigh HookPriority = 75
// PriorityHighest 最高优先级
PriorityHighest HookPriority = 100
)
// HookFunc 钩子函数类型
type HookFunc func(ctx context.Context, data interface{}) error
// Hook 钩子定义
type Hook struct {
Name string
Func HookFunc
Priority HookPriority
Async bool
Timeout time.Duration
}
// HookResult 钩子执行结果
type HookResult struct {
HookName string `json:"hook_name"`
Success bool `json:"success"`
Duration time.Duration `json:"duration"`
Error string `json:"error,omitempty"`
}
// HookConfig 钩子配置
type HookConfig struct {
// 默认超时时间
DefaultTimeout time.Duration
// 是否记录执行时间
TrackDuration bool
// 错误处理策略
ErrorStrategy ErrorStrategy
}
// ErrorStrategy 错误处理策略
type ErrorStrategy int
const (
// ContinueOnError 遇到错误继续执行
ContinueOnError ErrorStrategy = iota
// StopOnError 遇到错误停止执行
StopOnError
// CollectErrors 收集所有错误
CollectErrors
)
// DefaultHookConfig 默认钩子配置
func DefaultHookConfig() HookConfig {
return HookConfig{
DefaultTimeout: 30 * time.Second,
TrackDuration: true,
ErrorStrategy: ContinueOnError,
}
}
// HookSystem 钩子系统
type HookSystem struct {
hooks map[string][]*Hook
config HookConfig
logger *zap.Logger
mutex sync.RWMutex
stats map[string]*HookStats
}
// HookStats 钩子统计
type HookStats struct {
TotalExecutions int `json:"total_executions"`
Successes int `json:"successes"`
Failures int `json:"failures"`
TotalDuration time.Duration `json:"total_duration"`
AverageDuration time.Duration `json:"average_duration"`
LastExecution time.Time `json:"last_execution"`
LastError string `json:"last_error,omitempty"`
}
// NewHookSystem 创建钩子系统
func NewHookSystem(config HookConfig, logger *zap.Logger) *HookSystem {
return &HookSystem{
hooks: make(map[string][]*Hook),
config: config,
logger: logger,
stats: make(map[string]*HookStats),
}
}
// Register 注册钩子
func (hs *HookSystem) Register(event string, hook *Hook) error {
if hook.Name == "" {
return fmt.Errorf("hook name cannot be empty")
}
if hook.Func == nil {
return fmt.Errorf("hook function cannot be nil")
}
if hook.Timeout == 0 {
hook.Timeout = hs.config.DefaultTimeout
}
hs.mutex.Lock()
defer hs.mutex.Unlock()
// 检查是否已经注册了同名钩子
for _, existingHook := range hs.hooks[event] {
if existingHook.Name == hook.Name {
return fmt.Errorf("hook %s already registered for event %s", hook.Name, event)
}
}
hs.hooks[event] = append(hs.hooks[event], hook)
// 按优先级排序
sort.Slice(hs.hooks[event], func(i, j int) bool {
return hs.hooks[event][i].Priority > hs.hooks[event][j].Priority
})
// 初始化统计
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
hs.stats[hookKey] = &HookStats{}
hs.logger.Info("Registered hook",
zap.String("event", event),
zap.String("hook_name", hook.Name),
zap.Int("priority", int(hook.Priority)),
zap.Bool("async", hook.Async))
return nil
}
// RegisterFunc 注册钩子函数(简化版)
func (hs *HookSystem) RegisterFunc(event, name string, priority HookPriority, fn HookFunc) error {
hook := &Hook{
Name: name,
Func: fn,
Priority: priority,
Async: false,
Timeout: hs.config.DefaultTimeout,
}
return hs.Register(event, hook)
}
// RegisterAsyncFunc 注册异步钩子函数
func (hs *HookSystem) RegisterAsyncFunc(event, name string, priority HookPriority, fn HookFunc) error {
hook := &Hook{
Name: name,
Func: fn,
Priority: priority,
Async: true,
Timeout: hs.config.DefaultTimeout,
}
return hs.Register(event, hook)
}
// Unregister 取消注册钩子
func (hs *HookSystem) Unregister(event, hookName string) error {
hs.mutex.Lock()
defer hs.mutex.Unlock()
hooks := hs.hooks[event]
for i, hook := range hooks {
if hook.Name == hookName {
// 删除钩子
hs.hooks[event] = append(hooks[:i], hooks[i+1:]...)
// 删除统计
hookKey := fmt.Sprintf("%s.%s", event, hookName)
delete(hs.stats, hookKey)
hs.logger.Info("Unregistered hook",
zap.String("event", event),
zap.String("hook_name", hookName))
return nil
}
}
return fmt.Errorf("hook %s not found for event %s", hookName, event)
}
// Trigger 触发事件
func (hs *HookSystem) Trigger(ctx context.Context, event string, data interface{}) ([]HookResult, error) {
hs.mutex.RLock()
hooks := make([]*Hook, len(hs.hooks[event]))
copy(hooks, hs.hooks[event])
hs.mutex.RUnlock()
if len(hooks) == 0 {
hs.logger.Debug("No hooks registered for event", zap.String("event", event))
return nil, nil
}
hs.logger.Debug("Triggering event",
zap.String("event", event),
zap.Int("hook_count", len(hooks)))
results := make([]HookResult, 0, len(hooks))
var errors []error
for _, hook := range hooks {
result := hs.executeHook(ctx, event, hook, data)
results = append(results, result)
if !result.Success {
err := fmt.Errorf("hook %s failed: %s", hook.Name, result.Error)
errors = append(errors, err)
// 根据错误策略决定是否继续
if hs.config.ErrorStrategy == StopOnError {
break
}
}
}
// 处理错误
if len(errors) > 0 {
switch hs.config.ErrorStrategy {
case StopOnError:
return results, errors[0]
case CollectErrors:
return results, fmt.Errorf("multiple hook errors: %v", errors)
case ContinueOnError:
// 继续执行,但记录错误
hs.logger.Warn("Some hooks failed but continuing execution",
zap.String("event", event),
zap.Int("error_count", len(errors)))
}
}
return results, nil
}
// executeHook 执行单个钩子
func (hs *HookSystem) executeHook(ctx context.Context, event string, hook *Hook, data interface{}) HookResult {
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
start := time.Now()
result := HookResult{
HookName: hook.Name,
Success: false,
}
// 更新统计
defer func() {
result.Duration = time.Since(start)
hs.updateStats(hookKey, result)
}()
if hook.Async {
// 异步执行
go func() {
hs.doExecuteHook(ctx, hook, data)
}()
result.Success = true // 异步执行总是认为成功
return result
}
// 同步执行
err := hs.doExecuteHook(ctx, hook, data)
if err != nil {
result.Error = err.Error()
hs.logger.Error("Hook execution failed",
zap.String("event", event),
zap.String("hook_name", hook.Name),
zap.Error(err))
} else {
result.Success = true
hs.logger.Debug("Hook executed successfully",
zap.String("event", event),
zap.String("hook_name", hook.Name))
}
return result
}
// doExecuteHook 实际执行钩子
func (hs *HookSystem) doExecuteHook(ctx context.Context, hook *Hook, data interface{}) error {
// 设置超时上下文
hookCtx, cancel := context.WithTimeout(ctx, hook.Timeout)
defer cancel()
// 在goroutine中执行以便处理超时
errChan := make(chan error, 1)
go func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("hook panicked: %v", r)
}
}()
errChan <- hook.Func(hookCtx, data)
}()
select {
case err := <-errChan:
return err
case <-hookCtx.Done():
return fmt.Errorf("hook execution timeout after %v", hook.Timeout)
}
}
// updateStats 更新统计信息
func (hs *HookSystem) updateStats(hookKey string, result HookResult) {
hs.mutex.Lock()
defer hs.mutex.Unlock()
stats, exists := hs.stats[hookKey]
if !exists {
stats = &HookStats{}
hs.stats[hookKey] = stats
}
stats.TotalExecutions++
stats.LastExecution = time.Now()
if result.Success {
stats.Successes++
} else {
stats.Failures++
stats.LastError = result.Error
}
if hs.config.TrackDuration {
stats.TotalDuration += result.Duration
stats.AverageDuration = stats.TotalDuration / time.Duration(stats.TotalExecutions)
}
}
// GetHooks 获取事件的所有钩子
func (hs *HookSystem) GetHooks(event string) []*Hook {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
hooks := make([]*Hook, len(hs.hooks[event]))
copy(hooks, hs.hooks[event])
return hooks
}
// GetEvents 获取所有注册的事件
func (hs *HookSystem) GetEvents() []string {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
events := make([]string, 0, len(hs.hooks))
for event := range hs.hooks {
events = append(events, event)
}
sort.Strings(events)
return events
}
// GetStats 获取钩子统计信息
func (hs *HookSystem) GetStats() map[string]*HookStats {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
stats := make(map[string]*HookStats)
for key, stat := range hs.stats {
statCopy := *stat
stats[key] = &statCopy
}
return stats
}
// GetEventStats 获取特定事件的统计信息
func (hs *HookSystem) GetEventStats(event string) map[string]*HookStats {
allStats := hs.GetStats()
eventStats := make(map[string]*HookStats)
prefix := event + "."
for key, stat := range allStats {
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
hookName := key[len(prefix):]
eventStats[hookName] = stat
}
}
return eventStats
}
// Clear 清除所有钩子
func (hs *HookSystem) Clear() {
hs.mutex.Lock()
defer hs.mutex.Unlock()
hs.hooks = make(map[string][]*Hook)
hs.stats = make(map[string]*HookStats)
hs.logger.Info("Cleared all hooks")
}
// ClearEvent 清除特定事件的所有钩子
func (hs *HookSystem) ClearEvent(event string) {
hs.mutex.Lock()
defer hs.mutex.Unlock()
// 删除钩子
delete(hs.hooks, event)
// 删除统计
prefix := event + "."
for key := range hs.stats {
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
delete(hs.stats, key)
}
}
hs.logger.Info("Cleared hooks for event", zap.String("event", event))
}
// Count 获取钩子总数
func (hs *HookSystem) Count() int {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
total := 0
for _, hooks := range hs.hooks {
total += len(hooks)
}
return total
}
// EventCount 获取特定事件的钩子数量
func (hs *HookSystem) EventCount(event string) int {
hs.mutex.RLock()
defer hs.mutex.RUnlock()
return len(hs.hooks[event])
}
// 实现Service接口
// Name 返回服务名称
func (hs *HookSystem) Name() string {
return "hook-system"
}
// Initialize 初始化钩子系统
func (hs *HookSystem) Initialize(ctx context.Context) error {
hs.logger.Info("Hook system initialized")
return nil
}
// Start 启动钩子系统
func (hs *HookSystem) Start(ctx context.Context) error {
hs.logger.Info("Hook system started")
return nil
}
// HealthCheck 健康检查
func (hs *HookSystem) HealthCheck(ctx context.Context) error {
return nil
}
// Shutdown 关闭钩子系统
func (hs *HookSystem) Shutdown(ctx context.Context) error {
hs.logger.Info("Hook system shutdown")
return nil
}
// 便捷方法
// OnUserCreated 用户创建事件钩子
func (hs *HookSystem) OnUserCreated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.created", name, priority, fn)
}
// OnUserUpdated 用户更新事件钩子
func (hs *HookSystem) OnUserUpdated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.updated", name, priority, fn)
}
// OnUserDeleted 用户删除事件钩子
func (hs *HookSystem) OnUserDeleted(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("user.deleted", name, priority, fn)
}
// OnOrderCreated 订单创建事件钩子
func (hs *HookSystem) OnOrderCreated(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("order.created", name, priority, fn)
}
// OnOrderCompleted 订单完成事件钩子
func (hs *HookSystem) OnOrderCompleted(name string, priority HookPriority, fn HookFunc) error {
return hs.RegisterFunc("order.completed", name, priority, fn)
}
// TriggerUserCreated 触发用户创建事件
func (hs *HookSystem) TriggerUserCreated(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.created", user)
}
// TriggerUserUpdated 触发用户更新事件
func (hs *HookSystem) TriggerUserUpdated(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.updated", user)
}
// TriggerUserDeleted 触发用户删除事件
func (hs *HookSystem) TriggerUserDeleted(ctx context.Context, user interface{}) ([]HookResult, error) {
return hs.Trigger(ctx, "user.deleted", user)
}
// HookBuilder 钩子构建器
type HookBuilder struct {
hook *Hook
}
// NewHookBuilder 创建钩子构建器
func NewHookBuilder(name string, fn HookFunc) *HookBuilder {
return &HookBuilder{
hook: &Hook{
Name: name,
Func: fn,
Priority: PriorityNormal,
Async: false,
Timeout: 30 * time.Second,
},
}
}
// WithPriority 设置优先级
func (hb *HookBuilder) WithPriority(priority HookPriority) *HookBuilder {
hb.hook.Priority = priority
return hb
}
// WithTimeout 设置超时时间
func (hb *HookBuilder) WithTimeout(timeout time.Duration) *HookBuilder {
hb.hook.Timeout = timeout
return hb
}
// Async 设置为异步执行
func (hb *HookBuilder) Async() *HookBuilder {
hb.hook.Async = true
return hb
}
// Build 构建钩子
func (hb *HookBuilder) Build() *Hook {
return hb.hook
}
// TypedHookFunc 类型化钩子函数
type TypedHookFunc[T any] func(ctx context.Context, data T) error
// RegisterTypedFunc 注册类型化钩子函数
func RegisterTypedFunc[T any](hs *HookSystem, event, name string, priority HookPriority, fn TypedHookFunc[T]) error {
hookFunc := func(ctx context.Context, data interface{}) error {
typedData, ok := data.(T)
if !ok {
return fmt.Errorf("invalid data type for hook %s, expected %s", name, reflect.TypeOf((*T)(nil)).Elem().Name())
}
return fn(ctx, typedData)
}
return hs.RegisterFunc(event, name, priority, hookFunc)
}