Files
tyapi-server/internal/infrastructure/task/implementations/task_manager.go

375 lines
12 KiB
Go
Raw Normal View History

2025-09-12 01:15:09 +08:00
package implementations
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/interfaces"
"tyapi-server/internal/infrastructure/task/repositories"
"tyapi-server/internal/infrastructure/task/types"
)
// TaskManagerImpl 任务管理器实现
type TaskManagerImpl struct {
asynqClient *asynq.Client
asyncTaskRepo repositories.AsyncTaskRepository
logger *zap.Logger
config *interfaces.TaskManagerConfig
}
// NewTaskManager 创建任务管理器
func NewTaskManager(
asynqClient *asynq.Client,
asyncTaskRepo repositories.AsyncTaskRepository,
logger *zap.Logger,
config *interfaces.TaskManagerConfig,
) interfaces.TaskManager {
return &TaskManagerImpl{
asynqClient: asynqClient,
asyncTaskRepo: asyncTaskRepo,
logger: logger,
config: config,
}
}
// CreateAndEnqueueTask 创建并入队任务
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
// 1. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存任务失败: %w", err)
}
// 2. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 3. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
}
// 4. 入队到Asynq
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
// 如果入队失败,更新任务状态为失败
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
return fmt.Errorf("任务入队失败: %w", err)
}
tm.logger.Info("任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
return nil
}
// CreateAndEnqueueDelayedTask 创建并入队延时任务
func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
// 1. 设置调度时间
scheduledAt := time.Now().Add(delay)
task.ScheduledAt = &scheduledAt
// 2. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存延时任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存延时任务失败: %w", err)
}
// 3. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 4. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
}
// 5. 入队到Asynq延时队列
if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
// 如果入队失败,更新任务状态为失败
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
return fmt.Errorf("延时任务入队失败: %w", err)
}
tm.logger.Info("延时任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type),
zap.Duration("delay", delay))
return nil
}
// CancelTask 取消任务
func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
task, err := tm.findTask(ctx, taskID)
if err != nil {
return err
}
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("更新任务状态为取消失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务状态失败: %w", err)
}
tm.logger.Info("任务已标记为取消",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
return nil
}
// UpdateTaskSchedule 更新任务调度时间
func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
// 1. 查找任务
task, err := tm.findTask(ctx, taskID)
if err != nil {
return err
}
tm.logger.Info("找到要更新的任务",
zap.String("task_id", task.ID),
zap.String("current_status", string(task.Status)),
zap.Time("current_scheduled_at", *task.ScheduledAt))
// 2. 取消旧任务
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("取消旧任务失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("取消旧任务失败: %w", err)
}
tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
// 3. 创建并保存新任务
newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
if err != nil {
return err
}
tm.logger.Info("新任务已创建",
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
// 4. 计算延时并入队
delay := newScheduledAt.Sub(time.Now())
if delay < 0 {
delay = 0 // 如果时间已过,立即执行
}
if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
// 如果入队失败,删除新创建的任务记录
tm.asyncTaskRepo.Delete(ctx, newTask.ID)
return fmt.Errorf("重新入队任务失败: %w", err)
}
tm.logger.Info("任务调度时间更新成功",
zap.String("old_task_id", task.ID),
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
return nil
}
// GetTaskStatus 获取任务状态
func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
return tm.asyncTaskRepo.GetByID(ctx, taskID)
}
// UpdateTaskStatus 更新任务状态
func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
if errorMsg != "" {
return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
}
return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
}
// RetryTask 重试任务
func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
// 1. 获取任务信息
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("获取任务信息失败: %w", err)
}
// 2. 检查是否可以重试
if !task.CanRetry() {
return fmt.Errorf("任务已达到最大重试次数")
}
// 3. 增加重试次数并重置状态
task.RetryCount++
task.Status = entities.TaskStatusPending
// 4. 更新数据库
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
return fmt.Errorf("更新任务重试次数失败: %w", err)
}
// 5. 重新入队
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
return fmt.Errorf("重试任务入队失败: %w", err)
}
tm.logger.Info("任务重试成功",
zap.String("task_id", taskID),
zap.Int("retry_count", task.RetryCount))
return nil
}
// CleanupExpiredTasks 清理过期任务
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
// TODO: 实现清理逻辑
return nil
}
// updatePayloadTaskID 更新payload中的task_id
func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
// 解析payload
var payload map[string]interface{}
if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
return fmt.Errorf("解析payload失败: %w", err)
}
// 更新task_id
payload["task_id"] = task.ID
// 重新序列化
newPayload, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("序列化payload失败: %w", err)
}
task.Payload = string(newPayload)
return nil
}
// findTask 查找任务支持taskID和articleID双重查找
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
// 先尝试通过任务ID查找
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
if err == nil {
return task, nil
}
// 如果通过任务ID找不到尝试通过文章ID查找
tm.logger.Info("通过任务ID查找失败尝试通过文章ID查找", zap.String("task_id", taskID))
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
if err != nil || len(tasks) == 0 {
tm.logger.Error("通过文章ID也找不到任务",
zap.String("article_id", taskID),
zap.Error(err))
return nil, fmt.Errorf("获取任务信息失败: %w", err)
}
// 使用找到的第一个任务
task = tasks[0]
tm.logger.Info("通过文章ID找到任务",
zap.String("article_id", taskID),
zap.String("task_id", task.ID))
return task, nil
}
// createAndSaveTask 创建并保存新任务
func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
// 创建新任务
newTask := &entities.AsyncTask{
Type: originalTask.Type,
Payload: originalTask.Payload,
Status: entities.TaskStatusPending,
ScheduledAt: &newScheduledAt,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 保存到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
tm.logger.Error("创建新任务失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("创建新任务失败: %w", err)
}
// 更新payload中的task_id
if err := tm.updatePayloadTaskID(newTask); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
tm.logger.Error("更新新任务payload失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
}
return newTask, nil
}
// enqueueTaskWithDelay 入队任务到Asynq支持延时
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
queueName := tm.getQueueName(task.Type)
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
var err error
if delay > 0 {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
asynq.Queue(queueName),
asynq.ProcessIn(delay))
} else {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
}
return err
}
// getQueueName 根据任务类型获取队列名称
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
switch taskType {
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
return "article"
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
return "api"
case string(types.TaskTypeCompensation):
return "finance"
default:
return "default"
}
}