267 lines
9.9 KiB
Go
267 lines
9.9 KiB
Go
|
|
package repositories
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"context"
|
|||
|
|
"time"
|
|||
|
|
|
|||
|
|
"gorm.io/gorm"
|
|||
|
|
|
|||
|
|
"tyapi-server/internal/infrastructure/task/entities"
|
|||
|
|
"tyapi-server/internal/infrastructure/task/types"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// AsyncTaskRepository 异步任务仓库接口
|
|||
|
|
type AsyncTaskRepository interface {
|
|||
|
|
// 基础CRUD操作
|
|||
|
|
Create(ctx context.Context, task *entities.AsyncTask) error
|
|||
|
|
GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
|
|||
|
|
Update(ctx context.Context, task *entities.AsyncTask) error
|
|||
|
|
Delete(ctx context.Context, id string) error
|
|||
|
|
|
|||
|
|
// 查询操作
|
|||
|
|
ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
|
|||
|
|
ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
|||
|
|
ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
|||
|
|
ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
|
|||
|
|
|
|||
|
|
// 状态更新操作
|
|||
|
|
UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
|
|||
|
|
UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
|||
|
|
UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
|||
|
|
UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
|
|||
|
|
UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
|
|||
|
|
UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
|
|||
|
|
IncrementRetryCount(ctx context.Context, id string) error
|
|||
|
|
|
|||
|
|
// 批量操作
|
|||
|
|
UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
|
|||
|
|
DeleteBatch(ctx context.Context, ids []string) error
|
|||
|
|
|
|||
|
|
// 文章任务专用方法
|
|||
|
|
GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
|
|||
|
|
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
|
|||
|
|
CancelArticlePublishTask(ctx context.Context, articleID string) error
|
|||
|
|
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AsyncTaskRepositoryImpl 异步任务仓库实现
|
|||
|
|
type AsyncTaskRepositoryImpl struct {
|
|||
|
|
db *gorm.DB
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// NewAsyncTaskRepository 创建异步任务仓库
|
|||
|
|
func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
|
|||
|
|
return &AsyncTaskRepositoryImpl{
|
|||
|
|
db: db,
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Create 创建任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
|
|||
|
|
return r.db.WithContext(ctx).Create(task).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByID 根据ID获取任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
|
|||
|
|
var task entities.AsyncTask
|
|||
|
|
err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &task, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Update 更新任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
|
|||
|
|
return r.db.WithContext(ctx).Save(task).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// Delete 删除任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
|
|||
|
|
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListByType 根据类型列出任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
|
|||
|
|
var tasks []*entities.AsyncTask
|
|||
|
|
query := r.db.WithContext(ctx).Where("type = ?", taskType)
|
|||
|
|
if limit > 0 {
|
|||
|
|
query = query.Limit(limit)
|
|||
|
|
}
|
|||
|
|
err := query.Find(&tasks).Error
|
|||
|
|
return tasks, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListByStatus 根据状态列出任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
|||
|
|
var tasks []*entities.AsyncTask
|
|||
|
|
query := r.db.WithContext(ctx).Where("status = ?", status)
|
|||
|
|
if limit > 0 {
|
|||
|
|
query = query.Limit(limit)
|
|||
|
|
}
|
|||
|
|
err := query.Find(&tasks).Error
|
|||
|
|
return tasks, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListByTypeAndStatus 根据类型和状态列出任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
|||
|
|
var tasks []*entities.AsyncTask
|
|||
|
|
query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
|
|||
|
|
if limit > 0 {
|
|||
|
|
query = query.Limit(limit)
|
|||
|
|
}
|
|||
|
|
err := query.Find(&tasks).Error
|
|||
|
|
return tasks, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ListScheduledTasks 列出已到期的调度任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
|
|||
|
|
var tasks []*entities.AsyncTask
|
|||
|
|
err := r.db.WithContext(ctx).
|
|||
|
|
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
|
|||
|
|
Find(&tasks).Error
|
|||
|
|
return tasks, err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateStatus 更新任务状态
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"status": status,
|
|||
|
|
"updated_at": time.Now(),
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateStatusWithError 更新任务状态并记录错误
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"status": status,
|
|||
|
|
"error_msg": errorMsg,
|
|||
|
|
"updated_at": time.Now(),
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"status": status,
|
|||
|
|
"error_msg": errorMsg,
|
|||
|
|
"retry_count": gorm.Expr("retry_count + 1"),
|
|||
|
|
"updated_at": time.Now(),
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"status": status,
|
|||
|
|
"error_msg": "", // 清除错误信息
|
|||
|
|
"updated_at": time.Now(),
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateRetryCountAndError 更新重试次数和错误信息,保持pending状态
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Updates(map[string]interface{}{
|
|||
|
|
"retry_count": retryCount,
|
|||
|
|
"error_msg": errorMsg,
|
|||
|
|
"updated_at": time.Now(),
|
|||
|
|
// 注意:不更新status,保持pending状态
|
|||
|
|
}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateScheduledAt 更新任务调度时间
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Update("scheduled_at", scheduledAt).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// IncrementRetryCount 增加重试次数
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id = ?", id).
|
|||
|
|
Update("retry_count", gorm.Expr("retry_count + 1")).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateStatusBatch 批量更新状态
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("id IN ?", ids).
|
|||
|
|
Update("status", status).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// DeleteBatch 批量删除
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Where("id IN ?", ids).
|
|||
|
|
Delete(&entities.AsyncTask{}).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetArticlePublishTask 获取文章发布任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
|
|||
|
|
var task entities.AsyncTask
|
|||
|
|
err := r.db.WithContext(ctx).
|
|||
|
|
Where("type = ? AND payload LIKE ? AND status IN ?",
|
|||
|
|
types.TaskTypeArticlePublish,
|
|||
|
|
"%\"article_id\":\""+articleID+"\"%",
|
|||
|
|
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
|||
|
|
First(&task).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return &task, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// GetByArticleID 根据文章ID获取所有相关任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
|
|||
|
|
var tasks []*entities.AsyncTask
|
|||
|
|
err := r.db.WithContext(ctx).
|
|||
|
|
Where("payload LIKE ? AND status IN ?",
|
|||
|
|
"%\"article_id\":\""+articleID+"\"%",
|
|||
|
|
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
|||
|
|
Find(&tasks).Error
|
|||
|
|
if err != nil {
|
|||
|
|
return nil, err
|
|||
|
|
}
|
|||
|
|
return tasks, nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// CancelArticlePublishTask 取消文章发布任务
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("type = ? AND payload LIKE ? AND status IN ?",
|
|||
|
|
types.TaskTypeArticlePublish,
|
|||
|
|
"%\"article_id\":\""+articleID+"\"%",
|
|||
|
|
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
|||
|
|
Update("status", entities.TaskStatusCancelled).Error
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
|
|||
|
|
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
|
|||
|
|
return r.db.WithContext(ctx).
|
|||
|
|
Model(&entities.AsyncTask{}).
|
|||
|
|
Where("type = ? AND payload LIKE ? AND status IN ?",
|
|||
|
|
types.TaskTypeArticlePublish,
|
|||
|
|
"%\"article_id\":\""+articleID+"\"%",
|
|||
|
|
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
|||
|
|
Update("scheduled_at", newScheduledAt).Error
|
|||
|
|
}
|