375 lines
12 KiB
Go
375 lines
12 KiB
Go
package implementations
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"time"
|
||
|
||
"github.com/hibiken/asynq"
|
||
"go.uber.org/zap"
|
||
|
||
"tyapi-server/internal/infrastructure/task/entities"
|
||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||
"tyapi-server/internal/infrastructure/task/repositories"
|
||
"tyapi-server/internal/infrastructure/task/types"
|
||
)
|
||
|
||
// TaskManagerImpl 任务管理器实现
|
||
type TaskManagerImpl struct {
|
||
asynqClient *asynq.Client
|
||
asyncTaskRepo repositories.AsyncTaskRepository
|
||
logger *zap.Logger
|
||
config *interfaces.TaskManagerConfig
|
||
}
|
||
|
||
// NewTaskManager 创建任务管理器
|
||
func NewTaskManager(
|
||
asynqClient *asynq.Client,
|
||
asyncTaskRepo repositories.AsyncTaskRepository,
|
||
logger *zap.Logger,
|
||
config *interfaces.TaskManagerConfig,
|
||
) interfaces.TaskManager {
|
||
return &TaskManagerImpl{
|
||
asynqClient: asynqClient,
|
||
asyncTaskRepo: asyncTaskRepo,
|
||
logger: logger,
|
||
config: config,
|
||
}
|
||
}
|
||
|
||
// CreateAndEnqueueTask 创建并入队任务
|
||
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
|
||
// 1. 保存任务到数据库(GORM会自动生成UUID)
|
||
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
|
||
tm.logger.Error("保存任务到数据库失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("保存任务失败: %w", err)
|
||
}
|
||
|
||
// 2. 更新payload中的task_id
|
||
if err := tm.updatePayloadTaskID(task); err != nil {
|
||
tm.logger.Error("更新payload中的任务ID失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||
}
|
||
|
||
// 3. 更新数据库中的payload
|
||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||
tm.logger.Error("更新任务payload失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("更新任务payload失败: %w", err)
|
||
}
|
||
|
||
// 4. 入队到Asynq
|
||
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
|
||
// 如果入队失败,更新任务状态为失败
|
||
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
|
||
return fmt.Errorf("任务入队失败: %w", err)
|
||
}
|
||
|
||
tm.logger.Info("任务创建并入队成功",
|
||
zap.String("task_id", task.ID),
|
||
zap.String("task_type", task.Type))
|
||
|
||
return nil
|
||
}
|
||
|
||
// CreateAndEnqueueDelayedTask 创建并入队延时任务
|
||
func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
|
||
// 1. 设置调度时间
|
||
scheduledAt := time.Now().Add(delay)
|
||
task.ScheduledAt = &scheduledAt
|
||
|
||
// 2. 保存任务到数据库(GORM会自动生成UUID)
|
||
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
|
||
tm.logger.Error("保存延时任务到数据库失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("保存延时任务失败: %w", err)
|
||
}
|
||
|
||
// 3. 更新payload中的task_id
|
||
if err := tm.updatePayloadTaskID(task); err != nil {
|
||
tm.logger.Error("更新payload中的任务ID失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||
}
|
||
|
||
// 4. 更新数据库中的payload
|
||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||
tm.logger.Error("更新任务payload失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("更新任务payload失败: %w", err)
|
||
}
|
||
|
||
// 5. 入队到Asynq延时队列
|
||
if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
|
||
// 如果入队失败,更新任务状态为失败
|
||
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
|
||
return fmt.Errorf("延时任务入队失败: %w", err)
|
||
}
|
||
|
||
tm.logger.Info("延时任务创建并入队成功",
|
||
zap.String("task_id", task.ID),
|
||
zap.String("task_type", task.Type),
|
||
zap.Duration("delay", delay))
|
||
|
||
return nil
|
||
}
|
||
|
||
// CancelTask 取消任务
|
||
func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
|
||
task, err := tm.findTask(ctx, taskID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
|
||
tm.logger.Error("更新任务状态为取消失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("更新任务状态失败: %w", err)
|
||
}
|
||
|
||
tm.logger.Info("任务已标记为取消",
|
||
zap.String("task_id", task.ID),
|
||
zap.String("task_type", task.Type))
|
||
|
||
return nil
|
||
}
|
||
|
||
// UpdateTaskSchedule 更新任务调度时间
|
||
func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||
// 1. 查找任务
|
||
task, err := tm.findTask(ctx, taskID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
tm.logger.Info("找到要更新的任务",
|
||
zap.String("task_id", task.ID),
|
||
zap.String("current_status", string(task.Status)),
|
||
zap.Time("current_scheduled_at", *task.ScheduledAt))
|
||
|
||
// 2. 取消旧任务
|
||
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
|
||
tm.logger.Error("取消旧任务失败",
|
||
zap.String("task_id", task.ID),
|
||
zap.Error(err))
|
||
return fmt.Errorf("取消旧任务失败: %w", err)
|
||
}
|
||
|
||
tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
|
||
|
||
// 3. 创建并保存新任务
|
||
newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
tm.logger.Info("新任务已创建",
|
||
zap.String("new_task_id", newTask.ID),
|
||
zap.Time("new_scheduled_at", newScheduledAt))
|
||
|
||
// 4. 计算延时并入队
|
||
delay := newScheduledAt.Sub(time.Now())
|
||
if delay < 0 {
|
||
delay = 0 // 如果时间已过,立即执行
|
||
}
|
||
|
||
if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
|
||
// 如果入队失败,删除新创建的任务记录
|
||
tm.asyncTaskRepo.Delete(ctx, newTask.ID)
|
||
return fmt.Errorf("重新入队任务失败: %w", err)
|
||
}
|
||
|
||
tm.logger.Info("任务调度时间更新成功",
|
||
zap.String("old_task_id", task.ID),
|
||
zap.String("new_task_id", newTask.ID),
|
||
zap.Time("new_scheduled_at", newScheduledAt))
|
||
|
||
return nil
|
||
}
|
||
|
||
// GetTaskStatus 获取任务状态
|
||
func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||
return tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||
}
|
||
|
||
// UpdateTaskStatus 更新任务状态
|
||
func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
|
||
if errorMsg != "" {
|
||
return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
|
||
}
|
||
return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
|
||
}
|
||
|
||
// RetryTask 重试任务
|
||
func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
|
||
// 1. 获取任务信息
|
||
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||
if err != nil {
|
||
return fmt.Errorf("获取任务信息失败: %w", err)
|
||
}
|
||
|
||
// 2. 检查是否可以重试
|
||
if !task.CanRetry() {
|
||
return fmt.Errorf("任务已达到最大重试次数")
|
||
}
|
||
|
||
// 3. 增加重试次数并重置状态
|
||
task.RetryCount++
|
||
task.Status = entities.TaskStatusPending
|
||
|
||
// 4. 更新数据库
|
||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||
return fmt.Errorf("更新任务重试次数失败: %w", err)
|
||
}
|
||
|
||
// 5. 重新入队
|
||
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
|
||
return fmt.Errorf("重试任务入队失败: %w", err)
|
||
}
|
||
|
||
tm.logger.Info("任务重试成功",
|
||
zap.String("task_id", taskID),
|
||
zap.Int("retry_count", task.RetryCount))
|
||
|
||
return nil
|
||
}
|
||
|
||
// CleanupExpiredTasks 清理过期任务
|
||
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
|
||
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
|
||
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
|
||
|
||
// TODO: 实现清理逻辑
|
||
return nil
|
||
}
|
||
|
||
// updatePayloadTaskID 更新payload中的task_id
|
||
func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
|
||
// 解析payload
|
||
var payload map[string]interface{}
|
||
if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
|
||
return fmt.Errorf("解析payload失败: %w", err)
|
||
}
|
||
|
||
// 更新task_id
|
||
payload["task_id"] = task.ID
|
||
|
||
// 重新序列化
|
||
newPayload, err := json.Marshal(payload)
|
||
if err != nil {
|
||
return fmt.Errorf("序列化payload失败: %w", err)
|
||
}
|
||
|
||
task.Payload = string(newPayload)
|
||
return nil
|
||
}
|
||
|
||
|
||
// findTask 查找任务(支持taskID和articleID双重查找)
|
||
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||
// 先尝试通过任务ID查找
|
||
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||
if err == nil {
|
||
return task, nil
|
||
}
|
||
|
||
// 如果通过任务ID找不到,尝试通过文章ID查找
|
||
tm.logger.Info("通过任务ID查找失败,尝试通过文章ID查找", zap.String("task_id", taskID))
|
||
|
||
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
|
||
if err != nil || len(tasks) == 0 {
|
||
tm.logger.Error("通过文章ID也找不到任务",
|
||
zap.String("article_id", taskID),
|
||
zap.Error(err))
|
||
return nil, fmt.Errorf("获取任务信息失败: %w", err)
|
||
}
|
||
|
||
// 使用找到的第一个任务
|
||
task = tasks[0]
|
||
tm.logger.Info("通过文章ID找到任务",
|
||
zap.String("article_id", taskID),
|
||
zap.String("task_id", task.ID))
|
||
|
||
return task, nil
|
||
}
|
||
|
||
// createAndSaveTask 创建并保存新任务
|
||
func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
|
||
// 创建新任务
|
||
newTask := &entities.AsyncTask{
|
||
Type: originalTask.Type,
|
||
Payload: originalTask.Payload,
|
||
Status: entities.TaskStatusPending,
|
||
ScheduledAt: &newScheduledAt,
|
||
CreatedAt: time.Now(),
|
||
UpdatedAt: time.Now(),
|
||
}
|
||
|
||
// 保存到数据库(GORM会自动生成UUID)
|
||
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
|
||
tm.logger.Error("创建新任务失败",
|
||
zap.String("new_task_id", newTask.ID),
|
||
zap.Error(err))
|
||
return nil, fmt.Errorf("创建新任务失败: %w", err)
|
||
}
|
||
|
||
// 更新payload中的task_id
|
||
if err := tm.updatePayloadTaskID(newTask); err != nil {
|
||
tm.logger.Error("更新payload中的任务ID失败",
|
||
zap.String("new_task_id", newTask.ID),
|
||
zap.Error(err))
|
||
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||
}
|
||
|
||
// 更新数据库中的payload
|
||
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
|
||
tm.logger.Error("更新新任务payload失败",
|
||
zap.String("new_task_id", newTask.ID),
|
||
zap.Error(err))
|
||
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
|
||
}
|
||
|
||
return newTask, nil
|
||
}
|
||
|
||
// enqueueTaskWithDelay 入队任务到Asynq(支持延时)
|
||
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
|
||
queueName := tm.getQueueName(task.Type)
|
||
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
|
||
|
||
var err error
|
||
if delay > 0 {
|
||
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
|
||
asynq.Queue(queueName),
|
||
asynq.ProcessIn(delay))
|
||
} else {
|
||
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
// getQueueName 根据任务类型获取队列名称
|
||
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
|
||
switch taskType {
|
||
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
|
||
return "article"
|
||
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
|
||
return "api"
|
||
case string(types.TaskTypeCompensation):
|
||
return "finance"
|
||
default:
|
||
return "default"
|
||
}
|
||
}
|