new
This commit is contained in:
@@ -0,0 +1,126 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqApiTaskQueue Asynq API任务队列实现
|
||||
type AsynqApiTaskQueue struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqApiTaskQueue 创建Asynq API任务队列
|
||||
func NewAsynqApiTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ApiTaskQueue {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqApiTaskQueue{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队任务
|
||||
func (q *AsynqApiTaskQueue) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task)
|
||||
if err != nil {
|
||||
q.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
func (q *AsynqApiTaskQueue) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
|
||||
if err != nil {
|
||||
q.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
func (q *AsynqApiTaskQueue) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
|
||||
if err != nil {
|
||||
q.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel 取消任务
|
||||
func (q *AsynqApiTaskQueue) Cancel(ctx context.Context, taskID string) error {
|
||||
// Asynq本身不支持直接取消任务,这里返回错误提示
|
||||
return fmt.Errorf("Asynq不支持直接取消任务,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// ModifySchedule 修改任务调度时间
|
||||
func (q *AsynqApiTaskQueue) ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||||
// Asynq本身不支持修改调度时间,这里返回错误提示
|
||||
return fmt.Errorf("Asynq不支持修改任务调度时间,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
func (q *AsynqApiTaskQueue) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务状态查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务状态查询,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// ListTasks 列出任务
|
||||
func (q *AsynqApiTaskQueue) ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务列表查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务列表查询,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// EnqueueTask 入队任务
|
||||
func (q *AsynqApiTaskQueue) EnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
|
||||
// 创建Asynq任务
|
||||
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
|
||||
|
||||
// 入队任务
|
||||
_, err := q.client.EnqueueContext(ctx, asynqTask)
|
||||
if err != nil {
|
||||
q.logger.Error("入队任务失败", zap.String("task_id", task.ID), zap.String("task_type", task.Type), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("入队任务成功", zap.String("task_id", task.ID), zap.String("task_type", task.Type))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqArticleTaskQueue Asynq文章任务队列实现
|
||||
type AsynqArticleTaskQueue struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqArticleTaskQueue 创建Asynq文章任务队列
|
||||
func NewAsynqArticleTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ArticleTaskQueue {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqArticleTaskQueue{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队任务
|
||||
func (q *AsynqArticleTaskQueue) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task)
|
||||
if err != nil {
|
||||
q.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
func (q *AsynqArticleTaskQueue) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
|
||||
if err != nil {
|
||||
q.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
func (q *AsynqArticleTaskQueue) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
|
||||
if err != nil {
|
||||
q.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel 取消任务
|
||||
func (q *AsynqArticleTaskQueue) Cancel(ctx context.Context, taskID string) error {
|
||||
// Asynq本身不支持直接取消任务,但我们可以通过以下方式实现:
|
||||
// 1. 在数据库中标记任务为已取消
|
||||
// 2. 任务执行时检查状态,如果已取消则跳过执行
|
||||
|
||||
q.logger.Info("标记任务为已取消", zap.String("task_id", taskID))
|
||||
|
||||
// 这里应该更新数据库中的任务状态为cancelled
|
||||
// 由于我们没有直接访问repository,暂时只记录日志
|
||||
// 实际实现中应该调用AsyncTaskRepository.UpdateStatus
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModifySchedule 修改任务调度时间
|
||||
func (q *AsynqArticleTaskQueue) ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||||
// Asynq本身不支持修改调度时间,但我们可以通过以下方式实现:
|
||||
// 1. 取消旧任务
|
||||
// 2. 创建新任务
|
||||
|
||||
q.logger.Info("修改任务调度时间",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Time("new_scheduled_at", newScheduledAt))
|
||||
|
||||
// 这里应该:
|
||||
// 1. 调用Cancel取消旧任务
|
||||
// 2. 根据任务类型重新创建任务
|
||||
// 由于没有直接访问repository,暂时只记录日志
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
func (q *AsynqArticleTaskQueue) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务状态查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务状态查询,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// ListTasks 列出任务
|
||||
func (q *AsynqArticleTaskQueue) ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务列表查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务列表查询,请使用数据库状态管理")
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqClient Asynq客户端实现
|
||||
type AsynqClient struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqClient 创建Asynq客户端
|
||||
func NewAsynqClient(redisAddr string, logger *zap.Logger) *AsynqClient {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqClient{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队任务
|
||||
func (c *AsynqClient) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = c.client.EnqueueContext(ctx, task)
|
||||
if err != nil {
|
||||
c.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
func (c *AsynqClient) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = c.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
|
||||
if err != nil {
|
||||
c.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
func (c *AsynqClient) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = c.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
|
||||
if err != nil {
|
||||
c.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭客户端
|
||||
func (c *AsynqClient) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/api"
|
||||
"tyapi-server/internal/application/article"
|
||||
finance_services "tyapi-server/internal/domains/finance/services"
|
||||
product_services "tyapi-server/internal/domains/product/services"
|
||||
"tyapi-server/internal/infrastructure/task/handlers"
|
||||
"tyapi-server/internal/infrastructure/task/repositories"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqWorker Asynq Worker实现
|
||||
type AsynqWorker struct {
|
||||
server *asynq.Server
|
||||
mux *asynq.ServeMux
|
||||
logger *zap.Logger
|
||||
articleHandler *handlers.ArticleTaskHandler
|
||||
apiHandler *handlers.ApiTaskHandler
|
||||
}
|
||||
|
||||
// NewAsynqWorker 创建Asynq Worker
|
||||
func NewAsynqWorker(
|
||||
redisAddr string,
|
||||
logger *zap.Logger,
|
||||
articleApplicationService article.ArticleApplicationService,
|
||||
apiApplicationService api.ApiApplicationService,
|
||||
walletService finance_services.WalletAggregateService,
|
||||
subscriptionService *product_services.ProductSubscriptionService,
|
||||
asyncTaskRepo repositories.AsyncTaskRepository,
|
||||
) *AsynqWorker {
|
||||
server := asynq.NewServer(
|
||||
asynq.RedisClientOpt{Addr: redisAddr},
|
||||
asynq.Config{
|
||||
Concurrency: 6, // 降低总并发数
|
||||
Queues: map[string]int{
|
||||
"default": 2, // 2个goroutine
|
||||
"api": 3, // 3个goroutine (扣款任务)
|
||||
"article": 1, // 1个goroutine
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// 创建任务处理器
|
||||
articleHandler := handlers.NewArticleTaskHandler(logger, articleApplicationService, asyncTaskRepo)
|
||||
apiHandler := handlers.NewApiTaskHandler(logger, apiApplicationService, walletService, subscriptionService, asyncTaskRepo)
|
||||
|
||||
// 创建ServeMux
|
||||
mux := asynq.NewServeMux()
|
||||
|
||||
return &AsynqWorker{
|
||||
server: server,
|
||||
mux: mux,
|
||||
logger: logger,
|
||||
articleHandler: articleHandler,
|
||||
apiHandler: apiHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册任务处理器
|
||||
func (w *AsynqWorker) RegisterHandler(taskType types.TaskType, handler func(context.Context, *asynq.Task) error) {
|
||||
// 简化实现,避免API兼容性问题
|
||||
w.logger.Info("注册任务处理器", zap.String("task_type", string(taskType)))
|
||||
}
|
||||
|
||||
// Start 启动Worker
|
||||
func (w *AsynqWorker) Start() error {
|
||||
w.logger.Info("启动Asynq Worker")
|
||||
|
||||
// 注册所有任务处理器
|
||||
w.registerAllHandlers()
|
||||
|
||||
// 启动Worker服务器
|
||||
go func() {
|
||||
if err := w.server.Run(w.mux); err != nil {
|
||||
w.logger.Error("Worker运行失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
w.logger.Info("Asynq Worker启动成功")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止Worker
|
||||
func (w *AsynqWorker) Stop() {
|
||||
w.logger.Info("停止Asynq Worker")
|
||||
w.server.Stop()
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭Worker
|
||||
func (w *AsynqWorker) Shutdown() {
|
||||
w.logger.Info("优雅关闭Asynq Worker")
|
||||
w.server.Shutdown()
|
||||
}
|
||||
|
||||
// registerAllHandlers 注册所有任务处理器
|
||||
func (w *AsynqWorker) registerAllHandlers() {
|
||||
// 注册文章任务处理器
|
||||
w.mux.HandleFunc(string(types.TaskTypeArticlePublish), w.articleHandler.HandleArticlePublish)
|
||||
w.mux.HandleFunc(string(types.TaskTypeArticleCancel), w.articleHandler.HandleArticleCancel)
|
||||
w.mux.HandleFunc(string(types.TaskTypeArticleModify), w.articleHandler.HandleArticleModify)
|
||||
|
||||
// 注册API任务处理器
|
||||
w.mux.HandleFunc(string(types.TaskTypeApiCall), w.apiHandler.HandleApiCall)
|
||||
w.mux.HandleFunc(string(types.TaskTypeApiLog), w.apiHandler.HandleApiLog)
|
||||
w.mux.HandleFunc(string(types.TaskTypeDeduction), w.apiHandler.HandleDeduction)
|
||||
w.mux.HandleFunc(string(types.TaskTypeCompensation), w.apiHandler.HandleCompensation)
|
||||
w.mux.HandleFunc(string(types.TaskTypeUsageStats), w.apiHandler.HandleUsageStats)
|
||||
|
||||
w.logger.Info("所有任务处理器注册完成",
|
||||
zap.String("article_publish", string(types.TaskTypeArticlePublish)),
|
||||
zap.String("article_cancel", string(types.TaskTypeArticleCancel)),
|
||||
zap.String("article_modify", string(types.TaskTypeArticleModify)),
|
||||
zap.String("api_call", string(types.TaskTypeApiCall)),
|
||||
zap.String("api_log", string(types.TaskTypeApiLog)),
|
||||
)
|
||||
}
|
||||
374
internal/infrastructure/task/implementations/task_manager.go
Normal file
374
internal/infrastructure/task/implementations/task_manager.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package implementations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
"tyapi-server/internal/infrastructure/task/repositories"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// TaskManagerImpl 任务管理器实现
|
||||
type TaskManagerImpl struct {
|
||||
asynqClient *asynq.Client
|
||||
asyncTaskRepo repositories.AsyncTaskRepository
|
||||
logger *zap.Logger
|
||||
config *interfaces.TaskManagerConfig
|
||||
}
|
||||
|
||||
// NewTaskManager 创建任务管理器
|
||||
func NewTaskManager(
|
||||
asynqClient *asynq.Client,
|
||||
asyncTaskRepo repositories.AsyncTaskRepository,
|
||||
logger *zap.Logger,
|
||||
config *interfaces.TaskManagerConfig,
|
||||
) interfaces.TaskManager {
|
||||
return &TaskManagerImpl{
|
||||
asynqClient: asynqClient,
|
||||
asyncTaskRepo: asyncTaskRepo,
|
||||
logger: logger,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAndEnqueueTask 创建并入队任务
|
||||
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
|
||||
// 1. 保存任务到数据库(GORM会自动生成UUID)
|
||||
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
|
||||
tm.logger.Error("保存任务到数据库失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("保存任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新payload中的task_id
|
||||
if err := tm.updatePayloadTaskID(task); err != nil {
|
||||
tm.logger.Error("更新payload中的任务ID失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新数据库中的payload
|
||||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||||
tm.logger.Error("更新任务payload失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新任务payload失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 入队到Asynq
|
||||
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
|
||||
// 如果入队失败,更新任务状态为失败
|
||||
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
|
||||
return fmt.Errorf("任务入队失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务创建并入队成功",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.Type))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateAndEnqueueDelayedTask 创建并入队延时任务
|
||||
func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
|
||||
// 1. 设置调度时间
|
||||
scheduledAt := time.Now().Add(delay)
|
||||
task.ScheduledAt = &scheduledAt
|
||||
|
||||
// 2. 保存任务到数据库(GORM会自动生成UUID)
|
||||
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
|
||||
tm.logger.Error("保存延时任务到数据库失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("保存延时任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新payload中的task_id
|
||||
if err := tm.updatePayloadTaskID(task); err != nil {
|
||||
tm.logger.Error("更新payload中的任务ID失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 更新数据库中的payload
|
||||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||||
tm.logger.Error("更新任务payload失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新任务payload失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 入队到Asynq延时队列
|
||||
if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
|
||||
// 如果入队失败,更新任务状态为失败
|
||||
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
|
||||
return fmt.Errorf("延时任务入队失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("延时任务创建并入队成功",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.Type),
|
||||
zap.Duration("delay", delay))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelTask 取消任务
|
||||
func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
|
||||
task, err := tm.findTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
|
||||
tm.logger.Error("更新任务状态为取消失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新任务状态失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务已标记为取消",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.Type))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTaskSchedule 更新任务调度时间
|
||||
func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||||
// 1. 查找任务
|
||||
task, err := tm.findTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.logger.Info("找到要更新的任务",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("current_status", string(task.Status)),
|
||||
zap.Time("current_scheduled_at", *task.ScheduledAt))
|
||||
|
||||
// 2. 取消旧任务
|
||||
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
|
||||
tm.logger.Error("取消旧任务失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("取消旧任务失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
|
||||
|
||||
// 3. 创建并保存新任务
|
||||
newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.logger.Info("新任务已创建",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Time("new_scheduled_at", newScheduledAt))
|
||||
|
||||
// 4. 计算延时并入队
|
||||
delay := newScheduledAt.Sub(time.Now())
|
||||
if delay < 0 {
|
||||
delay = 0 // 如果时间已过,立即执行
|
||||
}
|
||||
|
||||
if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
|
||||
// 如果入队失败,删除新创建的任务记录
|
||||
tm.asyncTaskRepo.Delete(ctx, newTask.ID)
|
||||
return fmt.Errorf("重新入队任务失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务调度时间更新成功",
|
||||
zap.String("old_task_id", task.ID),
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Time("new_scheduled_at", newScheduledAt))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
return tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
}
|
||||
|
||||
// UpdateTaskStatus 更新任务状态
|
||||
func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
|
||||
if errorMsg != "" {
|
||||
return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
|
||||
}
|
||||
return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
|
||||
}
|
||||
|
||||
// RetryTask 重试任务
|
||||
func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
|
||||
// 1. 获取任务信息
|
||||
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取任务信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 检查是否可以重试
|
||||
if !task.CanRetry() {
|
||||
return fmt.Errorf("任务已达到最大重试次数")
|
||||
}
|
||||
|
||||
// 3. 增加重试次数并重置状态
|
||||
task.RetryCount++
|
||||
task.Status = entities.TaskStatusPending
|
||||
|
||||
// 4. 更新数据库
|
||||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||||
return fmt.Errorf("更新任务重试次数失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 重新入队
|
||||
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
|
||||
return fmt.Errorf("重试任务入队失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务重试成功",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", task.RetryCount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpiredTasks 清理过期任务
|
||||
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
|
||||
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
|
||||
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
|
||||
|
||||
// TODO: 实现清理逻辑
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePayloadTaskID 更新payload中的task_id
|
||||
func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
|
||||
// 解析payload
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
|
||||
return fmt.Errorf("解析payload失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新task_id
|
||||
payload["task_id"] = task.ID
|
||||
|
||||
// 重新序列化
|
||||
newPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化payload失败: %w", err)
|
||||
}
|
||||
|
||||
task.Payload = string(newPayload)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// findTask 查找任务(支持taskID和articleID双重查找)
|
||||
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
// 先尝试通过任务ID查找
|
||||
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err == nil {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// 如果通过任务ID找不到,尝试通过文章ID查找
|
||||
tm.logger.Info("通过任务ID查找失败,尝试通过文章ID查找", zap.String("task_id", taskID))
|
||||
|
||||
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
|
||||
if err != nil || len(tasks) == 0 {
|
||||
tm.logger.Error("通过文章ID也找不到任务",
|
||||
zap.String("article_id", taskID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("获取任务信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 使用找到的第一个任务
|
||||
task = tasks[0]
|
||||
tm.logger.Info("通过文章ID找到任务",
|
||||
zap.String("article_id", taskID),
|
||||
zap.String("task_id", task.ID))
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// createAndSaveTask 创建并保存新任务
|
||||
func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
|
||||
// 创建新任务
|
||||
newTask := &entities.AsyncTask{
|
||||
Type: originalTask.Type,
|
||||
Payload: originalTask.Payload,
|
||||
Status: entities.TaskStatusPending,
|
||||
ScheduledAt: &newScheduledAt,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库(GORM会自动生成UUID)
|
||||
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
|
||||
tm.logger.Error("创建新任务失败",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("创建新任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新payload中的task_id
|
||||
if err := tm.updatePayloadTaskID(newTask); err != nil {
|
||||
tm.logger.Error("更新payload中的任务ID失败",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库中的payload
|
||||
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
|
||||
tm.logger.Error("更新新任务payload失败",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
|
||||
}
|
||||
|
||||
return newTask, nil
|
||||
}
|
||||
|
||||
// enqueueTaskWithDelay 入队任务到Asynq(支持延时)
|
||||
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
|
||||
queueName := tm.getQueueName(task.Type)
|
||||
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
|
||||
|
||||
var err error
|
||||
if delay > 0 {
|
||||
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
|
||||
asynq.Queue(queueName),
|
||||
asynq.ProcessIn(delay))
|
||||
} else {
|
||||
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getQueueName 根据任务类型获取队列名称
|
||||
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
|
||||
switch taskType {
|
||||
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
|
||||
return "article"
|
||||
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
|
||||
return "api"
|
||||
case string(types.TaskTypeCompensation):
|
||||
return "finance"
|
||||
default:
|
||||
return "default"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user