375 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
		
		
			
		
	
	
			375 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
|  | package implementations | |||
|  | 
 | |||
|  | import ( | |||
|  | 	"context" | |||
|  | 	"encoding/json" | |||
|  | 	"fmt" | |||
|  | 	"time" | |||
|  | 
 | |||
|  | 	"github.com/hibiken/asynq" | |||
|  | 	"go.uber.org/zap" | |||
|  | 
 | |||
|  | 	"tyapi-server/internal/infrastructure/task/entities" | |||
|  | 	"tyapi-server/internal/infrastructure/task/interfaces" | |||
|  | 	"tyapi-server/internal/infrastructure/task/repositories" | |||
|  | 	"tyapi-server/internal/infrastructure/task/types" | |||
|  | ) | |||
|  | 
 | |||
|  | // TaskManagerImpl 任务管理器实现 | |||
|  | type TaskManagerImpl struct { | |||
|  | 	asynqClient    *asynq.Client | |||
|  | 	asyncTaskRepo  repositories.AsyncTaskRepository | |||
|  | 	logger         *zap.Logger | |||
|  | 	config         *interfaces.TaskManagerConfig | |||
|  | } | |||
|  | 
 | |||
|  | // NewTaskManager 创建任务管理器 | |||
|  | func NewTaskManager( | |||
|  | 	asynqClient *asynq.Client, | |||
|  | 	asyncTaskRepo repositories.AsyncTaskRepository, | |||
|  | 	logger *zap.Logger, | |||
|  | 	config *interfaces.TaskManagerConfig, | |||
|  | ) interfaces.TaskManager { | |||
|  | 	return &TaskManagerImpl{ | |||
|  | 		asynqClient:   asynqClient, | |||
|  | 		asyncTaskRepo: asyncTaskRepo, | |||
|  | 		logger:        logger, | |||
|  | 		config:        config, | |||
|  | 	} | |||
|  | } | |||
|  | 
 | |||
|  | // CreateAndEnqueueTask 创建并入队任务 | |||
|  | func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error { | |||
|  | 	// 1. 保存任务到数据库(GORM会自动生成UUID) | |||
|  | 	if err := tm.asyncTaskRepo.Create(ctx, task); err != nil { | |||
|  | 		tm.logger.Error("保存任务到数据库失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("保存任务失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 2. 更新payload中的task_id | |||
|  | 	if err := tm.updatePayloadTaskID(task); err != nil { | |||
|  | 		tm.logger.Error("更新payload中的任务ID失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("更新payload中的任务ID失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 3. 更新数据库中的payload | |||
|  | 	if err := tm.asyncTaskRepo.Update(ctx, task); err != nil { | |||
|  | 		tm.logger.Error("更新任务payload失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("更新任务payload失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 4. 入队到Asynq | |||
|  | 	if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil { | |||
|  | 		// 如果入队失败,更新任务状态为失败 | |||
|  | 		tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败") | |||
|  | 		return fmt.Errorf("任务入队失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("任务创建并入队成功",  | |||
|  | 		zap.String("task_id", task.ID), | |||
|  | 		zap.String("task_type", task.Type)) | |||
|  | 
 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // CreateAndEnqueueDelayedTask 创建并入队延时任务 | |||
|  | func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error { | |||
|  | 	// 1. 设置调度时间 | |||
|  | 	scheduledAt := time.Now().Add(delay) | |||
|  | 	task.ScheduledAt = &scheduledAt | |||
|  | 
 | |||
|  | 	// 2. 保存任务到数据库(GORM会自动生成UUID) | |||
|  | 	if err := tm.asyncTaskRepo.Create(ctx, task); err != nil { | |||
|  | 		tm.logger.Error("保存延时任务到数据库失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("保存延时任务失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 3. 更新payload中的task_id | |||
|  | 	if err := tm.updatePayloadTaskID(task); err != nil { | |||
|  | 		tm.logger.Error("更新payload中的任务ID失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("更新payload中的任务ID失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 4. 更新数据库中的payload | |||
|  | 	if err := tm.asyncTaskRepo.Update(ctx, task); err != nil { | |||
|  | 		tm.logger.Error("更新任务payload失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("更新任务payload失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 5. 入队到Asynq延时队列 | |||
|  | 	if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil { | |||
|  | 		// 如果入队失败,更新任务状态为失败 | |||
|  | 		tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败") | |||
|  | 		return fmt.Errorf("延时任务入队失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("延时任务创建并入队成功",  | |||
|  | 		zap.String("task_id", task.ID), | |||
|  | 		zap.String("task_type", task.Type), | |||
|  | 		zap.Duration("delay", delay)) | |||
|  | 
 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // CancelTask 取消任务 | |||
|  | func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error { | |||
|  | 	task, err := tm.findTask(ctx, taskID) | |||
|  | 	if err != nil { | |||
|  | 		return err | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil { | |||
|  | 		tm.logger.Error("更新任务状态为取消失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("更新任务状态失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("任务已标记为取消",  | |||
|  | 		zap.String("task_id", task.ID), | |||
|  | 		zap.String("task_type", task.Type)) | |||
|  | 
 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // UpdateTaskSchedule 更新任务调度时间 | |||
|  | func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error { | |||
|  | 	// 1. 查找任务 | |||
|  | 	task, err := tm.findTask(ctx, taskID) | |||
|  | 	if err != nil { | |||
|  | 		return err | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("找到要更新的任务",  | |||
|  | 		zap.String("task_id", task.ID), | |||
|  | 		zap.String("current_status", string(task.Status)), | |||
|  | 		zap.Time("current_scheduled_at", *task.ScheduledAt)) | |||
|  | 
 | |||
|  | 	// 2. 取消旧任务 | |||
|  | 	if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil { | |||
|  | 		tm.logger.Error("取消旧任务失败",  | |||
|  | 			zap.String("task_id", task.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return fmt.Errorf("取消旧任务失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID)) | |||
|  | 
 | |||
|  | 	// 3. 创建并保存新任务 | |||
|  | 	newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt) | |||
|  | 	if err != nil { | |||
|  | 		return err | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("新任务已创建",  | |||
|  | 		zap.String("new_task_id", newTask.ID), | |||
|  | 		zap.Time("new_scheduled_at", newScheduledAt)) | |||
|  | 
 | |||
|  | 	// 4. 计算延时并入队 | |||
|  | 	delay := newScheduledAt.Sub(time.Now()) | |||
|  | 	if delay < 0 { | |||
|  | 		delay = 0 // 如果时间已过,立即执行 | |||
|  | 	} | |||
|  | 
 | |||
|  | 	if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil { | |||
|  | 		// 如果入队失败,删除新创建的任务记录 | |||
|  | 		tm.asyncTaskRepo.Delete(ctx, newTask.ID) | |||
|  | 		return fmt.Errorf("重新入队任务失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("任务调度时间更新成功",  | |||
|  | 		zap.String("old_task_id", task.ID), | |||
|  | 		zap.String("new_task_id", newTask.ID), | |||
|  | 		zap.Time("new_scheduled_at", newScheduledAt)) | |||
|  | 
 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // GetTaskStatus 获取任务状态 | |||
|  | func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) { | |||
|  | 	return tm.asyncTaskRepo.GetByID(ctx, taskID) | |||
|  | } | |||
|  | 
 | |||
|  | // UpdateTaskStatus 更新任务状态 | |||
|  | func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error { | |||
|  | 	if errorMsg != "" { | |||
|  | 		return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg) | |||
|  | 	} | |||
|  | 	return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status) | |||
|  | } | |||
|  | 
 | |||
|  | // RetryTask 重试任务 | |||
|  | func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error { | |||
|  | 	// 1. 获取任务信息 | |||
|  | 	task, err := tm.asyncTaskRepo.GetByID(ctx, taskID) | |||
|  | 	if err != nil { | |||
|  | 		return fmt.Errorf("获取任务信息失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 2. 检查是否可以重试 | |||
|  | 	if !task.CanRetry() { | |||
|  | 		return fmt.Errorf("任务已达到最大重试次数") | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 3. 增加重试次数并重置状态 | |||
|  | 	task.RetryCount++ | |||
|  | 	task.Status = entities.TaskStatusPending | |||
|  | 
 | |||
|  | 	// 4. 更新数据库 | |||
|  | 	if err := tm.asyncTaskRepo.Update(ctx, task); err != nil { | |||
|  | 		return fmt.Errorf("更新任务重试次数失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 5. 重新入队 | |||
|  | 	if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil { | |||
|  | 		return fmt.Errorf("重试任务入队失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	tm.logger.Info("任务重试成功",  | |||
|  | 		zap.String("task_id", taskID), | |||
|  | 		zap.Int("retry_count", task.RetryCount)) | |||
|  | 
 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // CleanupExpiredTasks 清理过期任务 | |||
|  | func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error { | |||
|  | 	// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务 | |||
|  | 	tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan)) | |||
|  | 	 | |||
|  | 	// TODO: 实现清理逻辑 | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | // updatePayloadTaskID 更新payload中的task_id | |||
|  | func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error { | |||
|  | 	// 解析payload | |||
|  | 	var payload map[string]interface{} | |||
|  | 	if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil { | |||
|  | 		return fmt.Errorf("解析payload失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 更新task_id | |||
|  | 	payload["task_id"] = task.ID | |||
|  | 
 | |||
|  | 	// 重新序列化 | |||
|  | 	newPayload, err := json.Marshal(payload) | |||
|  | 	if err != nil { | |||
|  | 		return fmt.Errorf("序列化payload失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	task.Payload = string(newPayload) | |||
|  | 	return nil | |||
|  | } | |||
|  | 
 | |||
|  | 
 | |||
|  | // findTask 查找任务(支持taskID和articleID双重查找) | |||
|  | func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) { | |||
|  | 	// 先尝试通过任务ID查找 | |||
|  | 	task, err := tm.asyncTaskRepo.GetByID(ctx, taskID) | |||
|  | 	if err == nil { | |||
|  | 		return task, nil | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 如果通过任务ID找不到,尝试通过文章ID查找 | |||
|  | 	tm.logger.Info("通过任务ID查找失败,尝试通过文章ID查找", zap.String("task_id", taskID)) | |||
|  | 	 | |||
|  | 	tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID) | |||
|  | 	if err != nil || len(tasks) == 0 { | |||
|  | 		tm.logger.Error("通过文章ID也找不到任务",  | |||
|  | 			zap.String("article_id", taskID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return nil, fmt.Errorf("获取任务信息失败: %w", err) | |||
|  | 	} | |||
|  | 	 | |||
|  | 	// 使用找到的第一个任务 | |||
|  | 	task = tasks[0] | |||
|  | 	tm.logger.Info("通过文章ID找到任务",  | |||
|  | 		zap.String("article_id", taskID), | |||
|  | 		zap.String("task_id", task.ID)) | |||
|  | 	 | |||
|  | 	return task, nil | |||
|  | } | |||
|  | 
 | |||
|  | // createAndSaveTask 创建并保存新任务 | |||
|  | func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) { | |||
|  | 	// 创建新任务 | |||
|  | 	newTask := &entities.AsyncTask{ | |||
|  | 		Type:        originalTask.Type, | |||
|  | 		Payload:     originalTask.Payload, | |||
|  | 		Status:      entities.TaskStatusPending, | |||
|  | 		ScheduledAt: &newScheduledAt, | |||
|  | 		CreatedAt:   time.Now(), | |||
|  | 		UpdatedAt:   time.Now(), | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 保存到数据库(GORM会自动生成UUID) | |||
|  | 	if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil { | |||
|  | 		tm.logger.Error("创建新任务失败",  | |||
|  | 			zap.String("new_task_id", newTask.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return nil, fmt.Errorf("创建新任务失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 更新payload中的task_id | |||
|  | 	if err := tm.updatePayloadTaskID(newTask); err != nil { | |||
|  | 		tm.logger.Error("更新payload中的任务ID失败",  | |||
|  | 			zap.String("new_task_id", newTask.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	// 更新数据库中的payload | |||
|  | 	if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil { | |||
|  | 		tm.logger.Error("更新新任务payload失败",  | |||
|  | 			zap.String("new_task_id", newTask.ID), | |||
|  | 			zap.Error(err)) | |||
|  | 		return nil, fmt.Errorf("更新新任务payload失败: %w", err) | |||
|  | 	} | |||
|  | 
 | |||
|  | 	return newTask, nil | |||
|  | } | |||
|  | 
 | |||
|  | // enqueueTaskWithDelay 入队任务到Asynq(支持延时) | |||
|  | func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error { | |||
|  | 	queueName := tm.getQueueName(task.Type) | |||
|  | 	asynqTask := asynq.NewTask(task.Type, []byte(task.Payload)) | |||
|  | 	 | |||
|  | 	var err error | |||
|  | 	if delay > 0 { | |||
|  | 		_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,  | |||
|  | 			asynq.Queue(queueName), | |||
|  | 			asynq.ProcessIn(delay)) | |||
|  | 	} else { | |||
|  | 		_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName)) | |||
|  | 	} | |||
|  | 	 | |||
|  | 	return err | |||
|  | } | |||
|  | 
 | |||
|  | // getQueueName 根据任务类型获取队列名称 | |||
|  | func (tm *TaskManagerImpl) getQueueName(taskType string) string { | |||
|  | 	switch taskType { | |||
|  | 	case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify): | |||
|  | 		return "article" | |||
|  | 	case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats): | |||
|  | 		return "api" | |||
|  | 	case string(types.TaskTypeCompensation): | |||
|  | 		return "finance" | |||
|  | 	default: | |||
|  | 		return "default" | |||
|  | 	} | |||
|  | } |