Files
tyapi-server/internal/infrastructure/task/implementations/task_manager.go
2025-09-12 01:15:09 +08:00

375 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package implementations
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/interfaces"
"tyapi-server/internal/infrastructure/task/repositories"
"tyapi-server/internal/infrastructure/task/types"
)
// TaskManagerImpl 任务管理器实现
type TaskManagerImpl struct {
asynqClient *asynq.Client
asyncTaskRepo repositories.AsyncTaskRepository
logger *zap.Logger
config *interfaces.TaskManagerConfig
}
// NewTaskManager 创建任务管理器
func NewTaskManager(
asynqClient *asynq.Client,
asyncTaskRepo repositories.AsyncTaskRepository,
logger *zap.Logger,
config *interfaces.TaskManagerConfig,
) interfaces.TaskManager {
return &TaskManagerImpl{
asynqClient: asynqClient,
asyncTaskRepo: asyncTaskRepo,
logger: logger,
config: config,
}
}
// CreateAndEnqueueTask 创建并入队任务
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
// 1. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存任务失败: %w", err)
}
// 2. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 3. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
}
// 4. 入队到Asynq
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
// 如果入队失败,更新任务状态为失败
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
return fmt.Errorf("任务入队失败: %w", err)
}
tm.logger.Info("任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
return nil
}
// CreateAndEnqueueDelayedTask 创建并入队延时任务
func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
// 1. 设置调度时间
scheduledAt := time.Now().Add(delay)
task.ScheduledAt = &scheduledAt
// 2. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存延时任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存延时任务失败: %w", err)
}
// 3. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 4. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
}
// 5. 入队到Asynq延时队列
if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
// 如果入队失败,更新任务状态为失败
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
return fmt.Errorf("延时任务入队失败: %w", err)
}
tm.logger.Info("延时任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type),
zap.Duration("delay", delay))
return nil
}
// CancelTask 取消任务
func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
task, err := tm.findTask(ctx, taskID)
if err != nil {
return err
}
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("更新任务状态为取消失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务状态失败: %w", err)
}
tm.logger.Info("任务已标记为取消",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
return nil
}
// UpdateTaskSchedule 更新任务调度时间
func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
// 1. 查找任务
task, err := tm.findTask(ctx, taskID)
if err != nil {
return err
}
tm.logger.Info("找到要更新的任务",
zap.String("task_id", task.ID),
zap.String("current_status", string(task.Status)),
zap.Time("current_scheduled_at", *task.ScheduledAt))
// 2. 取消旧任务
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("取消旧任务失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("取消旧任务失败: %w", err)
}
tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
// 3. 创建并保存新任务
newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
if err != nil {
return err
}
tm.logger.Info("新任务已创建",
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
// 4. 计算延时并入队
delay := newScheduledAt.Sub(time.Now())
if delay < 0 {
delay = 0 // 如果时间已过,立即执行
}
if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
// 如果入队失败,删除新创建的任务记录
tm.asyncTaskRepo.Delete(ctx, newTask.ID)
return fmt.Errorf("重新入队任务失败: %w", err)
}
tm.logger.Info("任务调度时间更新成功",
zap.String("old_task_id", task.ID),
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
return nil
}
// GetTaskStatus 获取任务状态
func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
return tm.asyncTaskRepo.GetByID(ctx, taskID)
}
// UpdateTaskStatus 更新任务状态
func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
if errorMsg != "" {
return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
}
return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
}
// RetryTask 重试任务
func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
// 1. 获取任务信息
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("获取任务信息失败: %w", err)
}
// 2. 检查是否可以重试
if !task.CanRetry() {
return fmt.Errorf("任务已达到最大重试次数")
}
// 3. 增加重试次数并重置状态
task.RetryCount++
task.Status = entities.TaskStatusPending
// 4. 更新数据库
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
return fmt.Errorf("更新任务重试次数失败: %w", err)
}
// 5. 重新入队
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
return fmt.Errorf("重试任务入队失败: %w", err)
}
tm.logger.Info("任务重试成功",
zap.String("task_id", taskID),
zap.Int("retry_count", task.RetryCount))
return nil
}
// CleanupExpiredTasks 清理过期任务
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
// TODO: 实现清理逻辑
return nil
}
// updatePayloadTaskID 更新payload中的task_id
func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
// 解析payload
var payload map[string]interface{}
if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
return fmt.Errorf("解析payload失败: %w", err)
}
// 更新task_id
payload["task_id"] = task.ID
// 重新序列化
newPayload, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("序列化payload失败: %w", err)
}
task.Payload = string(newPayload)
return nil
}
// findTask 查找任务支持taskID和articleID双重查找
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
// 先尝试通过任务ID查找
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
if err == nil {
return task, nil
}
// 如果通过任务ID找不到尝试通过文章ID查找
tm.logger.Info("通过任务ID查找失败尝试通过文章ID查找", zap.String("task_id", taskID))
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
if err != nil || len(tasks) == 0 {
tm.logger.Error("通过文章ID也找不到任务",
zap.String("article_id", taskID),
zap.Error(err))
return nil, fmt.Errorf("获取任务信息失败: %w", err)
}
// 使用找到的第一个任务
task = tasks[0]
tm.logger.Info("通过文章ID找到任务",
zap.String("article_id", taskID),
zap.String("task_id", task.ID))
return task, nil
}
// createAndSaveTask 创建并保存新任务
func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
// 创建新任务
newTask := &entities.AsyncTask{
Type: originalTask.Type,
Payload: originalTask.Payload,
Status: entities.TaskStatusPending,
ScheduledAt: &newScheduledAt,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 保存到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
tm.logger.Error("创建新任务失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("创建新任务失败: %w", err)
}
// 更新payload中的task_id
if err := tm.updatePayloadTaskID(newTask); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
tm.logger.Error("更新新任务payload失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
}
return newTask, nil
}
// enqueueTaskWithDelay 入队任务到Asynq支持延时
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
queueName := tm.getQueueName(task.Type)
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
var err error
if delay > 0 {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
asynq.Queue(queueName),
asynq.ProcessIn(delay))
} else {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
}
return err
}
// getQueueName 根据任务类型获取队列名称
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
switch taskType {
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
return "article"
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
return "api"
case string(types.TaskTypeCompensation):
return "finance"
default:
return "default"
}
}