This commit is contained in:
2025-12-06 13:56:00 +08:00
parent 05b6623e75
commit 89367fb2ee
22 changed files with 2186 additions and 74 deletions

View File

@@ -53,6 +53,33 @@ func (f *TaskFactory) CreateArticlePublishTask(articleID string, publishAt time.
return task, nil
}
// CreateAnnouncementPublishTask 创建公告发布任务
func (f *TaskFactory) CreateAnnouncementPublishTask(announcementID string, publishAt time.Time, userID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeAnnouncementPublish),
Status: TaskStatusPending,
ScheduledAt: &publishAt,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务ID将在保存后更新
payloadWithID := map[string]interface{}{
"announcement_id": announcementID,
"publish_at": publishAt,
"user_id": userID,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateArticleCancelTask 创建文章取消任务
func (f *TaskFactory) CreateArticleCancelTask(articleID string, userID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
@@ -240,6 +267,32 @@ func (f *TaskFactory) CreateAndEnqueueArticlePublishTask(ctx context.Context, ar
return fmt.Errorf("TaskManager类型不匹配")
}
// CreateAndEnqueueAnnouncementPublishTask 创建并入队公告发布任务
func (f *TaskFactory) CreateAndEnqueueAnnouncementPublishTask(ctx context.Context, announcementID string, publishAt time.Time, userID string) error {
if f.taskManager == nil {
return fmt.Errorf("TaskManager未初始化")
}
task, err := f.CreateAnnouncementPublishTask(announcementID, publishAt, userID)
if err != nil {
return err
}
delay := publishAt.Sub(time.Now())
if delay < 0 {
delay = 0
}
// 使用类型断言调用TaskManager方法
if tm, ok := f.taskManager.(interface {
CreateAndEnqueueDelayedTask(ctx context.Context, task *AsyncTask, delay time.Duration) error
}); ok {
return tm.CreateAndEnqueueDelayedTask(ctx, task, delay)
}
return fmt.Errorf("TaskManager类型不匹配")
}
// CreateAndEnqueueApiLogTask 创建并入队API日志任务
func (f *TaskFactory) CreateAndEnqueueApiLogTask(ctx context.Context, transactionID string, userID string, apiName string, productID string) error {
if f.taskManager == nil {

View File

@@ -17,17 +17,24 @@ import (
// ArticleTaskHandler 文章任务处理器
type ArticleTaskHandler struct {
logger *zap.Logger
articleApplicationService article.ArticleApplicationService
asyncTaskRepo repositories.AsyncTaskRepository
logger *zap.Logger
articleApplicationService article.ArticleApplicationService
announcementApplicationService article.AnnouncementApplicationService
asyncTaskRepo repositories.AsyncTaskRepository
}
// NewArticleTaskHandler 创建文章任务处理器
func NewArticleTaskHandler(logger *zap.Logger, articleApplicationService article.ArticleApplicationService, asyncTaskRepo repositories.AsyncTaskRepository) *ArticleTaskHandler {
func NewArticleTaskHandler(
logger *zap.Logger,
articleApplicationService article.ArticleApplicationService,
announcementApplicationService article.AnnouncementApplicationService,
asyncTaskRepo repositories.AsyncTaskRepository,
) *ArticleTaskHandler {
return &ArticleTaskHandler{
logger: logger,
articleApplicationService: articleApplicationService,
asyncTaskRepo: asyncTaskRepo,
logger: logger,
articleApplicationService: articleApplicationService,
announcementApplicationService: announcementApplicationService,
asyncTaskRepo: asyncTaskRepo,
}
}
@@ -112,6 +119,47 @@ func (h *ArticleTaskHandler) HandleArticleModify(ctx context.Context, t *asynq.T
return nil
}
// HandleAnnouncementPublish 处理公告发布任务
func (h *ArticleTaskHandler) HandleAnnouncementPublish(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理公告发布任务")
var payload AnnouncementPublishPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析公告发布任务载荷失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
return err
}
h.logger.Info("处理公告发布任务",
zap.String("announcement_id", payload.AnnouncementID),
zap.Time("publish_at", payload.PublishAt))
// 检查任务是否已被取消
if err := h.checkTaskStatus(ctx, t); err != nil {
h.logger.Info("任务已被取消,跳过执行", zap.String("announcement_id", payload.AnnouncementID))
return nil // 静默返回,不报错
}
// 调用公告应用服务发布公告
if h.announcementApplicationService != nil {
err := h.announcementApplicationService.PublishAnnouncementByID(ctx, payload.AnnouncementID)
if err != nil {
h.logger.Error("公告发布失败", zap.String("announcement_id", payload.AnnouncementID), zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "公告发布失败: "+err.Error())
return err
}
} else {
h.logger.Warn("公告应用服务未初始化,跳过发布", zap.String("announcement_id", payload.AnnouncementID))
h.updateTaskStatus(ctx, t, "failed", "公告应用服务未初始化")
return nil
}
// 更新任务状态为成功
h.updateTaskStatus(ctx, t, "completed", "")
h.logger.Info("公告发布任务处理完成", zap.String("announcement_id", payload.AnnouncementID))
return nil
}
// ArticlePublishPayload 文章发布任务载荷
type ArticlePublishPayload struct {
ArticleID string `json:"article_id"`
@@ -157,9 +205,9 @@ func (p *ArticleCancelPayload) FromJSON(data []byte) error {
// ArticleModifyPayload 文章修改任务载荷
type ArticleModifyPayload struct {
ArticleID string `json:"article_id"`
NewPublishAt time.Time `json:"new_publish_at"`
UserID string `json:"user_id"`
ArticleID string `json:"article_id"`
NewPublishAt time.Time `json:"new_publish_at"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
@@ -177,6 +225,28 @@ func (p *ArticleModifyPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// AnnouncementPublishPayload 公告发布任务载荷
type AnnouncementPublishPayload struct {
AnnouncementID string `json:"announcement_id"`
PublishAt time.Time `json:"publish_at"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *AnnouncementPublishPayload) GetType() types.TaskType {
return types.TaskTypeAnnouncementPublish
}
// ToJSON 序列化为JSON
func (p *AnnouncementPublishPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *AnnouncementPublishPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// updateTaskStatus 更新任务状态
func (h *ArticleTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task, status string, errorMsg string) {
// 从任务载荷中提取任务ID
@@ -189,9 +259,11 @@ func (h *ArticleTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task
// 尝试从payload中获取任务ID
taskID, ok := payload["task_id"].(string)
if !ok {
// 如果没有task_id尝试从article_id生成
// 如果没有task_id尝试从article_id或announcement_id生成
if articleID, ok := payload["article_id"].(string); ok {
taskID = fmt.Sprintf("article-publish-%s", articleID)
} else if announcementID, ok := payload["announcement_id"].(string); ok {
taskID = fmt.Sprintf("announcement-publish-%s", announcementID)
} else {
h.logger.Error("无法从任务载荷中获取任务ID")
return
@@ -205,7 +277,7 @@ func (h *ArticleTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task
} else if status == "completed" {
// 成功时:清除错误信息并更新状态
if err := h.asyncTaskRepo.UpdateStatusWithSuccess(ctx, taskID, entities.TaskStatus(status)); err != nil {
h.logger.Error("更新任务状态失败",
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", status),
zap.Error(err))
@@ -213,14 +285,14 @@ func (h *ArticleTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task
} else {
// 其他状态:只更新状态
if err := h.asyncTaskRepo.UpdateStatus(ctx, taskID, entities.TaskStatus(status)); err != nil {
h.logger.Error("更新任务状态失败",
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", status),
zap.Error(err))
}
}
h.logger.Info("任务状态已更新",
h.logger.Info("任务状态已更新",
zap.String("task_id", taskID),
zap.String("status", status),
zap.String("error_msg", errorMsg))
@@ -237,29 +309,29 @@ func (h *ArticleTaskHandler) handleTaskFailure(ctx context.Context, taskID strin
// 增加重试次数
newRetryCount := task.RetryCount + 1
// 检查是否达到最大重试次数
if newRetryCount >= task.MaxRetries {
// 达到最大重试次数,标记为最终失败
if err := h.asyncTaskRepo.UpdateStatusWithRetryAndError(ctx, taskID, entities.TaskStatusFailed, errorMsg); err != nil {
h.logger.Error("更新任务状态失败",
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", "failed"),
zap.Error(err))
}
h.logger.Info("任务最终失败,已达到最大重试次数",
h.logger.Info("任务最终失败,已达到最大重试次数",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Int("max_retries", task.MaxRetries))
} else {
// 未达到最大重试次数保持pending状态记录错误信息
if err := h.asyncTaskRepo.UpdateRetryCountAndError(ctx, taskID, newRetryCount, errorMsg); err != nil {
h.logger.Error("更新任务重试次数失败",
h.logger.Error("更新任务重试次数失败",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Error(err))
}
h.logger.Info("任务失败,准备重试",
h.logger.Info("任务失败,准备重试",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Int("max_retries", task.MaxRetries))
@@ -278,9 +350,11 @@ func (h *ArticleTaskHandler) checkTaskStatus(ctx context.Context, t *asynq.Task)
// 尝试从payload中获取任务ID
taskID, ok := payload["task_id"].(string)
if !ok {
// 如果没有task_id尝试从article_id生成
// 如果没有task_id尝试从article_id或announcement_id生成
if articleID, ok := payload["article_id"].(string); ok {
taskID = fmt.Sprintf("article-publish-%s", articleID)
} else if announcementID, ok := payload["announcement_id"].(string); ok {
taskID = fmt.Sprintf("announcement-publish-%s", announcementID)
} else {
h.logger.Error("无法从任务载荷中获取任务ID")
return fmt.Errorf("无法获取任务ID")
@@ -301,4 +375,4 @@ func (h *ArticleTaskHandler) checkTaskStatus(ctx context.Context, t *asynq.Task)
}
return nil
}
}

View File

@@ -29,6 +29,7 @@ func NewAsynqWorker(
redisAddr string,
logger *zap.Logger,
articleApplicationService article.ArticleApplicationService,
announcementApplicationService article.AnnouncementApplicationService,
apiApplicationService api.ApiApplicationService,
walletService finance_services.WalletAggregateService,
subscriptionService *product_services.ProductSubscriptionService,
@@ -39,15 +40,16 @@ func NewAsynqWorker(
asynq.Config{
Concurrency: 6, // 降低总并发数
Queues: map[string]int{
"default": 2, // 2个goroutine
"api": 3, // 3个goroutine (扣款任务)
"article": 1, // 1个goroutine
"default": 2, // 2个goroutine
"api": 3, // 3个goroutine (扣款任务)
"article": 1, // 1个goroutine
"announcement": 1, // 1个goroutine
},
},
)
// 创建任务处理器
articleHandler := handlers.NewArticleTaskHandler(logger, articleApplicationService, asyncTaskRepo)
articleHandler := handlers.NewArticleTaskHandler(logger, articleApplicationService, announcementApplicationService, asyncTaskRepo)
apiHandler := handlers.NewApiTaskHandler(logger, apiApplicationService, walletService, subscriptionService, asyncTaskRepo)
// 创建ServeMux
@@ -105,6 +107,9 @@ func (w *AsynqWorker) registerAllHandlers() {
w.mux.HandleFunc(string(types.TaskTypeArticleCancel), w.articleHandler.HandleArticleCancel)
w.mux.HandleFunc(string(types.TaskTypeArticleModify), w.articleHandler.HandleArticleModify)
// 注册公告任务处理器
w.mux.HandleFunc(string(types.TaskTypeAnnouncementPublish), w.articleHandler.HandleAnnouncementPublish)
// 注册API任务处理器
w.mux.HandleFunc(string(types.TaskTypeApiCall), w.apiHandler.HandleApiCall)
w.mux.HandleFunc(string(types.TaskTypeApiLog), w.apiHandler.HandleApiLog)
@@ -116,6 +121,7 @@ func (w *AsynqWorker) registerAllHandlers() {
zap.String("article_publish", string(types.TaskTypeArticlePublish)),
zap.String("article_cancel", string(types.TaskTypeArticleCancel)),
zap.String("article_modify", string(types.TaskTypeArticleModify)),
zap.String("announcement_publish", string(types.TaskTypeAnnouncementPublish)),
zap.String("api_call", string(types.TaskTypeApiCall)),
zap.String("api_log", string(types.TaskTypeApiLog)),
)

View File

@@ -17,10 +17,10 @@ import (
// TaskManagerImpl 任务管理器实现
type TaskManagerImpl struct {
asynqClient *asynq.Client
asyncTaskRepo repositories.AsyncTaskRepository
logger *zap.Logger
config *interfaces.TaskManagerConfig
asynqClient *asynq.Client
asyncTaskRepo repositories.AsyncTaskRepository
logger *zap.Logger
config *interfaces.TaskManagerConfig
}
// NewTaskManager 创建任务管理器
@@ -42,7 +42,7 @@ func NewTaskManager(
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
// 1. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存任务到数据库失败",
tm.logger.Error("保存任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存任务失败: %w", err)
@@ -50,7 +50,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entit
// 2. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
@@ -58,7 +58,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entit
// 3. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
@@ -71,7 +71,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entit
return fmt.Errorf("任务入队失败: %w", err)
}
tm.logger.Info("任务创建并入队成功",
tm.logger.Info("任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
@@ -86,7 +86,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task
// 2. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存延时任务到数据库失败",
tm.logger.Error("保存延时任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存延时任务失败: %w", err)
@@ -94,7 +94,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task
// 3. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
@@ -102,7 +102,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task
// 4. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
@@ -115,7 +115,7 @@ func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task
return fmt.Errorf("延时任务入队失败: %w", err)
}
tm.logger.Info("延时任务创建并入队成功",
tm.logger.Info("延时任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type),
zap.Duration("delay", delay))
@@ -131,13 +131,13 @@ func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error
}
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("更新任务状态为取消失败",
tm.logger.Error("更新任务状态为取消失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务状态失败: %w", err)
}
tm.logger.Info("任务已标记为取消",
tm.logger.Info("任务已标记为取消",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
@@ -152,14 +152,14 @@ func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string
return err
}
tm.logger.Info("找到要更新的任务",
tm.logger.Info("找到要更新的任务",
zap.String("task_id", task.ID),
zap.String("current_status", string(task.Status)),
zap.Time("current_scheduled_at", *task.ScheduledAt))
// 2. 取消旧任务
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("取消旧任务失败",
tm.logger.Error("取消旧任务失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("取消旧任务失败: %w", err)
@@ -173,7 +173,7 @@ func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string
return err
}
tm.logger.Info("新任务已创建",
tm.logger.Info("新任务已创建",
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
@@ -189,7 +189,7 @@ func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string
return fmt.Errorf("重新入队任务失败: %w", err)
}
tm.logger.Info("任务调度时间更新成功",
tm.logger.Info("任务调度时间更新成功",
zap.String("old_task_id", task.ID),
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
@@ -237,7 +237,7 @@ func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
return fmt.Errorf("重试任务入队失败: %w", err)
}
tm.logger.Info("任务重试成功",
tm.logger.Info("任务重试成功",
zap.String("task_id", taskID),
zap.Int("retry_count", task.RetryCount))
@@ -248,7 +248,7 @@ func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
// TODO: 实现清理逻辑
return nil
}
@@ -274,8 +274,7 @@ func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
return nil
}
// findTask 查找任务支持taskID和articleID双重查找
// findTask 查找任务支持taskID、articleID和announcementID三重查找
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
// 先尝试通过任务ID查找
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
@@ -285,21 +284,34 @@ func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entiti
// 如果通过任务ID找不到尝试通过文章ID查找
tm.logger.Info("通过任务ID查找失败尝试通过文章ID查找", zap.String("task_id", taskID))
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
if err != nil || len(tasks) == 0 {
tm.logger.Error("通过文章ID也找不到任务",
if err == nil && len(tasks) > 0 {
// 使用找到的第一个任务
task = tasks[0]
tm.logger.Info("通过文章ID找到任务",
zap.String("article_id", taskID),
zap.String("task_id", task.ID))
return task, nil
}
// 如果通过文章ID也找不到尝试通过公告ID查找
tm.logger.Info("通过文章ID查找失败尝试通过公告ID查找", zap.String("task_id", taskID))
announcementTasks, err := tm.asyncTaskRepo.GetByAnnouncementID(ctx, taskID)
if err != nil || len(announcementTasks) == 0 {
tm.logger.Error("通过公告ID也找不到任务",
zap.String("announcement_id", taskID),
zap.Error(err))
return nil, fmt.Errorf("获取任务信息失败: %w", err)
}
// 使用找到的第一个任务
task = tasks[0]
tm.logger.Info("通过文章ID找到任务",
zap.String("article_id", taskID),
task = announcementTasks[0]
tm.logger.Info("通过公告ID找到任务",
zap.String("announcement_id", taskID),
zap.String("task_id", task.ID))
return task, nil
}
@@ -317,7 +329,7 @@ func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *
// 保存到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
tm.logger.Error("创建新任务失败",
tm.logger.Error("创建新任务失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("创建新任务失败: %w", err)
@@ -325,7 +337,7 @@ func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *
// 更新payload中的task_id
if err := tm.updatePayloadTaskID(newTask); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
tm.logger.Error("更新payload中的任务ID失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
@@ -333,7 +345,7 @@ func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *
// 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
tm.logger.Error("更新新任务payload失败",
tm.logger.Error("更新新任务payload失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
@@ -346,23 +358,24 @@ func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
queueName := tm.getQueueName(task.Type)
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
var err error
if delay > 0 {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
asynq.Queue(queueName),
asynq.ProcessIn(delay))
} else {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
}
return err
}
// getQueueName 根据任务类型获取队列名称
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
switch taskType {
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify),
string(types.TaskTypeAnnouncementPublish):
return "article"
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
return "api"

View File

@@ -42,6 +42,9 @@ type AsyncTaskRepository interface {
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
CancelArticlePublishTask(ctx context.Context, articleID string) error
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
// 公告任务专用方法
GetByAnnouncementID(ctx context.Context, announcementID string) ([]*entities.AsyncTask, error)
}
// AsyncTaskRepositoryImpl 异步任务仓库实现
@@ -219,8 +222,8 @@ func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string)
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
var task entities.AsyncTask
err := r.db.WithContext(ctx).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
First(&task).Error
@@ -234,7 +237,7 @@ func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, art
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
err := r.db.WithContext(ctx).
Where("payload LIKE ? AND status IN ?",
Where("payload LIKE ? AND status IN ?",
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Find(&tasks).Error
@@ -248,8 +251,8 @@ func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Update("status", entities.TaskStatusCancelled).Error
@@ -259,9 +262,38 @@ func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context,
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Update("scheduled_at", newScheduledAt).Error
}
}
// GetPendingArticlePublishTaskByArticleID 根据公告ID获取待执行的公告发布任务
func (r *AsyncTaskRepositoryImpl) GetPendingAnnouncementPublishTaskByAnnouncementID(ctx context.Context, announcementID string) (*entities.AsyncTask, error) {
var task entities.AsyncTask
err := r.db.WithContext(ctx).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeAnnouncementPublish,
"%\"announcement_id\":\""+announcementID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
First(&task).Error
if err != nil {
return nil, err
}
return &task, nil
}
// GetByAnnouncementID 根据公告ID获取所有相关任务
func (r *AsyncTaskRepositoryImpl) GetByAnnouncementID(ctx context.Context, announcementID string) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
err := r.db.WithContext(ctx).
Where("payload LIKE ? AND status IN ?",
"%\"announcement_id\":\""+announcementID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Find(&tasks).Error
if err != nil {
return nil, err
}
return tasks, nil
}

View File

@@ -9,12 +9,15 @@ const (
TaskTypeArticleCancel TaskType = "article_cancel"
TaskTypeArticleModify TaskType = "article_modify"
// 公告相关任务
TaskTypeAnnouncementPublish TaskType = "announcement_publish"
// API相关任务
TaskTypeApiCall TaskType = "api_call"
TaskTypeApiLog TaskType = "api_log"
// 财务相关任务
TaskTypeDeduction TaskType = "deduction"
TaskTypeDeduction TaskType = "deduction"
TaskTypeCompensation TaskType = "compensation"
// 产品相关任务
@@ -26,4 +29,4 @@ type TaskPayload interface {
GetType() TaskType
ToJSON() ([]byte, error)
FromJSON(data []byte) error
}
}