new
This commit is contained in:
1
internal/infrastructure/task/README.md
Normal file
1
internal/infrastructure/task/README.md
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"tyapi-server/internal/domains/article/repositories"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ArticlePublisher 文章发布接口
|
||||
type ArticlePublisher interface {
|
||||
PublishArticleByID(ctx context.Context, articleID string) error
|
||||
}
|
||||
|
||||
// ArticleTaskHandler 文章任务处理器
|
||||
type ArticleTaskHandler struct {
|
||||
publisher ArticlePublisher
|
||||
scheduledTaskRepo repositories.ScheduledTaskRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewArticleTaskHandler 创建文章任务处理器
|
||||
func NewArticleTaskHandler(
|
||||
publisher ArticlePublisher,
|
||||
scheduledTaskRepo repositories.ScheduledTaskRepository,
|
||||
logger *zap.Logger,
|
||||
) *ArticleTaskHandler {
|
||||
return &ArticleTaskHandler{
|
||||
publisher: publisher,
|
||||
scheduledTaskRepo: scheduledTaskRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleArticlePublish 处理文章定时发布任务
|
||||
func (h *ArticleTaskHandler) HandleArticlePublish(ctx context.Context, t *asynq.Task) error {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析任务载荷失败", zap.Error(err))
|
||||
return fmt.Errorf("解析任务载荷失败: %w", err)
|
||||
}
|
||||
|
||||
articleID, ok := payload["article_id"].(string)
|
||||
if !ok {
|
||||
h.logger.Error("任务载荷中缺少文章ID")
|
||||
return fmt.Errorf("任务载荷中缺少文章ID")
|
||||
}
|
||||
|
||||
// 获取任务状态记录
|
||||
task, err := h.scheduledTaskRepo.GetByTaskID(ctx, t.ResultWriter().TaskID())
|
||||
if err != nil {
|
||||
h.logger.Error("获取任务状态记录失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(err))
|
||||
// 继续执行,不阻断任务
|
||||
} else {
|
||||
// 检查任务是否已取消
|
||||
if task.IsCancelled() {
|
||||
h.logger.Info("任务已取消,跳过执行", zap.String("task_id", t.ResultWriter().TaskID()))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 标记任务为正在执行
|
||||
task.MarkAsRunning()
|
||||
if err := h.scheduledTaskRepo.Update(ctx, task); err != nil {
|
||||
h.logger.Warn("更新任务状态失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// 执行文章发布
|
||||
if err := h.publisher.PublishArticleByID(ctx, articleID); err != nil {
|
||||
// 更新任务状态为失败
|
||||
if task.ID != "" {
|
||||
task.MarkAsFailed(err.Error())
|
||||
if updateErr := h.scheduledTaskRepo.Update(ctx, task); updateErr != nil {
|
||||
h.logger.Warn("更新任务失败状态失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(updateErr))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Error("定时发布文章失败",
|
||||
zap.String("article_id", articleID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("定时发布文章失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新任务状态为已完成
|
||||
if task.ID != "" {
|
||||
task.MarkAsCompleted()
|
||||
if err := h.scheduledTaskRepo.Update(ctx, task); err != nil {
|
||||
h.logger.Warn("更新任务完成状态失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("定时发布文章成功", zap.String("article_id", articleID))
|
||||
return nil
|
||||
}
|
||||
@@ -1,133 +0,0 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
"tyapi-server/internal/domains/article/entities"
|
||||
"tyapi-server/internal/domains/article/repositories"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AsynqClient Asynq 客户端
|
||||
type AsynqClient struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
scheduledTaskRepo repositories.ScheduledTaskRepository
|
||||
}
|
||||
|
||||
// NewAsynqClient 创建 Asynq 客户端
|
||||
func NewAsynqClient(redisAddr string, scheduledTaskRepo repositories.ScheduledTaskRepository, logger *zap.Logger) *AsynqClient {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqClient{
|
||||
client: client,
|
||||
logger: logger,
|
||||
scheduledTaskRepo: scheduledTaskRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Close 关闭客户端
|
||||
func (c *AsynqClient) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// ScheduleArticlePublish 调度文章定时发布任务
|
||||
func (c *AsynqClient) ScheduleArticlePublish(ctx context.Context, articleID string, publishTime time.Time) (string, error) {
|
||||
payload := map[string]interface{}{
|
||||
"article_id": articleID,
|
||||
}
|
||||
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return "", fmt.Errorf("创建任务失败: %w", err)
|
||||
}
|
||||
|
||||
task := asynq.NewTask(TaskTypeArticlePublish, payloadBytes)
|
||||
|
||||
// 计算延迟时间
|
||||
delay := publishTime.Sub(time.Now())
|
||||
if delay <= 0 {
|
||||
return "", fmt.Errorf("定时发布时间不能早于当前时间")
|
||||
}
|
||||
|
||||
// 设置任务选项
|
||||
opts := []asynq.Option{
|
||||
asynq.ProcessIn(delay),
|
||||
asynq.MaxRetry(3),
|
||||
asynq.Timeout(5 * time.Minute),
|
||||
}
|
||||
|
||||
info, err := c.client.Enqueue(task, opts...)
|
||||
if err != nil {
|
||||
c.logger.Error("调度定时发布任务失败",
|
||||
zap.String("article_id", articleID),
|
||||
zap.Time("publish_time", publishTime),
|
||||
zap.Error(err))
|
||||
return "", fmt.Errorf("调度任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建任务状态记录
|
||||
scheduledTask := entities.ScheduledTask{
|
||||
TaskID: info.ID,
|
||||
TaskType: TaskTypeArticlePublish,
|
||||
ArticleID: articleID,
|
||||
Status: entities.TaskStatusPending,
|
||||
ScheduledAt: publishTime,
|
||||
}
|
||||
|
||||
if _, err := c.scheduledTaskRepo.Create(ctx, scheduledTask); err != nil {
|
||||
c.logger.Error("创建任务状态记录失败", zap.String("task_id", info.ID), zap.Error(err))
|
||||
// 不返回错误,因为Asynq任务已经创建成功
|
||||
}
|
||||
|
||||
c.logger.Info("定时发布任务调度成功",
|
||||
zap.String("article_id", articleID),
|
||||
zap.Time("publish_time", publishTime),
|
||||
zap.String("task_id", info.ID))
|
||||
|
||||
return info.ID, nil
|
||||
}
|
||||
|
||||
// CancelScheduledTask 取消已调度的任务
|
||||
func (c *AsynqClient) CancelScheduledTask(ctx context.Context, taskID string) error {
|
||||
c.logger.Info("标记定时任务为已取消",
|
||||
zap.String("task_id", taskID))
|
||||
|
||||
// 标记数据库中的任务状态为已取消
|
||||
if err := c.scheduledTaskRepo.MarkAsCancelled(ctx, taskID); err != nil {
|
||||
c.logger.Warn("标记任务状态为已取消失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
// 不返回错误,因为Asynq任务可能已经执行完成
|
||||
}
|
||||
|
||||
// Asynq不支持直接取消任务,我们通过数据库状态来标记
|
||||
// 任务执行时会检查文章状态,如果已取消则跳过执行
|
||||
return nil
|
||||
}
|
||||
|
||||
// RescheduleArticlePublish 重新调度文章定时发布任务
|
||||
func (c *AsynqClient) RescheduleArticlePublish(ctx context.Context, articleID string, oldTaskID string, newPublishTime time.Time) (string, error) {
|
||||
// 1. 取消旧任务(标记为已取消)
|
||||
if err := c.CancelScheduledTask(ctx, oldTaskID); err != nil {
|
||||
c.logger.Warn("取消旧任务失败",
|
||||
zap.String("old_task_id", oldTaskID),
|
||||
zap.Error(err))
|
||||
}
|
||||
|
||||
// 2. 创建新任务
|
||||
newTaskID, err := c.ScheduleArticlePublish(ctx, articleID, newPublishTime)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("重新调度任务失败: %w", err)
|
||||
}
|
||||
|
||||
c.logger.Info("重新调度定时发布任务成功",
|
||||
zap.String("article_id", articleID),
|
||||
zap.String("old_task_id", oldTaskID),
|
||||
zap.String("new_task_id", newTaskID),
|
||||
zap.Time("new_publish_time", newPublishTime))
|
||||
|
||||
return newTaskID, nil
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AsynqWorker Asynq Worker
|
||||
type AsynqWorker struct {
|
||||
server *asynq.Server
|
||||
mux *asynq.ServeMux
|
||||
taskHandler *ArticleTaskHandler
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqWorker 创建 Asynq Worker
|
||||
func NewAsynqWorker(
|
||||
redisAddr string,
|
||||
taskHandler *ArticleTaskHandler,
|
||||
logger *zap.Logger,
|
||||
) *AsynqWorker {
|
||||
server := asynq.NewServer(
|
||||
asynq.RedisClientOpt{Addr: redisAddr},
|
||||
asynq.Config{
|
||||
Concurrency: 10, // 并发数
|
||||
Queues: map[string]int{
|
||||
"critical": 6,
|
||||
"default": 3,
|
||||
"low": 1,
|
||||
},
|
||||
Logger: NewAsynqLogger(logger),
|
||||
},
|
||||
)
|
||||
|
||||
mux := asynq.NewServeMux()
|
||||
|
||||
return &AsynqWorker{
|
||||
server: server,
|
||||
mux: mux,
|
||||
taskHandler: taskHandler,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandlers 注册任务处理器
|
||||
func (w *AsynqWorker) RegisterHandlers() {
|
||||
// 注册文章定时发布任务处理器
|
||||
w.mux.HandleFunc(TaskTypeArticlePublish, w.taskHandler.HandleArticlePublish)
|
||||
|
||||
w.logger.Info("任务处理器注册完成")
|
||||
}
|
||||
|
||||
// Start 启动 Worker
|
||||
func (w *AsynqWorker) Start() error {
|
||||
w.RegisterHandlers()
|
||||
|
||||
w.logger.Info("启动 Asynq Worker")
|
||||
return w.server.Run(w.mux)
|
||||
}
|
||||
|
||||
// Stop 停止 Worker
|
||||
func (w *AsynqWorker) Stop() {
|
||||
w.logger.Info("停止 Asynq Worker")
|
||||
w.server.Stop()
|
||||
w.server.Shutdown()
|
||||
}
|
||||
|
||||
// AsynqLogger Asynq 日志适配器
|
||||
type AsynqLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqLogger 创建 Asynq 日志适配器
|
||||
func NewAsynqLogger(logger *zap.Logger) *AsynqLogger {
|
||||
return &AsynqLogger{logger: logger}
|
||||
}
|
||||
|
||||
func (l *AsynqLogger) Debug(args ...interface{}) {
|
||||
l.logger.Debug(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *AsynqLogger) Info(args ...interface{}) {
|
||||
l.logger.Info(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *AsynqLogger) Warn(args ...interface{}) {
|
||||
l.logger.Warn(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *AsynqLogger) Error(args ...interface{}) {
|
||||
l.logger.Error(fmt.Sprint(args...))
|
||||
}
|
||||
|
||||
func (l *AsynqLogger) Fatal(args ...interface{}) {
|
||||
l.logger.Fatal(fmt.Sprint(args...))
|
||||
}
|
||||
68
internal/infrastructure/task/entities/async_task.go
Normal file
68
internal/infrastructure/task/entities/async_task.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package entities
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// TaskStatus 任务状态
|
||||
type TaskStatus string
|
||||
|
||||
const (
|
||||
TaskStatusPending TaskStatus = "pending"
|
||||
TaskStatusRunning TaskStatus = "running"
|
||||
TaskStatusCompleted TaskStatus = "completed"
|
||||
TaskStatusFailed TaskStatus = "failed"
|
||||
TaskStatusCancelled TaskStatus = "cancelled"
|
||||
)
|
||||
|
||||
// AsyncTask 异步任务实体
|
||||
type AsyncTask struct {
|
||||
ID string `gorm:"type:char(36);primaryKey"`
|
||||
Type string `gorm:"not null;index"`
|
||||
Payload string `gorm:"type:text"`
|
||||
Status TaskStatus `gorm:"not null;index"`
|
||||
ScheduledAt *time.Time `gorm:"index"`
|
||||
StartedAt *time.Time
|
||||
CompletedAt *time.Time
|
||||
ErrorMsg string
|
||||
RetryCount int `gorm:"default:0"`
|
||||
MaxRetries int `gorm:"default:5"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// TableName 指定表名
|
||||
func (AsyncTask) TableName() string {
|
||||
return "async_tasks"
|
||||
}
|
||||
|
||||
// BeforeCreate GORM钩子,在创建前生成UUID
|
||||
func (t *AsyncTask) BeforeCreate(tx *gorm.DB) error {
|
||||
if t.ID == "" {
|
||||
t.ID = uuid.New().String()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsCompleted 检查任务是否已完成
|
||||
func (t *AsyncTask) IsCompleted() bool {
|
||||
return t.Status == TaskStatusCompleted
|
||||
}
|
||||
|
||||
// IsFailed 检查任务是否失败
|
||||
func (t *AsyncTask) IsFailed() bool {
|
||||
return t.Status == TaskStatusFailed
|
||||
}
|
||||
|
||||
// IsCancelled 检查任务是否已取消
|
||||
func (t *AsyncTask) IsCancelled() bool {
|
||||
return t.Status == TaskStatusCancelled
|
||||
}
|
||||
|
||||
// CanRetry 检查任务是否可以重试
|
||||
func (t *AsyncTask) CanRetry() bool {
|
||||
return t.Status == TaskStatusFailed && t.RetryCount < t.MaxRetries
|
||||
}
|
||||
335
internal/infrastructure/task/entities/task_factory.go
Normal file
335
internal/infrastructure/task/entities/task_factory.go
Normal file
@@ -0,0 +1,335 @@
|
||||
package entities
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// TaskFactory 任务工厂
|
||||
type TaskFactory struct {
|
||||
taskManager interface{} // 使用interface{}避免循环导入
|
||||
}
|
||||
|
||||
// NewTaskFactory 创建任务工厂
|
||||
func NewTaskFactory() *TaskFactory {
|
||||
return &TaskFactory{}
|
||||
}
|
||||
|
||||
// NewTaskFactoryWithManager 创建带管理器的任务工厂
|
||||
func NewTaskFactoryWithManager(taskManager interface{}) *TaskFactory {
|
||||
return &TaskFactory{
|
||||
taskManager: taskManager,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateArticlePublishTask 创建文章发布任务
|
||||
func (f *TaskFactory) CreateArticlePublishTask(articleID string, publishAt time.Time, userID string) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeArticlePublish),
|
||||
Status: TaskStatusPending,
|
||||
ScheduledAt: &publishAt,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务ID(将在保存后更新)
|
||||
payloadWithID := map[string]interface{}{
|
||||
"article_id": articleID,
|
||||
"publish_at": publishAt,
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateArticleCancelTask 创建文章取消任务
|
||||
func (f *TaskFactory) CreateArticleCancelTask(articleID string, userID string) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeArticleCancel),
|
||||
Status: TaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务数据
|
||||
payloadWithID := map[string]interface{}{
|
||||
"article_id": articleID,
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateArticleModifyTask 创建文章修改任务
|
||||
func (f *TaskFactory) CreateArticleModifyTask(articleID string, newPublishAt time.Time, userID string) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeArticleModify),
|
||||
Status: TaskStatusPending,
|
||||
ScheduledAt: &newPublishAt,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务数据
|
||||
payloadWithID := map[string]interface{}{
|
||||
"article_id": articleID,
|
||||
"new_publish_at": newPublishAt,
|
||||
"user_id": userID,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateApiCallTask 创建API调用任务
|
||||
func (f *TaskFactory) CreateApiCallTask(apiCallID string, userID string, productID string, amount string) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeApiCall),
|
||||
Status: TaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务数据
|
||||
payloadWithID := map[string]interface{}{
|
||||
"api_call_id": apiCallID,
|
||||
"user_id": userID,
|
||||
"product_id": productID,
|
||||
"amount": amount,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateDeductionTask 创建扣款任务
|
||||
func (f *TaskFactory) CreateDeductionTask(apiCallID string, userID string, productID string, amount string, transactionID string) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeDeduction),
|
||||
Status: TaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务数据
|
||||
payloadWithID := map[string]interface{}{
|
||||
"api_call_id": apiCallID,
|
||||
"user_id": userID,
|
||||
"product_id": productID,
|
||||
"amount": amount,
|
||||
"transaction_id": transactionID,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateApiCallLogTask 创建API调用日志任务
|
||||
func (f *TaskFactory) CreateApiCallLogTask(transactionID string, userID string, apiName string, productID string) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeApiLog),
|
||||
Status: TaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务数据
|
||||
payloadWithID := map[string]interface{}{
|
||||
"transaction_id": transactionID,
|
||||
"user_id": userID,
|
||||
"api_name": apiName,
|
||||
"product_id": productID,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateUsageStatsTask 创建使用统计任务
|
||||
func (f *TaskFactory) CreateUsageStatsTask(subscriptionID string, userID string, productID string, increment int) (*AsyncTask, error) {
|
||||
// 创建任务实体,ID将由GORM的BeforeCreate钩子自动生成UUID
|
||||
task := &AsyncTask{
|
||||
Type: string(types.TaskTypeUsageStats),
|
||||
Status: TaskStatusPending,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 在payload中添加任务数据
|
||||
payloadWithID := map[string]interface{}{
|
||||
"subscription_id": subscriptionID,
|
||||
"user_id": userID,
|
||||
"product_id": productID,
|
||||
"increment": increment,
|
||||
}
|
||||
|
||||
payloadDataWithID, err := json.Marshal(payloadWithID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task.Payload = string(payloadDataWithID)
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// CreateAndEnqueueArticlePublishTask 创建并入队文章发布任务
|
||||
func (f *TaskFactory) CreateAndEnqueueArticlePublishTask(ctx context.Context, articleID string, publishAt time.Time, userID string) error {
|
||||
if f.taskManager == nil {
|
||||
return fmt.Errorf("TaskManager未初始化")
|
||||
}
|
||||
|
||||
task, err := f.CreateArticlePublishTask(articleID, publishAt, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delay := publishAt.Sub(time.Now())
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
|
||||
// 使用类型断言调用TaskManager方法
|
||||
if tm, ok := f.taskManager.(interface {
|
||||
CreateAndEnqueueDelayedTask(ctx context.Context, task *AsyncTask, delay time.Duration) error
|
||||
}); ok {
|
||||
return tm.CreateAndEnqueueDelayedTask(ctx, task, delay)
|
||||
}
|
||||
|
||||
return fmt.Errorf("TaskManager类型不匹配")
|
||||
}
|
||||
|
||||
// CreateAndEnqueueApiLogTask 创建并入队API日志任务
|
||||
func (f *TaskFactory) CreateAndEnqueueApiLogTask(ctx context.Context, transactionID string, userID string, apiName string, productID string) error {
|
||||
if f.taskManager == nil {
|
||||
return fmt.Errorf("TaskManager未初始化")
|
||||
}
|
||||
|
||||
task, err := f.CreateApiCallLogTask(transactionID, userID, apiName, productID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用类型断言调用TaskManager方法
|
||||
if tm, ok := f.taskManager.(interface {
|
||||
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
|
||||
}); ok {
|
||||
return tm.CreateAndEnqueueTask(ctx, task)
|
||||
}
|
||||
|
||||
return fmt.Errorf("TaskManager类型不匹配")
|
||||
}
|
||||
|
||||
// CreateAndEnqueueApiCallTask 创建并入队API调用任务
|
||||
func (f *TaskFactory) CreateAndEnqueueApiCallTask(ctx context.Context, apiCallID string, userID string, productID string, amount string) error {
|
||||
if f.taskManager == nil {
|
||||
return fmt.Errorf("TaskManager未初始化")
|
||||
}
|
||||
|
||||
task, err := f.CreateApiCallTask(apiCallID, userID, productID, amount)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用类型断言调用TaskManager方法
|
||||
if tm, ok := f.taskManager.(interface {
|
||||
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
|
||||
}); ok {
|
||||
return tm.CreateAndEnqueueTask(ctx, task)
|
||||
}
|
||||
|
||||
return fmt.Errorf("TaskManager类型不匹配")
|
||||
}
|
||||
|
||||
// CreateAndEnqueueDeductionTask 创建并入队扣款任务
|
||||
func (f *TaskFactory) CreateAndEnqueueDeductionTask(ctx context.Context, apiCallID string, userID string, productID string, amount string, transactionID string) error {
|
||||
if f.taskManager == nil {
|
||||
return fmt.Errorf("TaskManager未初始化")
|
||||
}
|
||||
|
||||
task, err := f.CreateDeductionTask(apiCallID, userID, productID, amount, transactionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用类型断言调用TaskManager方法
|
||||
if tm, ok := f.taskManager.(interface {
|
||||
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
|
||||
}); ok {
|
||||
return tm.CreateAndEnqueueTask(ctx, task)
|
||||
}
|
||||
|
||||
return fmt.Errorf("TaskManager类型不匹配")
|
||||
}
|
||||
|
||||
// CreateAndEnqueueUsageStatsTask 创建并入队使用统计任务
|
||||
func (f *TaskFactory) CreateAndEnqueueUsageStatsTask(ctx context.Context, subscriptionID string, userID string, productID string, increment int) error {
|
||||
if f.taskManager == nil {
|
||||
return fmt.Errorf("TaskManager未初始化")
|
||||
}
|
||||
|
||||
task, err := f.CreateUsageStatsTask(subscriptionID, userID, productID, increment)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用类型断言调用TaskManager方法
|
||||
if tm, ok := f.taskManager.(interface {
|
||||
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
|
||||
}); ok {
|
||||
return tm.CreateAndEnqueueTask(ctx, task)
|
||||
}
|
||||
|
||||
return fmt.Errorf("TaskManager类型不匹配")
|
||||
}
|
||||
|
||||
// generateRandomString 生成随机字符串
|
||||
func generateRandomString(length int) string {
|
||||
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
45
internal/infrastructure/task/factory.go
Normal file
45
internal/infrastructure/task/factory.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package task
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/task/implementations/asynq"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TaskFactory 任务工厂
|
||||
type TaskFactory struct{}
|
||||
|
||||
// NewTaskFactory 创建任务工厂
|
||||
func NewTaskFactory() *TaskFactory {
|
||||
return &TaskFactory{}
|
||||
}
|
||||
|
||||
// CreateApiTaskQueue 创建API任务队列
|
||||
func (f *TaskFactory) CreateApiTaskQueue(redisAddr string, logger interface{}) interfaces.ApiTaskQueue {
|
||||
// 这里可以根据配置选择不同的实现
|
||||
// 目前使用Asynq实现
|
||||
return asynq.NewAsynqApiTaskQueue(redisAddr, logger.(*zap.Logger))
|
||||
}
|
||||
|
||||
// CreateArticleTaskQueue 创建文章任务队列
|
||||
func (f *TaskFactory) CreateArticleTaskQueue(redisAddr string, logger interface{}) interfaces.ArticleTaskQueue {
|
||||
// 这里可以根据配置选择不同的实现
|
||||
// 目前使用Asynq实现
|
||||
return asynq.NewAsynqArticleTaskQueue(redisAddr, logger.(*zap.Logger))
|
||||
}
|
||||
|
||||
// NewApiTaskQueue 创建API任务队列(包级别函数)
|
||||
func NewApiTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ApiTaskQueue {
|
||||
return asynq.NewAsynqApiTaskQueue(redisAddr, logger)
|
||||
}
|
||||
|
||||
// NewAsynqClient 创建Asynq客户端(包级别函数)
|
||||
func NewAsynqClient(redisAddr string, scheduledTaskRepo interface{}, logger *zap.Logger) *asynq.AsynqClient {
|
||||
return asynq.NewAsynqClient(redisAddr, logger)
|
||||
}
|
||||
|
||||
// NewArticleTaskQueue 创建文章任务队列(包级别函数)
|
||||
func NewArticleTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ArticleTaskQueue {
|
||||
return asynq.NewAsynqArticleTaskQueue(redisAddr, logger)
|
||||
}
|
||||
285
internal/infrastructure/task/handlers/api_task_handler.go
Normal file
285
internal/infrastructure/task/handlers/api_task_handler.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/api"
|
||||
finance_services "tyapi-server/internal/domains/finance/services"
|
||||
product_services "tyapi-server/internal/domains/product/services"
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/repositories"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// ApiTaskHandler API任务处理器
|
||||
type ApiTaskHandler struct {
|
||||
logger *zap.Logger
|
||||
apiApplicationService api.ApiApplicationService
|
||||
walletService finance_services.WalletAggregateService
|
||||
subscriptionService *product_services.ProductSubscriptionService
|
||||
asyncTaskRepo repositories.AsyncTaskRepository
|
||||
}
|
||||
|
||||
// NewApiTaskHandler 创建API任务处理器
|
||||
func NewApiTaskHandler(
|
||||
logger *zap.Logger,
|
||||
apiApplicationService api.ApiApplicationService,
|
||||
walletService finance_services.WalletAggregateService,
|
||||
subscriptionService *product_services.ProductSubscriptionService,
|
||||
asyncTaskRepo repositories.AsyncTaskRepository,
|
||||
) *ApiTaskHandler {
|
||||
return &ApiTaskHandler{
|
||||
logger: logger,
|
||||
apiApplicationService: apiApplicationService,
|
||||
walletService: walletService,
|
||||
subscriptionService: subscriptionService,
|
||||
asyncTaskRepo: asyncTaskRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleApiCall 处理API调用任务
|
||||
func (h *ApiTaskHandler) HandleApiCall(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理API调用任务")
|
||||
|
||||
var payload types.ApiCallPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析API调用任务载荷失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理API调用任务",
|
||||
zap.String("api_call_id", payload.ApiCallID),
|
||||
zap.String("user_id", payload.UserID),
|
||||
zap.String("product_id", payload.ProductID))
|
||||
|
||||
// 这里实现API调用的具体逻辑
|
||||
// 例如:记录API调用、更新使用统计等
|
||||
|
||||
// 更新任务状态为成功
|
||||
h.updateTaskStatus(ctx, t, "completed", "")
|
||||
h.logger.Info("API调用任务处理完成", zap.String("api_call_id", payload.ApiCallID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleDeduction 处理扣款任务
|
||||
func (h *ApiTaskHandler) HandleDeduction(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理扣款任务")
|
||||
|
||||
var payload types.DeductionPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析扣款任务载荷失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理扣款任务",
|
||||
zap.String("user_id", payload.UserID),
|
||||
zap.String("amount", payload.Amount),
|
||||
zap.String("transaction_id", payload.TransactionID))
|
||||
|
||||
// 调用钱包服务进行扣款
|
||||
if h.walletService != nil {
|
||||
amount, err := decimal.NewFromString(payload.Amount)
|
||||
if err != nil {
|
||||
h.logger.Error("金额格式错误", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "金额格式错误")
|
||||
return err
|
||||
}
|
||||
|
||||
if err := h.walletService.Deduct(ctx, payload.UserID, amount, payload.ApiCallID, payload.TransactionID, payload.ProductID); err != nil {
|
||||
h.logger.Error("扣款处理失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "扣款处理失败: "+err.Error())
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
h.logger.Warn("钱包服务未初始化,跳过扣款", zap.String("user_id", payload.UserID))
|
||||
h.updateTaskStatus(ctx, t, "failed", "钱包服务未初始化")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新任务状态为成功
|
||||
h.updateTaskStatus(ctx, t, "completed", "")
|
||||
h.logger.Info("扣款任务处理完成", zap.String("transaction_id", payload.TransactionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleCompensation 处理补偿任务
|
||||
func (h *ApiTaskHandler) HandleCompensation(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理补偿任务")
|
||||
|
||||
var payload types.CompensationPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析补偿任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理补偿任务",
|
||||
zap.String("transaction_id", payload.TransactionID),
|
||||
zap.String("type", payload.Type))
|
||||
|
||||
// 这里实现补偿的具体逻辑
|
||||
// 例如:调用钱包服务进行退款等
|
||||
|
||||
h.logger.Info("补偿任务处理完成", zap.String("transaction_id", payload.TransactionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleUsageStats 处理使用统计任务
|
||||
func (h *ApiTaskHandler) HandleUsageStats(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理使用统计任务")
|
||||
|
||||
var payload types.UsageStatsPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析使用统计任务载荷失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理使用统计任务",
|
||||
zap.String("subscription_id", payload.SubscriptionID),
|
||||
zap.String("user_id", payload.UserID),
|
||||
zap.Int("increment", payload.Increment))
|
||||
|
||||
// 调用订阅服务更新使用统计
|
||||
if h.subscriptionService != nil {
|
||||
if err := h.subscriptionService.IncrementSubscriptionAPIUsage(ctx, payload.SubscriptionID, int64(payload.Increment)); err != nil {
|
||||
h.logger.Error("更新使用统计失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "更新使用统计失败: "+err.Error())
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
h.logger.Warn("订阅服务未初始化,跳过使用统计更新", zap.String("subscription_id", payload.SubscriptionID))
|
||||
h.updateTaskStatus(ctx, t, "failed", "订阅服务未初始化")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新任务状态为成功
|
||||
h.updateTaskStatus(ctx, t, "completed", "")
|
||||
h.logger.Info("使用统计任务处理完成", zap.String("subscription_id", payload.SubscriptionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleApiLog 处理API日志任务
|
||||
func (h *ApiTaskHandler) HandleApiLog(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理API日志任务")
|
||||
|
||||
var payload types.ApiLogPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析API日志任务载荷失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理API日志任务",
|
||||
zap.String("transaction_id", payload.TransactionID),
|
||||
zap.String("user_id", payload.UserID),
|
||||
zap.String("api_name", payload.ApiName),
|
||||
zap.String("product_id", payload.ProductID))
|
||||
|
||||
// 记录结构化日志
|
||||
h.logger.Info("API调用日志",
|
||||
zap.String("transaction_id", payload.TransactionID),
|
||||
zap.String("user_id", payload.UserID),
|
||||
zap.String("api_name", payload.ApiName),
|
||||
zap.String("product_id", payload.ProductID),
|
||||
zap.Time("timestamp", time.Now()))
|
||||
|
||||
// 这里可以添加其他日志记录逻辑
|
||||
// 例如:写入专门的日志文件、发送到日志系统、写入数据库等
|
||||
|
||||
// 更新任务状态为成功
|
||||
h.updateTaskStatus(ctx, t, "completed", "")
|
||||
h.logger.Info("API日志任务处理完成", zap.String("transaction_id", payload.TransactionID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateTaskStatus 更新任务状态
|
||||
func (h *ApiTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task, status string, errorMsg string) {
|
||||
// 从任务载荷中提取任务ID
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析任务载荷失败,无法更新状态", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试从payload中获取任务ID
|
||||
taskID, ok := payload["task_id"].(string)
|
||||
if !ok {
|
||||
h.logger.Error("无法从任务载荷中获取任务ID")
|
||||
return
|
||||
}
|
||||
|
||||
// 根据状态决定更新方式
|
||||
if status == "failed" {
|
||||
// 失败时:需要检查是否达到最大重试次数
|
||||
h.handleTaskFailure(ctx, taskID, errorMsg)
|
||||
} else if status == "completed" {
|
||||
// 成功时:清除错误信息并更新状态
|
||||
if err := h.asyncTaskRepo.UpdateStatusWithSuccess(ctx, taskID, entities.TaskStatus(status)); err != nil {
|
||||
h.logger.Error("更新任务状态失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", status),
|
||||
zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
// 其他状态:只更新状态
|
||||
if err := h.asyncTaskRepo.UpdateStatus(ctx, taskID, entities.TaskStatus(status)); err != nil {
|
||||
h.logger.Error("更新任务状态失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", status),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("任务状态已更新",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", status),
|
||||
zap.String("error_msg", errorMsg))
|
||||
}
|
||||
|
||||
// handleTaskFailure 处理任务失败
|
||||
func (h *ApiTaskHandler) handleTaskFailure(ctx context.Context, taskID string, errorMsg string) {
|
||||
// 获取当前任务信息
|
||||
task, err := h.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取任务信息失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 增加重试次数
|
||||
newRetryCount := task.RetryCount + 1
|
||||
|
||||
// 检查是否达到最大重试次数
|
||||
if newRetryCount >= task.MaxRetries {
|
||||
// 达到最大重试次数,标记为最终失败
|
||||
if err := h.asyncTaskRepo.UpdateStatusWithRetryAndError(ctx, taskID, entities.TaskStatusFailed, errorMsg); err != nil {
|
||||
h.logger.Error("更新任务状态失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", "failed"),
|
||||
zap.Error(err))
|
||||
}
|
||||
h.logger.Info("任务最终失败,已达到最大重试次数",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", newRetryCount),
|
||||
zap.Int("max_retries", task.MaxRetries))
|
||||
} else {
|
||||
// 未达到最大重试次数,保持pending状态,记录错误信息
|
||||
if err := h.asyncTaskRepo.UpdateRetryCountAndError(ctx, taskID, newRetryCount, errorMsg); err != nil {
|
||||
h.logger.Error("更新任务重试次数失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", newRetryCount),
|
||||
zap.Error(err))
|
||||
}
|
||||
h.logger.Info("任务失败,准备重试",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", newRetryCount),
|
||||
zap.Int("max_retries", task.MaxRetries))
|
||||
}
|
||||
}
|
||||
304
internal/infrastructure/task/handlers/article_task_handler.go
Normal file
304
internal/infrastructure/task/handlers/article_task_handler.go
Normal file
@@ -0,0 +1,304 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/article"
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/repositories"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// ArticleTaskHandler 文章任务处理器
|
||||
type ArticleTaskHandler struct {
|
||||
logger *zap.Logger
|
||||
articleApplicationService article.ArticleApplicationService
|
||||
asyncTaskRepo repositories.AsyncTaskRepository
|
||||
}
|
||||
|
||||
// NewArticleTaskHandler 创建文章任务处理器
|
||||
func NewArticleTaskHandler(logger *zap.Logger, articleApplicationService article.ArticleApplicationService, asyncTaskRepo repositories.AsyncTaskRepository) *ArticleTaskHandler {
|
||||
return &ArticleTaskHandler{
|
||||
logger: logger,
|
||||
articleApplicationService: articleApplicationService,
|
||||
asyncTaskRepo: asyncTaskRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleArticlePublish 处理文章发布任务
|
||||
func (h *ArticleTaskHandler) HandleArticlePublish(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理文章发布任务")
|
||||
|
||||
var payload ArticlePublishPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析文章发布任务载荷失败", zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理文章发布任务",
|
||||
zap.String("article_id", payload.ArticleID),
|
||||
zap.Time("publish_at", payload.PublishAt))
|
||||
|
||||
// 检查任务是否已被取消
|
||||
if err := h.checkTaskStatus(ctx, t); err != nil {
|
||||
h.logger.Info("任务已被取消,跳过执行", zap.String("article_id", payload.ArticleID))
|
||||
return nil // 静默返回,不报错
|
||||
}
|
||||
|
||||
// 调用文章应用服务发布文章
|
||||
if h.articleApplicationService != nil {
|
||||
err := h.articleApplicationService.PublishArticleByID(ctx, payload.ArticleID)
|
||||
if err != nil {
|
||||
h.logger.Error("文章发布失败", zap.String("article_id", payload.ArticleID), zap.Error(err))
|
||||
h.updateTaskStatus(ctx, t, "failed", "文章发布失败: "+err.Error())
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
h.logger.Warn("文章应用服务未初始化,跳过发布", zap.String("article_id", payload.ArticleID))
|
||||
h.updateTaskStatus(ctx, t, "failed", "文章应用服务未初始化")
|
||||
return nil
|
||||
}
|
||||
|
||||
// 更新任务状态为成功
|
||||
h.updateTaskStatus(ctx, t, "completed", "")
|
||||
h.logger.Info("文章发布任务处理完成", zap.String("article_id", payload.ArticleID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleArticleCancel 处理文章取消任务
|
||||
func (h *ArticleTaskHandler) HandleArticleCancel(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理文章取消任务")
|
||||
|
||||
var payload ArticleCancelPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析文章取消任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理文章取消任务", zap.String("article_id", payload.ArticleID))
|
||||
|
||||
// 这里实现文章取消的具体逻辑
|
||||
// 例如:更新文章状态、取消定时发布等
|
||||
|
||||
h.logger.Info("文章取消任务处理完成", zap.String("article_id", payload.ArticleID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleArticleModify 处理文章修改任务
|
||||
func (h *ArticleTaskHandler) HandleArticleModify(ctx context.Context, t *asynq.Task) error {
|
||||
h.logger.Info("开始处理文章修改任务")
|
||||
|
||||
var payload ArticleModifyPayload
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析文章修改任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
h.logger.Info("处理文章修改任务",
|
||||
zap.String("article_id", payload.ArticleID),
|
||||
zap.Time("new_publish_at", payload.NewPublishAt))
|
||||
|
||||
// 这里实现文章修改的具体逻辑
|
||||
// 例如:更新文章发布时间、重新调度任务等
|
||||
|
||||
h.logger.Info("文章修改任务处理完成", zap.String("article_id", payload.ArticleID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ArticlePublishPayload 文章发布任务载荷
|
||||
type ArticlePublishPayload struct {
|
||||
ArticleID string `json:"article_id"`
|
||||
PublishAt time.Time `json:"publish_at"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ArticlePublishPayload) GetType() types.TaskType {
|
||||
return types.TaskTypeArticlePublish
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ArticlePublishPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ArticlePublishPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// ArticleCancelPayload 文章取消任务载荷
|
||||
type ArticleCancelPayload struct {
|
||||
ArticleID string `json:"article_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ArticleCancelPayload) GetType() types.TaskType {
|
||||
return types.TaskTypeArticleCancel
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ArticleCancelPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ArticleCancelPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// ArticleModifyPayload 文章修改任务载荷
|
||||
type ArticleModifyPayload struct {
|
||||
ArticleID string `json:"article_id"`
|
||||
NewPublishAt time.Time `json:"new_publish_at"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ArticleModifyPayload) GetType() types.TaskType {
|
||||
return types.TaskTypeArticleModify
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ArticleModifyPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ArticleModifyPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// updateTaskStatus 更新任务状态
|
||||
func (h *ArticleTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task, status string, errorMsg string) {
|
||||
// 从任务载荷中提取任务ID
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析任务载荷失败,无法更新状态", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试从payload中获取任务ID
|
||||
taskID, ok := payload["task_id"].(string)
|
||||
if !ok {
|
||||
// 如果没有task_id,尝试从article_id生成
|
||||
if articleID, ok := payload["article_id"].(string); ok {
|
||||
taskID = fmt.Sprintf("article-publish-%s", articleID)
|
||||
} else {
|
||||
h.logger.Error("无法从任务载荷中获取任务ID")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 根据状态决定更新方式
|
||||
if status == "failed" {
|
||||
// 失败时:需要检查是否达到最大重试次数
|
||||
h.handleTaskFailure(ctx, taskID, errorMsg)
|
||||
} else if status == "completed" {
|
||||
// 成功时:清除错误信息并更新状态
|
||||
if err := h.asyncTaskRepo.UpdateStatusWithSuccess(ctx, taskID, entities.TaskStatus(status)); err != nil {
|
||||
h.logger.Error("更新任务状态失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", status),
|
||||
zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
// 其他状态:只更新状态
|
||||
if err := h.asyncTaskRepo.UpdateStatus(ctx, taskID, entities.TaskStatus(status)); err != nil {
|
||||
h.logger.Error("更新任务状态失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", status),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
h.logger.Info("任务状态已更新",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", status),
|
||||
zap.String("error_msg", errorMsg))
|
||||
}
|
||||
|
||||
// handleTaskFailure 处理任务失败
|
||||
func (h *ArticleTaskHandler) handleTaskFailure(ctx context.Context, taskID string, errorMsg string) {
|
||||
// 获取当前任务信息
|
||||
task, err := h.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取任务信息失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// 增加重试次数
|
||||
newRetryCount := task.RetryCount + 1
|
||||
|
||||
// 检查是否达到最大重试次数
|
||||
if newRetryCount >= task.MaxRetries {
|
||||
// 达到最大重试次数,标记为最终失败
|
||||
if err := h.asyncTaskRepo.UpdateStatusWithRetryAndError(ctx, taskID, entities.TaskStatusFailed, errorMsg); err != nil {
|
||||
h.logger.Error("更新任务状态失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.String("status", "failed"),
|
||||
zap.Error(err))
|
||||
}
|
||||
h.logger.Info("任务最终失败,已达到最大重试次数",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", newRetryCount),
|
||||
zap.Int("max_retries", task.MaxRetries))
|
||||
} else {
|
||||
// 未达到最大重试次数,保持pending状态,记录错误信息
|
||||
if err := h.asyncTaskRepo.UpdateRetryCountAndError(ctx, taskID, newRetryCount, errorMsg); err != nil {
|
||||
h.logger.Error("更新任务重试次数失败",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", newRetryCount),
|
||||
zap.Error(err))
|
||||
}
|
||||
h.logger.Info("任务失败,准备重试",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", newRetryCount),
|
||||
zap.Int("max_retries", task.MaxRetries))
|
||||
}
|
||||
}
|
||||
|
||||
// checkTaskStatus 检查任务状态
|
||||
func (h *ArticleTaskHandler) checkTaskStatus(ctx context.Context, t *asynq.Task) error {
|
||||
// 从任务载荷中提取任务ID
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
|
||||
h.logger.Error("解析任务载荷失败,无法检查状态", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 尝试从payload中获取任务ID
|
||||
taskID, ok := payload["task_id"].(string)
|
||||
if !ok {
|
||||
// 如果没有task_id,尝试从article_id生成
|
||||
if articleID, ok := payload["article_id"].(string); ok {
|
||||
taskID = fmt.Sprintf("article-publish-%s", articleID)
|
||||
} else {
|
||||
h.logger.Error("无法从任务载荷中获取任务ID")
|
||||
return fmt.Errorf("无法获取任务ID")
|
||||
}
|
||||
}
|
||||
|
||||
// 查询任务状态
|
||||
task, err := h.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
h.logger.Error("查询任务状态失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查任务是否已被取消
|
||||
if task.Status == entities.TaskStatusCancelled {
|
||||
h.logger.Info("任务已被取消", zap.String("task_id", taskID))
|
||||
return fmt.Errorf("任务已被取消")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,126 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqApiTaskQueue Asynq API任务队列实现
|
||||
type AsynqApiTaskQueue struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqApiTaskQueue 创建Asynq API任务队列
|
||||
func NewAsynqApiTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ApiTaskQueue {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqApiTaskQueue{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队任务
|
||||
func (q *AsynqApiTaskQueue) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task)
|
||||
if err != nil {
|
||||
q.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
func (q *AsynqApiTaskQueue) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
|
||||
if err != nil {
|
||||
q.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
func (q *AsynqApiTaskQueue) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
|
||||
if err != nil {
|
||||
q.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel 取消任务
|
||||
func (q *AsynqApiTaskQueue) Cancel(ctx context.Context, taskID string) error {
|
||||
// Asynq本身不支持直接取消任务,这里返回错误提示
|
||||
return fmt.Errorf("Asynq不支持直接取消任务,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// ModifySchedule 修改任务调度时间
|
||||
func (q *AsynqApiTaskQueue) ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||||
// Asynq本身不支持修改调度时间,这里返回错误提示
|
||||
return fmt.Errorf("Asynq不支持修改任务调度时间,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
func (q *AsynqApiTaskQueue) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务状态查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务状态查询,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// ListTasks 列出任务
|
||||
func (q *AsynqApiTaskQueue) ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务列表查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务列表查询,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// EnqueueTask 入队任务
|
||||
func (q *AsynqApiTaskQueue) EnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
|
||||
// 创建Asynq任务
|
||||
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
|
||||
|
||||
// 入队任务
|
||||
_, err := q.client.EnqueueContext(ctx, asynqTask)
|
||||
if err != nil {
|
||||
q.logger.Error("入队任务失败", zap.String("task_id", task.ID), zap.String("task_type", task.Type), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("入队任务成功", zap.String("task_id", task.ID), zap.String("task_type", task.Type))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqArticleTaskQueue Asynq文章任务队列实现
|
||||
type AsynqArticleTaskQueue struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqArticleTaskQueue 创建Asynq文章任务队列
|
||||
func NewAsynqArticleTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ArticleTaskQueue {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqArticleTaskQueue{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队任务
|
||||
func (q *AsynqArticleTaskQueue) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task)
|
||||
if err != nil {
|
||||
q.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
func (q *AsynqArticleTaskQueue) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
|
||||
if err != nil {
|
||||
q.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
func (q *AsynqArticleTaskQueue) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
q.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
|
||||
if err != nil {
|
||||
q.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
q.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Cancel 取消任务
|
||||
func (q *AsynqArticleTaskQueue) Cancel(ctx context.Context, taskID string) error {
|
||||
// Asynq本身不支持直接取消任务,但我们可以通过以下方式实现:
|
||||
// 1. 在数据库中标记任务为已取消
|
||||
// 2. 任务执行时检查状态,如果已取消则跳过执行
|
||||
|
||||
q.logger.Info("标记任务为已取消", zap.String("task_id", taskID))
|
||||
|
||||
// 这里应该更新数据库中的任务状态为cancelled
|
||||
// 由于我们没有直接访问repository,暂时只记录日志
|
||||
// 实际实现中应该调用AsyncTaskRepository.UpdateStatus
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ModifySchedule 修改任务调度时间
|
||||
func (q *AsynqArticleTaskQueue) ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||||
// Asynq本身不支持修改调度时间,但我们可以通过以下方式实现:
|
||||
// 1. 取消旧任务
|
||||
// 2. 创建新任务
|
||||
|
||||
q.logger.Info("修改任务调度时间",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Time("new_scheduled_at", newScheduledAt))
|
||||
|
||||
// 这里应该:
|
||||
// 1. 调用Cancel取消旧任务
|
||||
// 2. 根据任务类型重新创建任务
|
||||
// 由于没有直接访问repository,暂时只记录日志
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
func (q *AsynqArticleTaskQueue) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务状态查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务状态查询,请使用数据库状态管理")
|
||||
}
|
||||
|
||||
// ListTasks 列出任务
|
||||
func (q *AsynqArticleTaskQueue) ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
// Asynq本身不提供任务列表查询,这里返回错误提示
|
||||
return nil, fmt.Errorf("Asynq不提供任务列表查询,请使用数据库状态管理")
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqClient Asynq客户端实现
|
||||
type AsynqClient struct {
|
||||
client *asynq.Client
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqClient 创建Asynq客户端
|
||||
func NewAsynqClient(redisAddr string, logger *zap.Logger) *AsynqClient {
|
||||
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
|
||||
return &AsynqClient{
|
||||
client: client,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue 入队任务
|
||||
func (c *AsynqClient) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = c.client.EnqueueContext(ctx, task)
|
||||
if err != nil {
|
||||
c.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
func (c *AsynqClient) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = c.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
|
||||
if err != nil {
|
||||
c.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
|
||||
return nil
|
||||
}
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
func (c *AsynqClient) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
|
||||
payloadData, err := payload.ToJSON()
|
||||
if err != nil {
|
||||
c.logger.Error("序列化任务载荷失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
task := asynq.NewTask(string(taskType), payloadData)
|
||||
_, err = c.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
|
||||
if err != nil {
|
||||
c.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
c.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close 关闭客户端
|
||||
func (c *AsynqClient) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
package asynq
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/api"
|
||||
"tyapi-server/internal/application/article"
|
||||
finance_services "tyapi-server/internal/domains/finance/services"
|
||||
product_services "tyapi-server/internal/domains/product/services"
|
||||
"tyapi-server/internal/infrastructure/task/handlers"
|
||||
"tyapi-server/internal/infrastructure/task/repositories"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsynqWorker Asynq Worker实现
|
||||
type AsynqWorker struct {
|
||||
server *asynq.Server
|
||||
mux *asynq.ServeMux
|
||||
logger *zap.Logger
|
||||
articleHandler *handlers.ArticleTaskHandler
|
||||
apiHandler *handlers.ApiTaskHandler
|
||||
}
|
||||
|
||||
// NewAsynqWorker 创建Asynq Worker
|
||||
func NewAsynqWorker(
|
||||
redisAddr string,
|
||||
logger *zap.Logger,
|
||||
articleApplicationService article.ArticleApplicationService,
|
||||
apiApplicationService api.ApiApplicationService,
|
||||
walletService finance_services.WalletAggregateService,
|
||||
subscriptionService *product_services.ProductSubscriptionService,
|
||||
asyncTaskRepo repositories.AsyncTaskRepository,
|
||||
) *AsynqWorker {
|
||||
server := asynq.NewServer(
|
||||
asynq.RedisClientOpt{Addr: redisAddr},
|
||||
asynq.Config{
|
||||
Concurrency: 6, // 降低总并发数
|
||||
Queues: map[string]int{
|
||||
"default": 2, // 2个goroutine
|
||||
"api": 3, // 3个goroutine (扣款任务)
|
||||
"article": 1, // 1个goroutine
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
// 创建任务处理器
|
||||
articleHandler := handlers.NewArticleTaskHandler(logger, articleApplicationService, asyncTaskRepo)
|
||||
apiHandler := handlers.NewApiTaskHandler(logger, apiApplicationService, walletService, subscriptionService, asyncTaskRepo)
|
||||
|
||||
// 创建ServeMux
|
||||
mux := asynq.NewServeMux()
|
||||
|
||||
return &AsynqWorker{
|
||||
server: server,
|
||||
mux: mux,
|
||||
logger: logger,
|
||||
articleHandler: articleHandler,
|
||||
apiHandler: apiHandler,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler 注册任务处理器
|
||||
func (w *AsynqWorker) RegisterHandler(taskType types.TaskType, handler func(context.Context, *asynq.Task) error) {
|
||||
// 简化实现,避免API兼容性问题
|
||||
w.logger.Info("注册任务处理器", zap.String("task_type", string(taskType)))
|
||||
}
|
||||
|
||||
// Start 启动Worker
|
||||
func (w *AsynqWorker) Start() error {
|
||||
w.logger.Info("启动Asynq Worker")
|
||||
|
||||
// 注册所有任务处理器
|
||||
w.registerAllHandlers()
|
||||
|
||||
// 启动Worker服务器
|
||||
go func() {
|
||||
if err := w.server.Run(w.mux); err != nil {
|
||||
w.logger.Error("Worker运行失败", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
w.logger.Info("Asynq Worker启动成功")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop 停止Worker
|
||||
func (w *AsynqWorker) Stop() {
|
||||
w.logger.Info("停止Asynq Worker")
|
||||
w.server.Stop()
|
||||
}
|
||||
|
||||
// Shutdown 优雅关闭Worker
|
||||
func (w *AsynqWorker) Shutdown() {
|
||||
w.logger.Info("优雅关闭Asynq Worker")
|
||||
w.server.Shutdown()
|
||||
}
|
||||
|
||||
// registerAllHandlers 注册所有任务处理器
|
||||
func (w *AsynqWorker) registerAllHandlers() {
|
||||
// 注册文章任务处理器
|
||||
w.mux.HandleFunc(string(types.TaskTypeArticlePublish), w.articleHandler.HandleArticlePublish)
|
||||
w.mux.HandleFunc(string(types.TaskTypeArticleCancel), w.articleHandler.HandleArticleCancel)
|
||||
w.mux.HandleFunc(string(types.TaskTypeArticleModify), w.articleHandler.HandleArticleModify)
|
||||
|
||||
// 注册API任务处理器
|
||||
w.mux.HandleFunc(string(types.TaskTypeApiCall), w.apiHandler.HandleApiCall)
|
||||
w.mux.HandleFunc(string(types.TaskTypeApiLog), w.apiHandler.HandleApiLog)
|
||||
w.mux.HandleFunc(string(types.TaskTypeDeduction), w.apiHandler.HandleDeduction)
|
||||
w.mux.HandleFunc(string(types.TaskTypeCompensation), w.apiHandler.HandleCompensation)
|
||||
w.mux.HandleFunc(string(types.TaskTypeUsageStats), w.apiHandler.HandleUsageStats)
|
||||
|
||||
w.logger.Info("所有任务处理器注册完成",
|
||||
zap.String("article_publish", string(types.TaskTypeArticlePublish)),
|
||||
zap.String("article_cancel", string(types.TaskTypeArticleCancel)),
|
||||
zap.String("article_modify", string(types.TaskTypeArticleModify)),
|
||||
zap.String("api_call", string(types.TaskTypeApiCall)),
|
||||
zap.String("api_log", string(types.TaskTypeApiLog)),
|
||||
)
|
||||
}
|
||||
374
internal/infrastructure/task/implementations/task_manager.go
Normal file
374
internal/infrastructure/task/implementations/task_manager.go
Normal file
@@ -0,0 +1,374 @@
|
||||
package implementations
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/interfaces"
|
||||
"tyapi-server/internal/infrastructure/task/repositories"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// TaskManagerImpl 任务管理器实现
|
||||
type TaskManagerImpl struct {
|
||||
asynqClient *asynq.Client
|
||||
asyncTaskRepo repositories.AsyncTaskRepository
|
||||
logger *zap.Logger
|
||||
config *interfaces.TaskManagerConfig
|
||||
}
|
||||
|
||||
// NewTaskManager 创建任务管理器
|
||||
func NewTaskManager(
|
||||
asynqClient *asynq.Client,
|
||||
asyncTaskRepo repositories.AsyncTaskRepository,
|
||||
logger *zap.Logger,
|
||||
config *interfaces.TaskManagerConfig,
|
||||
) interfaces.TaskManager {
|
||||
return &TaskManagerImpl{
|
||||
asynqClient: asynqClient,
|
||||
asyncTaskRepo: asyncTaskRepo,
|
||||
logger: logger,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAndEnqueueTask 创建并入队任务
|
||||
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
|
||||
// 1. 保存任务到数据库(GORM会自动生成UUID)
|
||||
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
|
||||
tm.logger.Error("保存任务到数据库失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("保存任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 更新payload中的task_id
|
||||
if err := tm.updatePayloadTaskID(task); err != nil {
|
||||
tm.logger.Error("更新payload中的任务ID失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新数据库中的payload
|
||||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||||
tm.logger.Error("更新任务payload失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新任务payload失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 入队到Asynq
|
||||
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
|
||||
// 如果入队失败,更新任务状态为失败
|
||||
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
|
||||
return fmt.Errorf("任务入队失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务创建并入队成功",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.Type))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateAndEnqueueDelayedTask 创建并入队延时任务
|
||||
func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
|
||||
// 1. 设置调度时间
|
||||
scheduledAt := time.Now().Add(delay)
|
||||
task.ScheduledAt = &scheduledAt
|
||||
|
||||
// 2. 保存任务到数据库(GORM会自动生成UUID)
|
||||
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
|
||||
tm.logger.Error("保存延时任务到数据库失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("保存延时任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 3. 更新payload中的task_id
|
||||
if err := tm.updatePayloadTaskID(task); err != nil {
|
||||
tm.logger.Error("更新payload中的任务ID失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 4. 更新数据库中的payload
|
||||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||||
tm.logger.Error("更新任务payload失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新任务payload失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 入队到Asynq延时队列
|
||||
if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
|
||||
// 如果入队失败,更新任务状态为失败
|
||||
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
|
||||
return fmt.Errorf("延时任务入队失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("延时任务创建并入队成功",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.Type),
|
||||
zap.Duration("delay", delay))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CancelTask 取消任务
|
||||
func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
|
||||
task, err := tm.findTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
|
||||
tm.logger.Error("更新任务状态为取消失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("更新任务状态失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务已标记为取消",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("task_type", task.Type))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateTaskSchedule 更新任务调度时间
|
||||
func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
|
||||
// 1. 查找任务
|
||||
task, err := tm.findTask(ctx, taskID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.logger.Info("找到要更新的任务",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.String("current_status", string(task.Status)),
|
||||
zap.Time("current_scheduled_at", *task.ScheduledAt))
|
||||
|
||||
// 2. 取消旧任务
|
||||
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
|
||||
tm.logger.Error("取消旧任务失败",
|
||||
zap.String("task_id", task.ID),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("取消旧任务失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
|
||||
|
||||
// 3. 创建并保存新任务
|
||||
newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tm.logger.Info("新任务已创建",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Time("new_scheduled_at", newScheduledAt))
|
||||
|
||||
// 4. 计算延时并入队
|
||||
delay := newScheduledAt.Sub(time.Now())
|
||||
if delay < 0 {
|
||||
delay = 0 // 如果时间已过,立即执行
|
||||
}
|
||||
|
||||
if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
|
||||
// 如果入队失败,删除新创建的任务记录
|
||||
tm.asyncTaskRepo.Delete(ctx, newTask.ID)
|
||||
return fmt.Errorf("重新入队任务失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务调度时间更新成功",
|
||||
zap.String("old_task_id", task.ID),
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Time("new_scheduled_at", newScheduledAt))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
return tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
}
|
||||
|
||||
// UpdateTaskStatus 更新任务状态
|
||||
func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
|
||||
if errorMsg != "" {
|
||||
return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
|
||||
}
|
||||
return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
|
||||
}
|
||||
|
||||
// RetryTask 重试任务
|
||||
func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
|
||||
// 1. 获取任务信息
|
||||
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("获取任务信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 2. 检查是否可以重试
|
||||
if !task.CanRetry() {
|
||||
return fmt.Errorf("任务已达到最大重试次数")
|
||||
}
|
||||
|
||||
// 3. 增加重试次数并重置状态
|
||||
task.RetryCount++
|
||||
task.Status = entities.TaskStatusPending
|
||||
|
||||
// 4. 更新数据库
|
||||
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
|
||||
return fmt.Errorf("更新任务重试次数失败: %w", err)
|
||||
}
|
||||
|
||||
// 5. 重新入队
|
||||
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
|
||||
return fmt.Errorf("重试任务入队失败: %w", err)
|
||||
}
|
||||
|
||||
tm.logger.Info("任务重试成功",
|
||||
zap.String("task_id", taskID),
|
||||
zap.Int("retry_count", task.RetryCount))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpiredTasks 清理过期任务
|
||||
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
|
||||
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
|
||||
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
|
||||
|
||||
// TODO: 实现清理逻辑
|
||||
return nil
|
||||
}
|
||||
|
||||
// updatePayloadTaskID 更新payload中的task_id
|
||||
func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
|
||||
// 解析payload
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
|
||||
return fmt.Errorf("解析payload失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新task_id
|
||||
payload["task_id"] = task.ID
|
||||
|
||||
// 重新序列化
|
||||
newPayload, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化payload失败: %w", err)
|
||||
}
|
||||
|
||||
task.Payload = string(newPayload)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// findTask 查找任务(支持taskID和articleID双重查找)
|
||||
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
|
||||
// 先尝试通过任务ID查找
|
||||
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
|
||||
if err == nil {
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// 如果通过任务ID找不到,尝试通过文章ID查找
|
||||
tm.logger.Info("通过任务ID查找失败,尝试通过文章ID查找", zap.String("task_id", taskID))
|
||||
|
||||
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
|
||||
if err != nil || len(tasks) == 0 {
|
||||
tm.logger.Error("通过文章ID也找不到任务",
|
||||
zap.String("article_id", taskID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("获取任务信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 使用找到的第一个任务
|
||||
task = tasks[0]
|
||||
tm.logger.Info("通过文章ID找到任务",
|
||||
zap.String("article_id", taskID),
|
||||
zap.String("task_id", task.ID))
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// createAndSaveTask 创建并保存新任务
|
||||
func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
|
||||
// 创建新任务
|
||||
newTask := &entities.AsyncTask{
|
||||
Type: originalTask.Type,
|
||||
Payload: originalTask.Payload,
|
||||
Status: entities.TaskStatusPending,
|
||||
ScheduledAt: &newScheduledAt,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 保存到数据库(GORM会自动生成UUID)
|
||||
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
|
||||
tm.logger.Error("创建新任务失败",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("创建新任务失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新payload中的task_id
|
||||
if err := tm.updatePayloadTaskID(newTask); err != nil {
|
||||
tm.logger.Error("更新payload中的任务ID失败",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库中的payload
|
||||
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
|
||||
tm.logger.Error("更新新任务payload失败",
|
||||
zap.String("new_task_id", newTask.ID),
|
||||
zap.Error(err))
|
||||
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
|
||||
}
|
||||
|
||||
return newTask, nil
|
||||
}
|
||||
|
||||
// enqueueTaskWithDelay 入队任务到Asynq(支持延时)
|
||||
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
|
||||
queueName := tm.getQueueName(task.Type)
|
||||
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
|
||||
|
||||
var err error
|
||||
if delay > 0 {
|
||||
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
|
||||
asynq.Queue(queueName),
|
||||
asynq.ProcessIn(delay))
|
||||
} else {
|
||||
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// getQueueName 根据任务类型获取队列名称
|
||||
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
|
||||
switch taskType {
|
||||
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
|
||||
return "article"
|
||||
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
|
||||
return "api"
|
||||
case string(types.TaskTypeCompensation):
|
||||
return "finance"
|
||||
default:
|
||||
return "default"
|
||||
}
|
||||
}
|
||||
35
internal/infrastructure/task/interfaces/api_task_queue.go
Normal file
35
internal/infrastructure/task/interfaces/api_task_queue.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// ApiTaskQueue API任务队列接口
|
||||
type ApiTaskQueue interface {
|
||||
// Enqueue 入队任务
|
||||
Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error
|
||||
|
||||
// Cancel 取消任务
|
||||
Cancel(ctx context.Context, taskID string) error
|
||||
|
||||
// ModifySchedule 修改任务调度时间
|
||||
ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error)
|
||||
|
||||
// ListTasks 列出任务
|
||||
ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||||
|
||||
// EnqueueTask 入队任务(简化版本)
|
||||
EnqueueTask(ctx context.Context, task *entities.AsyncTask) error
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// ArticleTaskQueue 文章任务队列接口
|
||||
type ArticleTaskQueue interface {
|
||||
// Enqueue 入队任务
|
||||
Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error
|
||||
|
||||
// EnqueueDelayed 延时入队任务
|
||||
EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error
|
||||
|
||||
// EnqueueAt 指定时间入队任务
|
||||
EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error
|
||||
|
||||
// Cancel 取消任务
|
||||
Cancel(ctx context.Context, taskID string) error
|
||||
|
||||
// ModifySchedule 修改任务调度时间
|
||||
ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error
|
||||
|
||||
// GetTaskStatus 获取任务状态
|
||||
GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error)
|
||||
|
||||
// ListTasks 列出任务
|
||||
ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||||
}
|
||||
44
internal/infrastructure/task/interfaces/task_manager.go
Normal file
44
internal/infrastructure/task/interfaces/task_manager.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package interfaces
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
)
|
||||
|
||||
// TaskManager 任务管理器接口
|
||||
// 统一管理Asynq任务和AsyncTask实体的操作
|
||||
type TaskManager interface {
|
||||
// 创建并入队任务
|
||||
CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error
|
||||
|
||||
// 创建并入队延时任务
|
||||
CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error
|
||||
|
||||
// 取消任务
|
||||
CancelTask(ctx context.Context, taskID string) error
|
||||
|
||||
// 更新任务调度时间
|
||||
UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error
|
||||
|
||||
// 获取任务状态
|
||||
GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error)
|
||||
|
||||
// 更新任务状态
|
||||
UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error
|
||||
|
||||
// 重试任务
|
||||
RetryTask(ctx context.Context, taskID string) error
|
||||
|
||||
// 清理过期任务
|
||||
CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error
|
||||
}
|
||||
|
||||
// TaskManagerConfig 任务管理器配置
|
||||
type TaskManagerConfig struct {
|
||||
RedisAddr string
|
||||
MaxRetries int
|
||||
RetryInterval time.Duration
|
||||
CleanupDays int
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsyncTaskRepository 异步任务仓库接口
|
||||
type AsyncTaskRepository interface {
|
||||
// 基础CRUD操作
|
||||
Create(ctx context.Context, task *entities.AsyncTask) error
|
||||
GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
|
||||
Update(ctx context.Context, task *entities.AsyncTask) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// 查询操作
|
||||
ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
|
||||
ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||||
ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||||
ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
|
||||
|
||||
// 状态更新操作
|
||||
UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
|
||||
UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
||||
UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
||||
UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
|
||||
UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
|
||||
UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
|
||||
IncrementRetryCount(ctx context.Context, id string) error
|
||||
|
||||
// 批量操作
|
||||
UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
|
||||
DeleteBatch(ctx context.Context, ids []string) error
|
||||
|
||||
// 文章任务专用方法
|
||||
GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
|
||||
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
|
||||
CancelArticlePublishTask(ctx context.Context, articleID string) error
|
||||
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
|
||||
}
|
||||
|
||||
// AsyncTaskRepositoryImpl 异步任务仓库实现
|
||||
type AsyncTaskRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAsyncTaskRepository 创建异步任务仓库
|
||||
func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
|
||||
return &AsyncTaskRepositoryImpl{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建任务
|
||||
func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
|
||||
return r.db.WithContext(ctx).Create(task).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取任务
|
||||
func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
|
||||
var task entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// Update 更新任务
|
||||
func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
|
||||
return r.db.WithContext(ctx).Save(task).Error
|
||||
}
|
||||
|
||||
// Delete 删除任务
|
||||
func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
|
||||
}
|
||||
|
||||
// ListByType 根据类型列出任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
query := r.db.WithContext(ctx).Where("type = ?", taskType)
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
err := query.Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态列出任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
query := r.db.WithContext(ctx).Where("status = ?", status)
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
err := query.Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ListByTypeAndStatus 根据类型和状态列出任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
err := query.Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ListScheduledTasks 列出已到期的调度任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
|
||||
Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新任务状态
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatusWithError 更新任务状态并记录错误
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"error_msg": errorMsg,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"error_msg": errorMsg,
|
||||
"retry_count": gorm.Expr("retry_count + 1"),
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"error_msg": "", // 清除错误信息
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateRetryCountAndError 更新重试次数和错误信息,保持pending状态
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"retry_count": retryCount,
|
||||
"error_msg": errorMsg,
|
||||
"updated_at": time.Now(),
|
||||
// 注意:不更新status,保持pending状态
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateScheduledAt 更新任务调度时间
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Update("scheduled_at", scheduledAt).Error
|
||||
}
|
||||
|
||||
// IncrementRetryCount 增加重试次数
|
||||
func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Update("retry_count", gorm.Expr("retry_count + 1")).Error
|
||||
}
|
||||
|
||||
// UpdateStatusBatch 批量更新状态
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id IN ?", ids).
|
||||
Update("status", status).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除
|
||||
func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Where("id IN ?", ids).
|
||||
Delete(&entities.AsyncTask{}).Error
|
||||
}
|
||||
|
||||
// GetArticlePublishTask 获取文章发布任务
|
||||
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
|
||||
var task entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||||
types.TaskTypeArticlePublish,
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
First(&task).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// GetByArticleID 根据文章ID获取所有相关任务
|
||||
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("payload LIKE ? AND status IN ?",
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
Find(&tasks).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// CancelArticlePublishTask 取消文章发布任务
|
||||
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||||
types.TaskTypeArticlePublish,
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
Update("status", entities.TaskStatusCancelled).Error
|
||||
}
|
||||
|
||||
// UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||||
types.TaskTypeArticlePublish,
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
Update("scheduled_at", newScheduledAt).Error
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
package task
|
||||
|
||||
// 任务类型常量
|
||||
const (
|
||||
// TaskTypeArticlePublish 文章定时发布任务
|
||||
TaskTypeArticlePublish = "article:publish"
|
||||
)
|
||||
196
internal/infrastructure/task/types/queue_types.go
Normal file
196
internal/infrastructure/task/types/queue_types.go
Normal file
@@ -0,0 +1,196 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
// QueueType 队列类型
|
||||
type QueueType string
|
||||
|
||||
const (
|
||||
QueueTypeDefault QueueType = "default"
|
||||
QueueTypeApi QueueType = "api"
|
||||
QueueTypeArticle QueueType = "article"
|
||||
QueueTypeFinance QueueType = "finance"
|
||||
QueueTypeProduct QueueType = "product"
|
||||
)
|
||||
|
||||
// ArticlePublishPayload 文章发布任务载荷
|
||||
type ArticlePublishPayload struct {
|
||||
ArticleID string `json:"article_id"`
|
||||
PublishAt time.Time `json:"publish_at"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ArticlePublishPayload) GetType() TaskType {
|
||||
return TaskTypeArticlePublish
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ArticlePublishPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ArticlePublishPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// ArticleCancelPayload 文章取消任务载荷
|
||||
type ArticleCancelPayload struct {
|
||||
ArticleID string `json:"article_id"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ArticleCancelPayload) GetType() TaskType {
|
||||
return TaskTypeArticleCancel
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ArticleCancelPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ArticleCancelPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// ArticleModifyPayload 文章修改任务载荷
|
||||
type ArticleModifyPayload struct {
|
||||
ArticleID string `json:"article_id"`
|
||||
NewPublishAt time.Time `json:"new_publish_at"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ArticleModifyPayload) GetType() TaskType {
|
||||
return TaskTypeArticleModify
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ArticleModifyPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ArticleModifyPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// ApiCallPayload API调用任务载荷
|
||||
type ApiCallPayload struct {
|
||||
ApiCallID string `json:"api_call_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ProductID string `json:"product_id"`
|
||||
Amount string `json:"amount"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ApiCallPayload) GetType() TaskType {
|
||||
return TaskTypeApiCall
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ApiCallPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ApiCallPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// DeductionPayload 扣款任务载荷
|
||||
type DeductionPayload struct {
|
||||
UserID string `json:"user_id"`
|
||||
Amount string `json:"amount"`
|
||||
ApiCallID string `json:"api_call_id"`
|
||||
TransactionID string `json:"transaction_id"`
|
||||
ProductID string `json:"product_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *DeductionPayload) GetType() TaskType {
|
||||
return TaskTypeDeduction
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *DeductionPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *DeductionPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// CompensationPayload 补偿任务载荷
|
||||
type CompensationPayload struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *CompensationPayload) GetType() TaskType {
|
||||
return TaskTypeCompensation
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *CompensationPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *CompensationPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// UsageStatsPayload 使用统计任务载荷
|
||||
type UsageStatsPayload struct {
|
||||
SubscriptionID string `json:"subscription_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ProductID string `json:"product_id"`
|
||||
Increment int `json:"increment"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *UsageStatsPayload) GetType() TaskType {
|
||||
return TaskTypeUsageStats
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *UsageStatsPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *UsageStatsPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
|
||||
// ApiLogPayload API日志任务载荷
|
||||
type ApiLogPayload struct {
|
||||
TransactionID string `json:"transaction_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ApiName string `json:"api_name"`
|
||||
ProductID string `json:"product_id"`
|
||||
}
|
||||
|
||||
// GetType 获取任务类型
|
||||
func (p *ApiLogPayload) GetType() TaskType {
|
||||
return TaskTypeApiLog
|
||||
}
|
||||
|
||||
// ToJSON 序列化为JSON
|
||||
func (p *ApiLogPayload) ToJSON() ([]byte, error) {
|
||||
return json.Marshal(p)
|
||||
}
|
||||
|
||||
// FromJSON 从JSON反序列化
|
||||
func (p *ApiLogPayload) FromJSON(data []byte) error {
|
||||
return json.Unmarshal(data, p)
|
||||
}
|
||||
29
internal/infrastructure/task/types/task_types.go
Normal file
29
internal/infrastructure/task/types/task_types.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package types
|
||||
|
||||
// TaskType 任务类型
|
||||
type TaskType string
|
||||
|
||||
const (
|
||||
// 文章相关任务
|
||||
TaskTypeArticlePublish TaskType = "article_publish"
|
||||
TaskTypeArticleCancel TaskType = "article_cancel"
|
||||
TaskTypeArticleModify TaskType = "article_modify"
|
||||
|
||||
// API相关任务
|
||||
TaskTypeApiCall TaskType = "api_call"
|
||||
TaskTypeApiLog TaskType = "api_log"
|
||||
|
||||
// 财务相关任务
|
||||
TaskTypeDeduction TaskType = "deduction"
|
||||
TaskTypeCompensation TaskType = "compensation"
|
||||
|
||||
// 产品相关任务
|
||||
TaskTypeUsageStats TaskType = "usage_stats"
|
||||
)
|
||||
|
||||
// TaskPayload 任务载荷接口
|
||||
type TaskPayload interface {
|
||||
GetType() TaskType
|
||||
ToJSON() ([]byte, error)
|
||||
FromJSON(data []byte) error
|
||||
}
|
||||
100
internal/infrastructure/task/utils/asynq_logger.go
Normal file
100
internal/infrastructure/task/utils/asynq_logger.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/hibiken/asynq"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AsynqLogger Asynq日志适配器
|
||||
type AsynqLogger struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAsynqLogger 创建Asynq日志适配器
|
||||
func NewAsynqLogger(logger *zap.Logger) *AsynqLogger {
|
||||
return &AsynqLogger{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Debug 调试日志
|
||||
func (l *AsynqLogger) Debug(args ...interface{}) {
|
||||
l.logger.Debug("", zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Info 信息日志
|
||||
func (l *AsynqLogger) Info(args ...interface{}) {
|
||||
l.logger.Info("", zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Warn 警告日志
|
||||
func (l *AsynqLogger) Warn(args ...interface{}) {
|
||||
l.logger.Warn("", zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Error 错误日志
|
||||
func (l *AsynqLogger) Error(args ...interface{}) {
|
||||
l.logger.Error("", zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Fatal 致命错误日志
|
||||
func (l *AsynqLogger) Fatal(args ...interface{}) {
|
||||
l.logger.Fatal("", zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Debugf 格式化调试日志
|
||||
func (l *AsynqLogger) Debugf(format string, args ...interface{}) {
|
||||
l.logger.Debug("", zap.String("format", format), zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Infof 格式化信息日志
|
||||
func (l *AsynqLogger) Infof(format string, args ...interface{}) {
|
||||
l.logger.Info("", zap.String("format", format), zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Warnf 格式化警告日志
|
||||
func (l *AsynqLogger) Warnf(format string, args ...interface{}) {
|
||||
l.logger.Warn("", zap.String("format", format), zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Errorf 格式化错误日志
|
||||
func (l *AsynqLogger) Errorf(format string, args ...interface{}) {
|
||||
l.logger.Error("", zap.String("format", format), zap.Any("args", args))
|
||||
}
|
||||
|
||||
// Fatalf 格式化致命错误日志
|
||||
func (l *AsynqLogger) Fatalf(format string, args ...interface{}) {
|
||||
l.logger.Fatal("", zap.String("format", format), zap.Any("args", args))
|
||||
}
|
||||
|
||||
// WithField 添加字段
|
||||
func (l *AsynqLogger) WithField(key string, value interface{}) asynq.Logger {
|
||||
return &AsynqLogger{
|
||||
logger: l.logger.With(zap.Any(key, value)),
|
||||
}
|
||||
}
|
||||
|
||||
// WithFields 添加多个字段
|
||||
func (l *AsynqLogger) WithFields(fields map[string]interface{}) asynq.Logger {
|
||||
zapFields := make([]zap.Field, 0, len(fields))
|
||||
for k, v := range fields {
|
||||
zapFields = append(zapFields, zap.Any(k, v))
|
||||
}
|
||||
return &AsynqLogger{
|
||||
logger: l.logger.With(zapFields...),
|
||||
}
|
||||
}
|
||||
|
||||
// WithError 添加错误字段
|
||||
func (l *AsynqLogger) WithError(err error) asynq.Logger {
|
||||
return &AsynqLogger{
|
||||
logger: l.logger.With(zap.Error(err)),
|
||||
}
|
||||
}
|
||||
|
||||
// WithContext 添加上下文
|
||||
func (l *AsynqLogger) WithContext(ctx context.Context) asynq.Logger {
|
||||
return l
|
||||
}
|
||||
17
internal/infrastructure/task/utils/task_id.go
Normal file
17
internal/infrastructure/task/utils/task_id.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GenerateTaskID 生成统一格式的任务ID (UUID)
|
||||
func GenerateTaskID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// GenerateTaskIDWithPrefix 生成带前缀的任务ID (UUID)
|
||||
func GenerateTaskIDWithPrefix(prefix string) string {
|
||||
return fmt.Sprintf("%s-%s", prefix, uuid.New().String())
|
||||
}
|
||||
Reference in New Issue
Block a user