package implementations import ( "context" "encoding/json" "fmt" "time" "github.com/hibiken/asynq" "go.uber.org/zap" "tyapi-server/internal/infrastructure/task/entities" "tyapi-server/internal/infrastructure/task/interfaces" "tyapi-server/internal/infrastructure/task/repositories" "tyapi-server/internal/infrastructure/task/types" ) // TaskManagerImpl 任务管理器实现 type TaskManagerImpl struct { asynqClient *asynq.Client asyncTaskRepo repositories.AsyncTaskRepository logger *zap.Logger config *interfaces.TaskManagerConfig } // NewTaskManager 创建任务管理器 func NewTaskManager( asynqClient *asynq.Client, asyncTaskRepo repositories.AsyncTaskRepository, logger *zap.Logger, config *interfaces.TaskManagerConfig, ) interfaces.TaskManager { return &TaskManagerImpl{ asynqClient: asynqClient, asyncTaskRepo: asyncTaskRepo, logger: logger, config: config, } } // CreateAndEnqueueTask 创建并入队任务 func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error { // 1. 保存任务到数据库(GORM会自动生成UUID) if err := tm.asyncTaskRepo.Create(ctx, task); err != nil { tm.logger.Error("保存任务到数据库失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("保存任务失败: %w", err) } // 2. 更新payload中的task_id if err := tm.updatePayloadTaskID(task); err != nil { tm.logger.Error("更新payload中的任务ID失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("更新payload中的任务ID失败: %w", err) } // 3. 更新数据库中的payload if err := tm.asyncTaskRepo.Update(ctx, task); err != nil { tm.logger.Error("更新任务payload失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("更新任务payload失败: %w", err) } // 4. 入队到Asynq if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil { // 如果入队失败,更新任务状态为失败 tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败") return fmt.Errorf("任务入队失败: %w", err) } tm.logger.Info("任务创建并入队成功", zap.String("task_id", task.ID), zap.String("task_type", task.Type)) return nil } // CreateAndEnqueueDelayedTask 创建并入队延时任务 func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error { // 1. 设置调度时间 scheduledAt := time.Now().Add(delay) task.ScheduledAt = &scheduledAt // 2. 保存任务到数据库(GORM会自动生成UUID) if err := tm.asyncTaskRepo.Create(ctx, task); err != nil { tm.logger.Error("保存延时任务到数据库失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("保存延时任务失败: %w", err) } // 3. 更新payload中的task_id if err := tm.updatePayloadTaskID(task); err != nil { tm.logger.Error("更新payload中的任务ID失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("更新payload中的任务ID失败: %w", err) } // 4. 更新数据库中的payload if err := tm.asyncTaskRepo.Update(ctx, task); err != nil { tm.logger.Error("更新任务payload失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("更新任务payload失败: %w", err) } // 5. 入队到Asynq延时队列 if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil { // 如果入队失败,更新任务状态为失败 tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败") return fmt.Errorf("延时任务入队失败: %w", err) } tm.logger.Info("延时任务创建并入队成功", zap.String("task_id", task.ID), zap.String("task_type", task.Type), zap.Duration("delay", delay)) return nil } // CancelTask 取消任务 func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error { task, err := tm.findTask(ctx, taskID) if err != nil { return err } if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil { tm.logger.Error("更新任务状态为取消失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("更新任务状态失败: %w", err) } tm.logger.Info("任务已标记为取消", zap.String("task_id", task.ID), zap.String("task_type", task.Type)) return nil } // UpdateTaskSchedule 更新任务调度时间 func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error { // 1. 查找任务 task, err := tm.findTask(ctx, taskID) if err != nil { return err } tm.logger.Info("找到要更新的任务", zap.String("task_id", task.ID), zap.String("current_status", string(task.Status)), zap.Time("current_scheduled_at", *task.ScheduledAt)) // 2. 取消旧任务 if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil { tm.logger.Error("取消旧任务失败", zap.String("task_id", task.ID), zap.Error(err)) return fmt.Errorf("取消旧任务失败: %w", err) } tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID)) // 3. 创建并保存新任务 newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt) if err != nil { return err } tm.logger.Info("新任务已创建", zap.String("new_task_id", newTask.ID), zap.Time("new_scheduled_at", newScheduledAt)) // 4. 计算延时并入队 delay := newScheduledAt.Sub(time.Now()) if delay < 0 { delay = 0 // 如果时间已过,立即执行 } if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil { // 如果入队失败,删除新创建的任务记录 tm.asyncTaskRepo.Delete(ctx, newTask.ID) return fmt.Errorf("重新入队任务失败: %w", err) } tm.logger.Info("任务调度时间更新成功", zap.String("old_task_id", task.ID), zap.String("new_task_id", newTask.ID), zap.Time("new_scheduled_at", newScheduledAt)) return nil } // GetTaskStatus 获取任务状态 func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) { return tm.asyncTaskRepo.GetByID(ctx, taskID) } // UpdateTaskStatus 更新任务状态 func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error { if errorMsg != "" { return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg) } return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status) } // RetryTask 重试任务 func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error { // 1. 获取任务信息 task, err := tm.asyncTaskRepo.GetByID(ctx, taskID) if err != nil { return fmt.Errorf("获取任务信息失败: %w", err) } // 2. 检查是否可以重试 if !task.CanRetry() { return fmt.Errorf("任务已达到最大重试次数") } // 3. 增加重试次数并重置状态 task.RetryCount++ task.Status = entities.TaskStatusPending // 4. 更新数据库 if err := tm.asyncTaskRepo.Update(ctx, task); err != nil { return fmt.Errorf("更新任务重试次数失败: %w", err) } // 5. 重新入队 if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil { return fmt.Errorf("重试任务入队失败: %w", err) } tm.logger.Info("任务重试成功", zap.String("task_id", taskID), zap.Int("retry_count", task.RetryCount)) return nil } // CleanupExpiredTasks 清理过期任务 func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error { // 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务 tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan)) // TODO: 实现清理逻辑 return nil } // updatePayloadTaskID 更新payload中的task_id func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error { // 解析payload var payload map[string]interface{} if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil { return fmt.Errorf("解析payload失败: %w", err) } // 更新task_id payload["task_id"] = task.ID // 重新序列化 newPayload, err := json.Marshal(payload) if err != nil { return fmt.Errorf("序列化payload失败: %w", err) } task.Payload = string(newPayload) return nil } // findTask 查找任务(支持taskID和articleID双重查找) func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) { // 先尝试通过任务ID查找 task, err := tm.asyncTaskRepo.GetByID(ctx, taskID) if err == nil { return task, nil } // 如果通过任务ID找不到,尝试通过文章ID查找 tm.logger.Info("通过任务ID查找失败,尝试通过文章ID查找", zap.String("task_id", taskID)) tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID) if err != nil || len(tasks) == 0 { tm.logger.Error("通过文章ID也找不到任务", zap.String("article_id", taskID), zap.Error(err)) return nil, fmt.Errorf("获取任务信息失败: %w", err) } // 使用找到的第一个任务 task = tasks[0] tm.logger.Info("通过文章ID找到任务", zap.String("article_id", taskID), zap.String("task_id", task.ID)) return task, nil } // createAndSaveTask 创建并保存新任务 func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) { // 创建新任务 newTask := &entities.AsyncTask{ Type: originalTask.Type, Payload: originalTask.Payload, Status: entities.TaskStatusPending, ScheduledAt: &newScheduledAt, CreatedAt: time.Now(), UpdatedAt: time.Now(), } // 保存到数据库(GORM会自动生成UUID) if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil { tm.logger.Error("创建新任务失败", zap.String("new_task_id", newTask.ID), zap.Error(err)) return nil, fmt.Errorf("创建新任务失败: %w", err) } // 更新payload中的task_id if err := tm.updatePayloadTaskID(newTask); err != nil { tm.logger.Error("更新payload中的任务ID失败", zap.String("new_task_id", newTask.ID), zap.Error(err)) return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err) } // 更新数据库中的payload if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil { tm.logger.Error("更新新任务payload失败", zap.String("new_task_id", newTask.ID), zap.Error(err)) return nil, fmt.Errorf("更新新任务payload失败: %w", err) } return newTask, nil } // enqueueTaskWithDelay 入队任务到Asynq(支持延时) func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error { queueName := tm.getQueueName(task.Type) asynqTask := asynq.NewTask(task.Type, []byte(task.Payload)) var err error if delay > 0 { _, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName), asynq.ProcessIn(delay)) } else { _, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName)) } return err } // getQueueName 根据任务类型获取队列名称 func (tm *TaskManagerImpl) getQueueName(taskType string) string { switch taskType { case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify): return "article" case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats): return "api" case string(types.TaskTypeCompensation): return "finance" default: return "default" } }