267 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			267 lines
		
	
	
		
			9.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package repositories
 | ||
| 
 | ||
| import (
 | ||
| 	"context"
 | ||
| 	"time"
 | ||
| 
 | ||
| 	"gorm.io/gorm"
 | ||
| 
 | ||
| 	"tyapi-server/internal/infrastructure/task/entities"
 | ||
| 	"tyapi-server/internal/infrastructure/task/types"
 | ||
| )
 | ||
| 
 | ||
| // AsyncTaskRepository 异步任务仓库接口
 | ||
| type AsyncTaskRepository interface {
 | ||
| 	// 基础CRUD操作
 | ||
| 	Create(ctx context.Context, task *entities.AsyncTask) error
 | ||
| 	GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
 | ||
| 	Update(ctx context.Context, task *entities.AsyncTask) error
 | ||
| 	Delete(ctx context.Context, id string) error
 | ||
| 
 | ||
| 	// 查询操作
 | ||
| 	ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
 | ||
| 	ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
 | ||
| 	ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
 | ||
| 	ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
 | ||
| 
 | ||
| 	// 状态更新操作
 | ||
| 	UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
 | ||
| 	UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
 | ||
| 	UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
 | ||
| 	UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
 | ||
| 	UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
 | ||
| 	UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
 | ||
| 	IncrementRetryCount(ctx context.Context, id string) error
 | ||
| 
 | ||
| 	// 批量操作
 | ||
| 	UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
 | ||
| 	DeleteBatch(ctx context.Context, ids []string) error
 | ||
| 
 | ||
| 	// 文章任务专用方法
 | ||
| 	GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
 | ||
| 	GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
 | ||
| 	CancelArticlePublishTask(ctx context.Context, articleID string) error
 | ||
| 	UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
 | ||
| }
 | ||
| 
 | ||
| // AsyncTaskRepositoryImpl 异步任务仓库实现
 | ||
| type AsyncTaskRepositoryImpl struct {
 | ||
| 	db *gorm.DB
 | ||
| }
 | ||
| 
 | ||
| // NewAsyncTaskRepository 创建异步任务仓库
 | ||
| func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
 | ||
| 	return &AsyncTaskRepositoryImpl{
 | ||
| 		db: db,
 | ||
| 	}
 | ||
| }
 | ||
| 
 | ||
| // Create 创建任务
 | ||
| func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
 | ||
| 	return r.db.WithContext(ctx).Create(task).Error
 | ||
| }
 | ||
| 
 | ||
| // GetByID 根据ID获取任务
 | ||
| func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
 | ||
| 	var task entities.AsyncTask
 | ||
| 	err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 	return &task, nil
 | ||
| }
 | ||
| 
 | ||
| // Update 更新任务
 | ||
| func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
 | ||
| 	return r.db.WithContext(ctx).Save(task).Error
 | ||
| }
 | ||
| 
 | ||
| // Delete 删除任务
 | ||
| func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
 | ||
| 	return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
 | ||
| }
 | ||
| 
 | ||
| // ListByType 根据类型列出任务
 | ||
| func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
 | ||
| 	var tasks []*entities.AsyncTask
 | ||
| 	query := r.db.WithContext(ctx).Where("type = ?", taskType)
 | ||
| 	if limit > 0 {
 | ||
| 		query = query.Limit(limit)
 | ||
| 	}
 | ||
| 	err := query.Find(&tasks).Error
 | ||
| 	return tasks, err
 | ||
| }
 | ||
| 
 | ||
| // ListByStatus 根据状态列出任务
 | ||
| func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
 | ||
| 	var tasks []*entities.AsyncTask
 | ||
| 	query := r.db.WithContext(ctx).Where("status = ?", status)
 | ||
| 	if limit > 0 {
 | ||
| 		query = query.Limit(limit)
 | ||
| 	}
 | ||
| 	err := query.Find(&tasks).Error
 | ||
| 	return tasks, err
 | ||
| }
 | ||
| 
 | ||
| // ListByTypeAndStatus 根据类型和状态列出任务
 | ||
| func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
 | ||
| 	var tasks []*entities.AsyncTask
 | ||
| 	query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
 | ||
| 	if limit > 0 {
 | ||
| 		query = query.Limit(limit)
 | ||
| 	}
 | ||
| 	err := query.Find(&tasks).Error
 | ||
| 	return tasks, err
 | ||
| }
 | ||
| 
 | ||
| // ListScheduledTasks 列出已到期的调度任务
 | ||
| func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
 | ||
| 	var tasks []*entities.AsyncTask
 | ||
| 	err := r.db.WithContext(ctx).
 | ||
| 		Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
 | ||
| 		Find(&tasks).Error
 | ||
| 	return tasks, err
 | ||
| }
 | ||
| 
 | ||
| // UpdateStatus 更新任务状态
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Updates(map[string]interface{}{
 | ||
| 			"status":     status,
 | ||
| 			"updated_at": time.Now(),
 | ||
| 		}).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateStatusWithError 更新任务状态并记录错误
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Updates(map[string]interface{}{
 | ||
| 			"status":     status,
 | ||
| 			"error_msg":  errorMsg,
 | ||
| 			"updated_at": time.Now(),
 | ||
| 		}).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Updates(map[string]interface{}{
 | ||
| 			"status":      status,
 | ||
| 			"error_msg":   errorMsg,
 | ||
| 			"retry_count": gorm.Expr("retry_count + 1"),
 | ||
| 			"updated_at":  time.Now(),
 | ||
| 		}).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Updates(map[string]interface{}{
 | ||
| 			"status":     status,
 | ||
| 			"error_msg":  "", // 清除错误信息
 | ||
| 			"updated_at": time.Now(),
 | ||
| 		}).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateRetryCountAndError 更新重试次数和错误信息,保持pending状态
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Updates(map[string]interface{}{
 | ||
| 			"retry_count": retryCount,
 | ||
| 			"error_msg":   errorMsg,
 | ||
| 			"updated_at":  time.Now(),
 | ||
| 			// 注意:不更新status,保持pending状态
 | ||
| 		}).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateScheduledAt 更新任务调度时间
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Update("scheduled_at", scheduledAt).Error
 | ||
| }
 | ||
| 
 | ||
| // IncrementRetryCount 增加重试次数
 | ||
| func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id = ?", id).
 | ||
| 		Update("retry_count", gorm.Expr("retry_count + 1")).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateStatusBatch 批量更新状态
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("id IN ?", ids).
 | ||
| 		Update("status", status).Error
 | ||
| }
 | ||
| 
 | ||
| // DeleteBatch 批量删除
 | ||
| func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Where("id IN ?", ids).
 | ||
| 		Delete(&entities.AsyncTask{}).Error
 | ||
| }
 | ||
| 
 | ||
| // GetArticlePublishTask 获取文章发布任务
 | ||
| func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
 | ||
| 	var task entities.AsyncTask
 | ||
| 	err := r.db.WithContext(ctx).
 | ||
| 		Where("type = ? AND payload LIKE ? AND status IN ?", 
 | ||
| 			types.TaskTypeArticlePublish, 
 | ||
| 			"%\"article_id\":\""+articleID+"\"%",
 | ||
| 			[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
 | ||
| 		First(&task).Error
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 	return &task, nil
 | ||
| }
 | ||
| 
 | ||
| // GetByArticleID 根据文章ID获取所有相关任务
 | ||
| func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
 | ||
| 	var tasks []*entities.AsyncTask
 | ||
| 	err := r.db.WithContext(ctx).
 | ||
| 		Where("payload LIKE ? AND status IN ?", 
 | ||
| 			"%\"article_id\":\""+articleID+"\"%",
 | ||
| 			[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
 | ||
| 		Find(&tasks).Error
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 	return tasks, nil
 | ||
| }
 | ||
| 
 | ||
| // CancelArticlePublishTask 取消文章发布任务
 | ||
| func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("type = ? AND payload LIKE ? AND status IN ?", 
 | ||
| 			types.TaskTypeArticlePublish, 
 | ||
| 			"%\"article_id\":\""+articleID+"\"%",
 | ||
| 			[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
 | ||
| 		Update("status", entities.TaskStatusCancelled).Error
 | ||
| }
 | ||
| 
 | ||
| // UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
 | ||
| func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
 | ||
| 	return r.db.WithContext(ctx).
 | ||
| 		Model(&entities.AsyncTask{}).
 | ||
| 		Where("type = ? AND payload LIKE ? AND status IN ?", 
 | ||
| 			types.TaskTypeArticlePublish, 
 | ||
| 			"%\"article_id\":\""+articleID+"\"%",
 | ||
| 			[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
 | ||
| 		Update("scheduled_at", newScheduledAt).Error
 | ||
| } |