new
This commit is contained in:
@@ -0,0 +1,267 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/infrastructure/task/entities"
|
||||
"tyapi-server/internal/infrastructure/task/types"
|
||||
)
|
||||
|
||||
// AsyncTaskRepository 异步任务仓库接口
|
||||
type AsyncTaskRepository interface {
|
||||
// 基础CRUD操作
|
||||
Create(ctx context.Context, task *entities.AsyncTask) error
|
||||
GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
|
||||
Update(ctx context.Context, task *entities.AsyncTask) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
|
||||
// 查询操作
|
||||
ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
|
||||
ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||||
ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||||
ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
|
||||
|
||||
// 状态更新操作
|
||||
UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
|
||||
UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
||||
UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
||||
UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
|
||||
UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
|
||||
UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
|
||||
IncrementRetryCount(ctx context.Context, id string) error
|
||||
|
||||
// 批量操作
|
||||
UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
|
||||
DeleteBatch(ctx context.Context, ids []string) error
|
||||
|
||||
// 文章任务专用方法
|
||||
GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
|
||||
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
|
||||
CancelArticlePublishTask(ctx context.Context, articleID string) error
|
||||
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
|
||||
}
|
||||
|
||||
// AsyncTaskRepositoryImpl 异步任务仓库实现
|
||||
type AsyncTaskRepositoryImpl struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewAsyncTaskRepository 创建异步任务仓库
|
||||
func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
|
||||
return &AsyncTaskRepositoryImpl{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建任务
|
||||
func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
|
||||
return r.db.WithContext(ctx).Create(task).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取任务
|
||||
func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
|
||||
var task entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// Update 更新任务
|
||||
func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
|
||||
return r.db.WithContext(ctx).Save(task).Error
|
||||
}
|
||||
|
||||
// Delete 删除任务
|
||||
func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
|
||||
}
|
||||
|
||||
// ListByType 根据类型列出任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
query := r.db.WithContext(ctx).Where("type = ?", taskType)
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
err := query.Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态列出任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
query := r.db.WithContext(ctx).Where("status = ?", status)
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
err := query.Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ListByTypeAndStatus 根据类型和状态列出任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
err := query.Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// ListScheduledTasks 列出已到期的调度任务
|
||||
func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
|
||||
Find(&tasks).Error
|
||||
return tasks, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新任务状态
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatusWithError 更新任务状态并记录错误
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"error_msg": errorMsg,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"error_msg": errorMsg,
|
||||
"retry_count": gorm.Expr("retry_count + 1"),
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"status": status,
|
||||
"error_msg": "", // 清除错误信息
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateRetryCountAndError 更新重试次数和错误信息,保持pending状态
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Updates(map[string]interface{}{
|
||||
"retry_count": retryCount,
|
||||
"error_msg": errorMsg,
|
||||
"updated_at": time.Now(),
|
||||
// 注意:不更新status,保持pending状态
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateScheduledAt 更新任务调度时间
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Update("scheduled_at", scheduledAt).Error
|
||||
}
|
||||
|
||||
// IncrementRetryCount 增加重试次数
|
||||
func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id = ?", id).
|
||||
Update("retry_count", gorm.Expr("retry_count + 1")).Error
|
||||
}
|
||||
|
||||
// UpdateStatusBatch 批量更新状态
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("id IN ?", ids).
|
||||
Update("status", status).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除
|
||||
func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Where("id IN ?", ids).
|
||||
Delete(&entities.AsyncTask{}).Error
|
||||
}
|
||||
|
||||
// GetArticlePublishTask 获取文章发布任务
|
||||
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
|
||||
var task entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||||
types.TaskTypeArticlePublish,
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
First(&task).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &task, nil
|
||||
}
|
||||
|
||||
// GetByArticleID 根据文章ID获取所有相关任务
|
||||
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
|
||||
var tasks []*entities.AsyncTask
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("payload LIKE ? AND status IN ?",
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
Find(&tasks).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// CancelArticlePublishTask 取消文章发布任务
|
||||
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||||
types.TaskTypeArticlePublish,
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
Update("status", entities.TaskStatusCancelled).Error
|
||||
}
|
||||
|
||||
// UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
|
||||
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.AsyncTask{}).
|
||||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||||
types.TaskTypeArticlePublish,
|
||||
"%\"article_id\":\""+articleID+"\"%",
|
||||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||||
Update("scheduled_at", newScheduledAt).Error
|
||||
}
|
||||
Reference in New Issue
Block a user