588 lines
14 KiB
Go
588 lines
14 KiB
Go
package hooks
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"reflect"
|
||
"sort"
|
||
"sync"
|
||
"time"
|
||
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// HookPriority 钩子优先级
|
||
type HookPriority int
|
||
|
||
const (
|
||
// PriorityLowest 最低优先级
|
||
PriorityLowest HookPriority = 0
|
||
// PriorityLow 低优先级
|
||
PriorityLow HookPriority = 25
|
||
// PriorityNormal 普通优先级
|
||
PriorityNormal HookPriority = 50
|
||
// PriorityHigh 高优先级
|
||
PriorityHigh HookPriority = 75
|
||
// PriorityHighest 最高优先级
|
||
PriorityHighest HookPriority = 100
|
||
)
|
||
|
||
// HookFunc 钩子函数类型
|
||
type HookFunc func(ctx context.Context, data interface{}) error
|
||
|
||
// Hook 钩子定义
|
||
type Hook struct {
|
||
Name string
|
||
Func HookFunc
|
||
Priority HookPriority
|
||
Async bool
|
||
Timeout time.Duration
|
||
}
|
||
|
||
// HookResult 钩子执行结果
|
||
type HookResult struct {
|
||
HookName string `json:"hook_name"`
|
||
Success bool `json:"success"`
|
||
Duration time.Duration `json:"duration"`
|
||
Error string `json:"error,omitempty"`
|
||
}
|
||
|
||
// HookConfig 钩子配置
|
||
type HookConfig struct {
|
||
// 默认超时时间
|
||
DefaultTimeout time.Duration
|
||
// 是否记录执行时间
|
||
TrackDuration bool
|
||
// 错误处理策略
|
||
ErrorStrategy ErrorStrategy
|
||
}
|
||
|
||
// ErrorStrategy 错误处理策略
|
||
type ErrorStrategy int
|
||
|
||
const (
|
||
// ContinueOnError 遇到错误继续执行
|
||
ContinueOnError ErrorStrategy = iota
|
||
// StopOnError 遇到错误停止执行
|
||
StopOnError
|
||
// CollectErrors 收集所有错误
|
||
CollectErrors
|
||
)
|
||
|
||
// DefaultHookConfig 默认钩子配置
|
||
func DefaultHookConfig() HookConfig {
|
||
return HookConfig{
|
||
DefaultTimeout: 30 * time.Second,
|
||
TrackDuration: true,
|
||
ErrorStrategy: ContinueOnError,
|
||
}
|
||
}
|
||
|
||
// HookSystem 钩子系统
|
||
type HookSystem struct {
|
||
hooks map[string][]*Hook
|
||
config HookConfig
|
||
logger *zap.Logger
|
||
mutex sync.RWMutex
|
||
stats map[string]*HookStats
|
||
}
|
||
|
||
// HookStats 钩子统计
|
||
type HookStats struct {
|
||
TotalExecutions int `json:"total_executions"`
|
||
Successes int `json:"successes"`
|
||
Failures int `json:"failures"`
|
||
TotalDuration time.Duration `json:"total_duration"`
|
||
AverageDuration time.Duration `json:"average_duration"`
|
||
LastExecution time.Time `json:"last_execution"`
|
||
LastError string `json:"last_error,omitempty"`
|
||
}
|
||
|
||
// NewHookSystem 创建钩子系统
|
||
func NewHookSystem(config HookConfig, logger *zap.Logger) *HookSystem {
|
||
return &HookSystem{
|
||
hooks: make(map[string][]*Hook),
|
||
config: config,
|
||
logger: logger,
|
||
stats: make(map[string]*HookStats),
|
||
}
|
||
}
|
||
|
||
// Register 注册钩子
|
||
func (hs *HookSystem) Register(event string, hook *Hook) error {
|
||
if hook.Name == "" {
|
||
return fmt.Errorf("hook name cannot be empty")
|
||
}
|
||
|
||
if hook.Func == nil {
|
||
return fmt.Errorf("hook function cannot be nil")
|
||
}
|
||
|
||
if hook.Timeout == 0 {
|
||
hook.Timeout = hs.config.DefaultTimeout
|
||
}
|
||
|
||
hs.mutex.Lock()
|
||
defer hs.mutex.Unlock()
|
||
|
||
// 检查是否已经注册了同名钩子
|
||
for _, existingHook := range hs.hooks[event] {
|
||
if existingHook.Name == hook.Name {
|
||
return fmt.Errorf("hook %s already registered for event %s", hook.Name, event)
|
||
}
|
||
}
|
||
|
||
hs.hooks[event] = append(hs.hooks[event], hook)
|
||
|
||
// 按优先级排序
|
||
sort.Slice(hs.hooks[event], func(i, j int) bool {
|
||
return hs.hooks[event][i].Priority > hs.hooks[event][j].Priority
|
||
})
|
||
|
||
// 初始化统计
|
||
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
|
||
hs.stats[hookKey] = &HookStats{}
|
||
|
||
hs.logger.Info("Registered hook",
|
||
zap.String("event", event),
|
||
zap.String("hook_name", hook.Name),
|
||
zap.Int("priority", int(hook.Priority)),
|
||
zap.Bool("async", hook.Async))
|
||
|
||
return nil
|
||
}
|
||
|
||
// RegisterFunc 注册钩子函数(简化版)
|
||
func (hs *HookSystem) RegisterFunc(event, name string, priority HookPriority, fn HookFunc) error {
|
||
hook := &Hook{
|
||
Name: name,
|
||
Func: fn,
|
||
Priority: priority,
|
||
Async: false,
|
||
Timeout: hs.config.DefaultTimeout,
|
||
}
|
||
|
||
return hs.Register(event, hook)
|
||
}
|
||
|
||
// RegisterAsyncFunc 注册异步钩子函数
|
||
func (hs *HookSystem) RegisterAsyncFunc(event, name string, priority HookPriority, fn HookFunc) error {
|
||
hook := &Hook{
|
||
Name: name,
|
||
Func: fn,
|
||
Priority: priority,
|
||
Async: true,
|
||
Timeout: hs.config.DefaultTimeout,
|
||
}
|
||
|
||
return hs.Register(event, hook)
|
||
}
|
||
|
||
// Unregister 取消注册钩子
|
||
func (hs *HookSystem) Unregister(event, hookName string) error {
|
||
hs.mutex.Lock()
|
||
defer hs.mutex.Unlock()
|
||
|
||
hooks := hs.hooks[event]
|
||
for i, hook := range hooks {
|
||
if hook.Name == hookName {
|
||
// 删除钩子
|
||
hs.hooks[event] = append(hooks[:i], hooks[i+1:]...)
|
||
|
||
// 删除统计
|
||
hookKey := fmt.Sprintf("%s.%s", event, hookName)
|
||
delete(hs.stats, hookKey)
|
||
|
||
hs.logger.Info("Unregistered hook",
|
||
zap.String("event", event),
|
||
zap.String("hook_name", hookName))
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
return fmt.Errorf("hook %s not found for event %s", hookName, event)
|
||
}
|
||
|
||
// Trigger 触发事件
|
||
func (hs *HookSystem) Trigger(ctx context.Context, event string, data interface{}) ([]HookResult, error) {
|
||
hs.mutex.RLock()
|
||
hooks := make([]*Hook, len(hs.hooks[event]))
|
||
copy(hooks, hs.hooks[event])
|
||
hs.mutex.RUnlock()
|
||
|
||
if len(hooks) == 0 {
|
||
hs.logger.Debug("No hooks registered for event", zap.String("event", event))
|
||
return nil, nil
|
||
}
|
||
|
||
hs.logger.Debug("Triggering event",
|
||
zap.String("event", event),
|
||
zap.Int("hook_count", len(hooks)))
|
||
|
||
results := make([]HookResult, 0, len(hooks))
|
||
var errors []error
|
||
|
||
for _, hook := range hooks {
|
||
result := hs.executeHook(ctx, event, hook, data)
|
||
results = append(results, result)
|
||
|
||
if !result.Success {
|
||
err := fmt.Errorf("hook %s failed: %s", hook.Name, result.Error)
|
||
errors = append(errors, err)
|
||
|
||
// 根据错误策略决定是否继续
|
||
if hs.config.ErrorStrategy == StopOnError {
|
||
break
|
||
}
|
||
}
|
||
}
|
||
|
||
// 处理错误
|
||
if len(errors) > 0 {
|
||
switch hs.config.ErrorStrategy {
|
||
case StopOnError:
|
||
return results, errors[0]
|
||
case CollectErrors:
|
||
return results, fmt.Errorf("multiple hook errors: %v", errors)
|
||
case ContinueOnError:
|
||
// 继续执行,但记录错误
|
||
hs.logger.Warn("Some hooks failed but continuing execution",
|
||
zap.String("event", event),
|
||
zap.Int("error_count", len(errors)))
|
||
}
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
// executeHook 执行单个钩子
|
||
func (hs *HookSystem) executeHook(ctx context.Context, event string, hook *Hook, data interface{}) HookResult {
|
||
hookKey := fmt.Sprintf("%s.%s", event, hook.Name)
|
||
start := time.Now()
|
||
|
||
result := HookResult{
|
||
HookName: hook.Name,
|
||
Success: false,
|
||
}
|
||
|
||
// 更新统计
|
||
defer func() {
|
||
result.Duration = time.Since(start)
|
||
hs.updateStats(hookKey, result)
|
||
}()
|
||
|
||
if hook.Async {
|
||
// 异步执行
|
||
go func() {
|
||
hs.doExecuteHook(ctx, hook, data)
|
||
}()
|
||
result.Success = true // 异步执行总是认为成功
|
||
return result
|
||
}
|
||
|
||
// 同步执行
|
||
err := hs.doExecuteHook(ctx, hook, data)
|
||
if err != nil {
|
||
result.Error = err.Error()
|
||
hs.logger.Error("Hook execution failed",
|
||
zap.String("event", event),
|
||
zap.String("hook_name", hook.Name),
|
||
zap.Error(err))
|
||
} else {
|
||
result.Success = true
|
||
hs.logger.Debug("Hook executed successfully",
|
||
zap.String("event", event),
|
||
zap.String("hook_name", hook.Name))
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
// doExecuteHook 实际执行钩子
|
||
func (hs *HookSystem) doExecuteHook(ctx context.Context, hook *Hook, data interface{}) error {
|
||
// 设置超时上下文
|
||
hookCtx, cancel := context.WithTimeout(ctx, hook.Timeout)
|
||
defer cancel()
|
||
|
||
// 在goroutine中执行,以便处理超时
|
||
errChan := make(chan error, 1)
|
||
go func() {
|
||
defer func() {
|
||
if r := recover(); r != nil {
|
||
errChan <- fmt.Errorf("hook panicked: %v", r)
|
||
}
|
||
}()
|
||
|
||
errChan <- hook.Func(hookCtx, data)
|
||
}()
|
||
|
||
select {
|
||
case err := <-errChan:
|
||
return err
|
||
case <-hookCtx.Done():
|
||
return fmt.Errorf("hook execution timeout after %v", hook.Timeout)
|
||
}
|
||
}
|
||
|
||
// updateStats 更新统计信息
|
||
func (hs *HookSystem) updateStats(hookKey string, result HookResult) {
|
||
hs.mutex.Lock()
|
||
defer hs.mutex.Unlock()
|
||
|
||
stats, exists := hs.stats[hookKey]
|
||
if !exists {
|
||
stats = &HookStats{}
|
||
hs.stats[hookKey] = stats
|
||
}
|
||
|
||
stats.TotalExecutions++
|
||
stats.LastExecution = time.Now()
|
||
|
||
if result.Success {
|
||
stats.Successes++
|
||
} else {
|
||
stats.Failures++
|
||
stats.LastError = result.Error
|
||
}
|
||
|
||
if hs.config.TrackDuration {
|
||
stats.TotalDuration += result.Duration
|
||
stats.AverageDuration = stats.TotalDuration / time.Duration(stats.TotalExecutions)
|
||
}
|
||
}
|
||
|
||
// GetHooks 获取事件的所有钩子
|
||
func (hs *HookSystem) GetHooks(event string) []*Hook {
|
||
hs.mutex.RLock()
|
||
defer hs.mutex.RUnlock()
|
||
|
||
hooks := make([]*Hook, len(hs.hooks[event]))
|
||
copy(hooks, hs.hooks[event])
|
||
return hooks
|
||
}
|
||
|
||
// GetEvents 获取所有注册的事件
|
||
func (hs *HookSystem) GetEvents() []string {
|
||
hs.mutex.RLock()
|
||
defer hs.mutex.RUnlock()
|
||
|
||
events := make([]string, 0, len(hs.hooks))
|
||
for event := range hs.hooks {
|
||
events = append(events, event)
|
||
}
|
||
|
||
sort.Strings(events)
|
||
return events
|
||
}
|
||
|
||
// GetStats 获取钩子统计信息
|
||
func (hs *HookSystem) GetStats() map[string]*HookStats {
|
||
hs.mutex.RLock()
|
||
defer hs.mutex.RUnlock()
|
||
|
||
stats := make(map[string]*HookStats)
|
||
for key, stat := range hs.stats {
|
||
statCopy := *stat
|
||
stats[key] = &statCopy
|
||
}
|
||
|
||
return stats
|
||
}
|
||
|
||
// GetEventStats 获取特定事件的统计信息
|
||
func (hs *HookSystem) GetEventStats(event string) map[string]*HookStats {
|
||
allStats := hs.GetStats()
|
||
eventStats := make(map[string]*HookStats)
|
||
|
||
prefix := event + "."
|
||
for key, stat := range allStats {
|
||
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
|
||
hookName := key[len(prefix):]
|
||
eventStats[hookName] = stat
|
||
}
|
||
}
|
||
|
||
return eventStats
|
||
}
|
||
|
||
// Clear 清除所有钩子
|
||
func (hs *HookSystem) Clear() {
|
||
hs.mutex.Lock()
|
||
defer hs.mutex.Unlock()
|
||
|
||
hs.hooks = make(map[string][]*Hook)
|
||
hs.stats = make(map[string]*HookStats)
|
||
|
||
hs.logger.Info("Cleared all hooks")
|
||
}
|
||
|
||
// ClearEvent 清除特定事件的所有钩子
|
||
func (hs *HookSystem) ClearEvent(event string) {
|
||
hs.mutex.Lock()
|
||
defer hs.mutex.Unlock()
|
||
|
||
// 删除钩子
|
||
delete(hs.hooks, event)
|
||
|
||
// 删除统计
|
||
prefix := event + "."
|
||
for key := range hs.stats {
|
||
if len(key) > len(prefix) && key[:len(prefix)] == prefix {
|
||
delete(hs.stats, key)
|
||
}
|
||
}
|
||
|
||
hs.logger.Info("Cleared hooks for event", zap.String("event", event))
|
||
}
|
||
|
||
// Count 获取钩子总数
|
||
func (hs *HookSystem) Count() int {
|
||
hs.mutex.RLock()
|
||
defer hs.mutex.RUnlock()
|
||
|
||
total := 0
|
||
for _, hooks := range hs.hooks {
|
||
total += len(hooks)
|
||
}
|
||
|
||
return total
|
||
}
|
||
|
||
// EventCount 获取特定事件的钩子数量
|
||
func (hs *HookSystem) EventCount(event string) int {
|
||
hs.mutex.RLock()
|
||
defer hs.mutex.RUnlock()
|
||
|
||
return len(hs.hooks[event])
|
||
}
|
||
|
||
// 实现Service接口
|
||
|
||
// Name 返回服务名称
|
||
func (hs *HookSystem) Name() string {
|
||
return "hook-system"
|
||
}
|
||
|
||
// Initialize 初始化钩子系统
|
||
func (hs *HookSystem) Initialize(ctx context.Context) error {
|
||
hs.logger.Info("Hook system initialized")
|
||
return nil
|
||
}
|
||
|
||
// Start 启动钩子系统
|
||
func (hs *HookSystem) Start(ctx context.Context) error {
|
||
hs.logger.Info("Hook system started")
|
||
return nil
|
||
}
|
||
|
||
// HealthCheck 健康检查
|
||
func (hs *HookSystem) HealthCheck(ctx context.Context) error {
|
||
return nil
|
||
}
|
||
|
||
// Shutdown 关闭钩子系统
|
||
func (hs *HookSystem) Shutdown(ctx context.Context) error {
|
||
hs.logger.Info("Hook system shutdown")
|
||
return nil
|
||
}
|
||
|
||
// 便捷方法
|
||
|
||
// OnUserCreated 用户创建事件钩子
|
||
func (hs *HookSystem) OnUserCreated(name string, priority HookPriority, fn HookFunc) error {
|
||
return hs.RegisterFunc("user.created", name, priority, fn)
|
||
}
|
||
|
||
// OnUserUpdated 用户更新事件钩子
|
||
func (hs *HookSystem) OnUserUpdated(name string, priority HookPriority, fn HookFunc) error {
|
||
return hs.RegisterFunc("user.updated", name, priority, fn)
|
||
}
|
||
|
||
// OnUserDeleted 用户删除事件钩子
|
||
func (hs *HookSystem) OnUserDeleted(name string, priority HookPriority, fn HookFunc) error {
|
||
return hs.RegisterFunc("user.deleted", name, priority, fn)
|
||
}
|
||
|
||
// OnOrderCreated 订单创建事件钩子
|
||
func (hs *HookSystem) OnOrderCreated(name string, priority HookPriority, fn HookFunc) error {
|
||
return hs.RegisterFunc("order.created", name, priority, fn)
|
||
}
|
||
|
||
// OnOrderCompleted 订单完成事件钩子
|
||
func (hs *HookSystem) OnOrderCompleted(name string, priority HookPriority, fn HookFunc) error {
|
||
return hs.RegisterFunc("order.completed", name, priority, fn)
|
||
}
|
||
|
||
// TriggerUserCreated 触发用户创建事件
|
||
func (hs *HookSystem) TriggerUserCreated(ctx context.Context, user interface{}) ([]HookResult, error) {
|
||
return hs.Trigger(ctx, "user.created", user)
|
||
}
|
||
|
||
// TriggerUserUpdated 触发用户更新事件
|
||
func (hs *HookSystem) TriggerUserUpdated(ctx context.Context, user interface{}) ([]HookResult, error) {
|
||
return hs.Trigger(ctx, "user.updated", user)
|
||
}
|
||
|
||
// TriggerUserDeleted 触发用户删除事件
|
||
func (hs *HookSystem) TriggerUserDeleted(ctx context.Context, user interface{}) ([]HookResult, error) {
|
||
return hs.Trigger(ctx, "user.deleted", user)
|
||
}
|
||
|
||
// HookBuilder 钩子构建器
|
||
type HookBuilder struct {
|
||
hook *Hook
|
||
}
|
||
|
||
// NewHookBuilder 创建钩子构建器
|
||
func NewHookBuilder(name string, fn HookFunc) *HookBuilder {
|
||
return &HookBuilder{
|
||
hook: &Hook{
|
||
Name: name,
|
||
Func: fn,
|
||
Priority: PriorityNormal,
|
||
Async: false,
|
||
Timeout: 30 * time.Second,
|
||
},
|
||
}
|
||
}
|
||
|
||
// WithPriority 设置优先级
|
||
func (hb *HookBuilder) WithPriority(priority HookPriority) *HookBuilder {
|
||
hb.hook.Priority = priority
|
||
return hb
|
||
}
|
||
|
||
// WithTimeout 设置超时时间
|
||
func (hb *HookBuilder) WithTimeout(timeout time.Duration) *HookBuilder {
|
||
hb.hook.Timeout = timeout
|
||
return hb
|
||
}
|
||
|
||
// Async 设置为异步执行
|
||
func (hb *HookBuilder) Async() *HookBuilder {
|
||
hb.hook.Async = true
|
||
return hb
|
||
}
|
||
|
||
// Build 构建钩子
|
||
func (hb *HookBuilder) Build() *Hook {
|
||
return hb.hook
|
||
}
|
||
|
||
// TypedHookFunc 类型化钩子函数
|
||
type TypedHookFunc[T any] func(ctx context.Context, data T) error
|
||
|
||
// RegisterTypedFunc 注册类型化钩子函数
|
||
func RegisterTypedFunc[T any](hs *HookSystem, event, name string, priority HookPriority, fn TypedHookFunc[T]) error {
|
||
hookFunc := func(ctx context.Context, data interface{}) error {
|
||
typedData, ok := data.(T)
|
||
if !ok {
|
||
return fmt.Errorf("invalid data type for hook %s, expected %s", name, reflect.TypeOf((*T)(nil)).Elem().Name())
|
||
}
|
||
return fn(ctx, typedData)
|
||
}
|
||
|
||
return hs.RegisterFunc(event, name, priority, hookFunc)
|
||
}
|