375 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			375 lines
		
	
	
		
			12 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package implementations
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"encoding/json"
 | ||
| 	"fmt"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"github.com/hibiken/asynq"
 | ||
| 	"go.uber.org/zap"
 | ||
| 
 | ||
| 	"tyapi-server/internal/infrastructure/task/entities"
 | ||
| 	"tyapi-server/internal/infrastructure/task/interfaces"
 | ||
| 	"tyapi-server/internal/infrastructure/task/repositories"
 | ||
| 	"tyapi-server/internal/infrastructure/task/types"
 | ||
| )
 | ||
| 
 | ||
| // TaskManagerImpl 任务管理器实现
 | ||
| type TaskManagerImpl struct {
 | ||
| 	asynqClient    *asynq.Client
 | ||
| 	asyncTaskRepo  repositories.AsyncTaskRepository
 | ||
| 	logger         *zap.Logger
 | ||
| 	config         *interfaces.TaskManagerConfig
 | ||
| }
 | ||
| 
 | ||
| // NewTaskManager 创建任务管理器
 | ||
| func NewTaskManager(
 | ||
| 	asynqClient *asynq.Client,
 | ||
| 	asyncTaskRepo repositories.AsyncTaskRepository,
 | ||
| 	logger *zap.Logger,
 | ||
| 	config *interfaces.TaskManagerConfig,
 | ||
| ) interfaces.TaskManager {
 | ||
| 	return &TaskManagerImpl{
 | ||
| 		asynqClient:   asynqClient,
 | ||
| 		asyncTaskRepo: asyncTaskRepo,
 | ||
| 		logger:        logger,
 | ||
| 		config:        config,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // CreateAndEnqueueTask 创建并入队任务
 | ||
| func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
 | ||
| 	// 1. 保存任务到数据库(GORM会自动生成UUID)
 | ||
| 	if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
 | ||
| 		tm.logger.Error("保存任务到数据库失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("保存任务失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 更新payload中的task_id
 | ||
| 	if err := tm.updatePayloadTaskID(task); err != nil {
 | ||
| 		tm.logger.Error("更新payload中的任务ID失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("更新payload中的任务ID失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 更新数据库中的payload
 | ||
| 	if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
 | ||
| 		tm.logger.Error("更新任务payload失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("更新任务payload失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 4. 入队到Asynq
 | ||
| 	if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
 | ||
| 		// 如果入队失败,更新任务状态为失败
 | ||
| 		tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
 | ||
| 		return fmt.Errorf("任务入队失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("任务创建并入队成功", 
 | ||
| 		zap.String("task_id", task.ID),
 | ||
| 		zap.String("task_type", task.Type))
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // CreateAndEnqueueDelayedTask 创建并入队延时任务
 | ||
| func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
 | ||
| 	// 1. 设置调度时间
 | ||
| 	scheduledAt := time.Now().Add(delay)
 | ||
| 	task.ScheduledAt = &scheduledAt
 | ||
| 
 | ||
| 	// 2. 保存任务到数据库(GORM会自动生成UUID)
 | ||
| 	if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
 | ||
| 		tm.logger.Error("保存延时任务到数据库失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("保存延时任务失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 更新payload中的task_id
 | ||
| 	if err := tm.updatePayloadTaskID(task); err != nil {
 | ||
| 		tm.logger.Error("更新payload中的任务ID失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("更新payload中的任务ID失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 4. 更新数据库中的payload
 | ||
| 	if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
 | ||
| 		tm.logger.Error("更新任务payload失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("更新任务payload失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 5. 入队到Asynq延时队列
 | ||
| 	if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
 | ||
| 		// 如果入队失败,更新任务状态为失败
 | ||
| 		tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
 | ||
| 		return fmt.Errorf("延时任务入队失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("延时任务创建并入队成功", 
 | ||
| 		zap.String("task_id", task.ID),
 | ||
| 		zap.String("task_type", task.Type),
 | ||
| 		zap.Duration("delay", delay))
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // CancelTask 取消任务
 | ||
| func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
 | ||
| 	task, err := tm.findTask(ctx, taskID)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
 | ||
| 		tm.logger.Error("更新任务状态为取消失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("更新任务状态失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("任务已标记为取消", 
 | ||
| 		zap.String("task_id", task.ID),
 | ||
| 		zap.String("task_type", task.Type))
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // UpdateTaskSchedule 更新任务调度时间
 | ||
| func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
 | ||
| 	// 1. 查找任务
 | ||
| 	task, err := tm.findTask(ctx, taskID)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("找到要更新的任务", 
 | ||
| 		zap.String("task_id", task.ID),
 | ||
| 		zap.String("current_status", string(task.Status)),
 | ||
| 		zap.Time("current_scheduled_at", *task.ScheduledAt))
 | ||
| 
 | ||
| 	// 2. 取消旧任务
 | ||
| 	if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
 | ||
| 		tm.logger.Error("取消旧任务失败", 
 | ||
| 			zap.String("task_id", task.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return fmt.Errorf("取消旧任务失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
 | ||
| 
 | ||
| 	// 3. 创建并保存新任务
 | ||
| 	newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("新任务已创建", 
 | ||
| 		zap.String("new_task_id", newTask.ID),
 | ||
| 		zap.Time("new_scheduled_at", newScheduledAt))
 | ||
| 
 | ||
| 	// 4. 计算延时并入队
 | ||
| 	delay := newScheduledAt.Sub(time.Now())
 | ||
| 	if delay < 0 {
 | ||
| 		delay = 0 // 如果时间已过,立即执行
 | ||
| 	}
 | ||
| 
 | ||
| 	if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
 | ||
| 		// 如果入队失败,删除新创建的任务记录
 | ||
| 		tm.asyncTaskRepo.Delete(ctx, newTask.ID)
 | ||
| 		return fmt.Errorf("重新入队任务失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("任务调度时间更新成功", 
 | ||
| 		zap.String("old_task_id", task.ID),
 | ||
| 		zap.String("new_task_id", newTask.ID),
 | ||
| 		zap.Time("new_scheduled_at", newScheduledAt))
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // GetTaskStatus 获取任务状态
 | ||
| func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
 | ||
| 	return tm.asyncTaskRepo.GetByID(ctx, taskID)
 | ||
| }
 | ||
| 
 | ||
| // UpdateTaskStatus 更新任务状态
 | ||
| func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
 | ||
| 	if errorMsg != "" {
 | ||
| 		return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
 | ||
| 	}
 | ||
| 	return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
 | ||
| }
 | ||
| 
 | ||
| // RetryTask 重试任务
 | ||
| func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
 | ||
| 	// 1. 获取任务信息
 | ||
| 	task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("获取任务信息失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 检查是否可以重试
 | ||
| 	if !task.CanRetry() {
 | ||
| 		return fmt.Errorf("任务已达到最大重试次数")
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 增加重试次数并重置状态
 | ||
| 	task.RetryCount++
 | ||
| 	task.Status = entities.TaskStatusPending
 | ||
| 
 | ||
| 	// 4. 更新数据库
 | ||
| 	if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
 | ||
| 		return fmt.Errorf("更新任务重试次数失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 5. 重新入队
 | ||
| 	if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
 | ||
| 		return fmt.Errorf("重试任务入队失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	tm.logger.Info("任务重试成功", 
 | ||
| 		zap.String("task_id", taskID),
 | ||
| 		zap.Int("retry_count", task.RetryCount))
 | ||
| 
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // CleanupExpiredTasks 清理过期任务
 | ||
| func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
 | ||
| 	// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
 | ||
| 	tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
 | ||
| 	
 | ||
| 	// TODO: 实现清理逻辑
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // updatePayloadTaskID 更新payload中的task_id
 | ||
| func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
 | ||
| 	// 解析payload
 | ||
| 	var payload map[string]interface{}
 | ||
| 	if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
 | ||
| 		return fmt.Errorf("解析payload失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 更新task_id
 | ||
| 	payload["task_id"] = task.ID
 | ||
| 
 | ||
| 	// 重新序列化
 | ||
| 	newPayload, err := json.Marshal(payload)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("序列化payload失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	task.Payload = string(newPayload)
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| 
 | ||
| // findTask 查找任务(支持taskID和articleID双重查找)
 | ||
| func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
 | ||
| 	// 先尝试通过任务ID查找
 | ||
| 	task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
 | ||
| 	if err == nil {
 | ||
| 		return task, nil
 | ||
| 	}
 | ||
| 
 | ||
| 	// 如果通过任务ID找不到,尝试通过文章ID查找
 | ||
| 	tm.logger.Info("通过任务ID查找失败,尝试通过文章ID查找", zap.String("task_id", taskID))
 | ||
| 	
 | ||
| 	tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
 | ||
| 	if err != nil || len(tasks) == 0 {
 | ||
| 		tm.logger.Error("通过文章ID也找不到任务", 
 | ||
| 			zap.String("article_id", taskID),
 | ||
| 			zap.Error(err))
 | ||
| 		return nil, fmt.Errorf("获取任务信息失败: %w", err)
 | ||
| 	}
 | ||
| 	
 | ||
| 	// 使用找到的第一个任务
 | ||
| 	task = tasks[0]
 | ||
| 	tm.logger.Info("通过文章ID找到任务", 
 | ||
| 		zap.String("article_id", taskID),
 | ||
| 		zap.String("task_id", task.ID))
 | ||
| 	
 | ||
| 	return task, nil
 | ||
| }
 | ||
| 
 | ||
| // createAndSaveTask 创建并保存新任务
 | ||
| func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
 | ||
| 	// 创建新任务
 | ||
| 	newTask := &entities.AsyncTask{
 | ||
| 		Type:        originalTask.Type,
 | ||
| 		Payload:     originalTask.Payload,
 | ||
| 		Status:      entities.TaskStatusPending,
 | ||
| 		ScheduledAt: &newScheduledAt,
 | ||
| 		CreatedAt:   time.Now(),
 | ||
| 		UpdatedAt:   time.Now(),
 | ||
| 	}
 | ||
| 
 | ||
| 	// 保存到数据库(GORM会自动生成UUID)
 | ||
| 	if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
 | ||
| 		tm.logger.Error("创建新任务失败", 
 | ||
| 			zap.String("new_task_id", newTask.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return nil, fmt.Errorf("创建新任务失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 更新payload中的task_id
 | ||
| 	if err := tm.updatePayloadTaskID(newTask); err != nil {
 | ||
| 		tm.logger.Error("更新payload中的任务ID失败", 
 | ||
| 			zap.String("new_task_id", newTask.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 更新数据库中的payload
 | ||
| 	if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
 | ||
| 		tm.logger.Error("更新新任务payload失败", 
 | ||
| 			zap.String("new_task_id", newTask.ID),
 | ||
| 			zap.Error(err))
 | ||
| 		return nil, fmt.Errorf("更新新任务payload失败: %w", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return newTask, nil
 | ||
| }
 | ||
| 
 | ||
| // enqueueTaskWithDelay 入队任务到Asynq(支持延时)
 | ||
| func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
 | ||
| 	queueName := tm.getQueueName(task.Type)
 | ||
| 	asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
 | ||
| 	
 | ||
| 	var err error
 | ||
| 	if delay > 0 {
 | ||
| 		_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, 
 | ||
| 			asynq.Queue(queueName),
 | ||
| 			asynq.ProcessIn(delay))
 | ||
| 	} else {
 | ||
| 		_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
 | ||
| 	}
 | ||
| 	
 | ||
| 	return err
 | ||
| }
 | ||
| 
 | ||
| // getQueueName 根据任务类型获取队列名称
 | ||
| func (tm *TaskManagerImpl) getQueueName(taskType string) string {
 | ||
| 	switch taskType {
 | ||
| 	case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
 | ||
| 		return "article"
 | ||
| 	case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
 | ||
| 		return "api"
 | ||
| 	case string(types.TaskTypeCompensation):
 | ||
| 		return "finance"
 | ||
| 	default:
 | ||
| 		return "default"
 | ||
| 	}
 | ||
| }
 |