package repositories import ( "context" "time" "gorm.io/gorm" "tyapi-server/internal/infrastructure/task/entities" "tyapi-server/internal/infrastructure/task/types" ) // AsyncTaskRepository 异步任务仓库接口 type AsyncTaskRepository interface { // 基础CRUD操作 Create(ctx context.Context, task *entities.AsyncTask) error GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) Update(ctx context.Context, task *entities.AsyncTask) error Delete(ctx context.Context, id string) error // 查询操作 ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) // 状态更新操作 UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error IncrementRetryCount(ctx context.Context, id string) error // 批量操作 UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error DeleteBatch(ctx context.Context, ids []string) error // 文章任务专用方法 GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) CancelArticlePublishTask(ctx context.Context, articleID string) error UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error } // AsyncTaskRepositoryImpl 异步任务仓库实现 type AsyncTaskRepositoryImpl struct { db *gorm.DB } // NewAsyncTaskRepository 创建异步任务仓库 func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository { return &AsyncTaskRepositoryImpl{ db: db, } } // Create 创建任务 func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error { return r.db.WithContext(ctx).Create(task).Error } // GetByID 根据ID获取任务 func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) { var task entities.AsyncTask err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error if err != nil { return nil, err } return &task, nil } // Update 更新任务 func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error { return r.db.WithContext(ctx).Save(task).Error } // Delete 删除任务 func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error { return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error } // ListByType 根据类型列出任务 func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) { var tasks []*entities.AsyncTask query := r.db.WithContext(ctx).Where("type = ?", taskType) if limit > 0 { query = query.Limit(limit) } err := query.Find(&tasks).Error return tasks, err } // ListByStatus 根据状态列出任务 func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) { var tasks []*entities.AsyncTask query := r.db.WithContext(ctx).Where("status = ?", status) if limit > 0 { query = query.Limit(limit) } err := query.Find(&tasks).Error return tasks, err } // ListByTypeAndStatus 根据类型和状态列出任务 func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) { var tasks []*entities.AsyncTask query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status) if limit > 0 { query = query.Limit(limit) } err := query.Find(&tasks).Error return tasks, err } // ListScheduledTasks 列出已到期的调度任务 func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) { var tasks []*entities.AsyncTask err := r.db.WithContext(ctx). Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before). Find(&tasks).Error return tasks, err } // UpdateStatus 更新任务状态 func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Updates(map[string]interface{}{ "status": status, "updated_at": time.Now(), }).Error } // UpdateStatusWithError 更新任务状态并记录错误 func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Updates(map[string]interface{}{ "status": status, "error_msg": errorMsg, "updated_at": time.Now(), }).Error } // UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误 func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Updates(map[string]interface{}{ "status": status, "error_msg": errorMsg, "retry_count": gorm.Expr("retry_count + 1"), "updated_at": time.Now(), }).Error } // UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息 func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Updates(map[string]interface{}{ "status": status, "error_msg": "", // 清除错误信息 "updated_at": time.Now(), }).Error } // UpdateRetryCountAndError 更新重试次数和错误信息,保持pending状态 func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Updates(map[string]interface{}{ "retry_count": retryCount, "error_msg": errorMsg, "updated_at": time.Now(), // 注意:不更新status,保持pending状态 }).Error } // UpdateScheduledAt 更新任务调度时间 func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Update("scheduled_at", scheduledAt).Error } // IncrementRetryCount 增加重试次数 func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id = ?", id). Update("retry_count", gorm.Expr("retry_count + 1")).Error } // UpdateStatusBatch 批量更新状态 func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("id IN ?", ids). Update("status", status).Error } // DeleteBatch 批量删除 func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error { return r.db.WithContext(ctx). Where("id IN ?", ids). Delete(&entities.AsyncTask{}).Error } // GetArticlePublishTask 获取文章发布任务 func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) { var task entities.AsyncTask err := r.db.WithContext(ctx). Where("type = ? AND payload LIKE ? AND status IN ?", types.TaskTypeArticlePublish, "%\"article_id\":\""+articleID+"\"%", []entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}). First(&task).Error if err != nil { return nil, err } return &task, nil } // GetByArticleID 根据文章ID获取所有相关任务 func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) { var tasks []*entities.AsyncTask err := r.db.WithContext(ctx). Where("payload LIKE ? AND status IN ?", "%\"article_id\":\""+articleID+"\"%", []entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}). Find(&tasks).Error if err != nil { return nil, err } return tasks, nil } // CancelArticlePublishTask 取消文章发布任务 func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("type = ? AND payload LIKE ? AND status IN ?", types.TaskTypeArticlePublish, "%\"article_id\":\""+articleID+"\"%", []entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}). Update("status", entities.TaskStatusCancelled).Error } // UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间 func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error { return r.db.WithContext(ctx). Model(&entities.AsyncTask{}). Where("type = ? AND payload LIKE ? AND status IN ?", types.TaskTypeArticlePublish, "%\"article_id\":\""+articleID+"\"%", []entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}). Update("scheduled_at", newScheduledAt).Error }