267 lines
9.9 KiB
Go
267 lines
9.9 KiB
Go
package repositories
|
||
|
||
import (
|
||
"context"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
|
||
"tyapi-server/internal/infrastructure/task/entities"
|
||
"tyapi-server/internal/infrastructure/task/types"
|
||
)
|
||
|
||
// AsyncTaskRepository 异步任务仓库接口
|
||
type AsyncTaskRepository interface {
|
||
// 基础CRUD操作
|
||
Create(ctx context.Context, task *entities.AsyncTask) error
|
||
GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
|
||
Update(ctx context.Context, task *entities.AsyncTask) error
|
||
Delete(ctx context.Context, id string) error
|
||
|
||
// 查询操作
|
||
ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
|
||
ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||
ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
|
||
ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
|
||
|
||
// 状态更新操作
|
||
UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
|
||
UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
||
UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
|
||
UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
|
||
UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
|
||
UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
|
||
IncrementRetryCount(ctx context.Context, id string) error
|
||
|
||
// 批量操作
|
||
UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
|
||
DeleteBatch(ctx context.Context, ids []string) error
|
||
|
||
// 文章任务专用方法
|
||
GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
|
||
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
|
||
CancelArticlePublishTask(ctx context.Context, articleID string) error
|
||
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
|
||
}
|
||
|
||
// AsyncTaskRepositoryImpl 异步任务仓库实现
|
||
type AsyncTaskRepositoryImpl struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewAsyncTaskRepository 创建异步任务仓库
|
||
func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
|
||
return &AsyncTaskRepositoryImpl{
|
||
db: db,
|
||
}
|
||
}
|
||
|
||
// Create 创建任务
|
||
func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
|
||
return r.db.WithContext(ctx).Create(task).Error
|
||
}
|
||
|
||
// GetByID 根据ID获取任务
|
||
func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
|
||
var task entities.AsyncTask
|
||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &task, nil
|
||
}
|
||
|
||
// Update 更新任务
|
||
func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
|
||
return r.db.WithContext(ctx).Save(task).Error
|
||
}
|
||
|
||
// Delete 删除任务
|
||
func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
|
||
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
|
||
}
|
||
|
||
// ListByType 根据类型列出任务
|
||
func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
|
||
var tasks []*entities.AsyncTask
|
||
query := r.db.WithContext(ctx).Where("type = ?", taskType)
|
||
if limit > 0 {
|
||
query = query.Limit(limit)
|
||
}
|
||
err := query.Find(&tasks).Error
|
||
return tasks, err
|
||
}
|
||
|
||
// ListByStatus 根据状态列出任务
|
||
func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||
var tasks []*entities.AsyncTask
|
||
query := r.db.WithContext(ctx).Where("status = ?", status)
|
||
if limit > 0 {
|
||
query = query.Limit(limit)
|
||
}
|
||
err := query.Find(&tasks).Error
|
||
return tasks, err
|
||
}
|
||
|
||
// ListByTypeAndStatus 根据类型和状态列出任务
|
||
func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
|
||
var tasks []*entities.AsyncTask
|
||
query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
|
||
if limit > 0 {
|
||
query = query.Limit(limit)
|
||
}
|
||
err := query.Find(&tasks).Error
|
||
return tasks, err
|
||
}
|
||
|
||
// ListScheduledTasks 列出已到期的调度任务
|
||
func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
|
||
var tasks []*entities.AsyncTask
|
||
err := r.db.WithContext(ctx).
|
||
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
|
||
Find(&tasks).Error
|
||
return tasks, err
|
||
}
|
||
|
||
// UpdateStatus 更新任务状态
|
||
func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Updates(map[string]interface{}{
|
||
"status": status,
|
||
"updated_at": time.Now(),
|
||
}).Error
|
||
}
|
||
|
||
// UpdateStatusWithError 更新任务状态并记录错误
|
||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Updates(map[string]interface{}{
|
||
"status": status,
|
||
"error_msg": errorMsg,
|
||
"updated_at": time.Now(),
|
||
}).Error
|
||
}
|
||
|
||
// UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
|
||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Updates(map[string]interface{}{
|
||
"status": status,
|
||
"error_msg": errorMsg,
|
||
"retry_count": gorm.Expr("retry_count + 1"),
|
||
"updated_at": time.Now(),
|
||
}).Error
|
||
}
|
||
|
||
// UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
|
||
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Updates(map[string]interface{}{
|
||
"status": status,
|
||
"error_msg": "", // 清除错误信息
|
||
"updated_at": time.Now(),
|
||
}).Error
|
||
}
|
||
|
||
// UpdateRetryCountAndError 更新重试次数和错误信息,保持pending状态
|
||
func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Updates(map[string]interface{}{
|
||
"retry_count": retryCount,
|
||
"error_msg": errorMsg,
|
||
"updated_at": time.Now(),
|
||
// 注意:不更新status,保持pending状态
|
||
}).Error
|
||
}
|
||
|
||
// UpdateScheduledAt 更新任务调度时间
|
||
func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Update("scheduled_at", scheduledAt).Error
|
||
}
|
||
|
||
// IncrementRetryCount 增加重试次数
|
||
func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id = ?", id).
|
||
Update("retry_count", gorm.Expr("retry_count + 1")).Error
|
||
}
|
||
|
||
// UpdateStatusBatch 批量更新状态
|
||
func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("id IN ?", ids).
|
||
Update("status", status).Error
|
||
}
|
||
|
||
// DeleteBatch 批量删除
|
||
func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
|
||
return r.db.WithContext(ctx).
|
||
Where("id IN ?", ids).
|
||
Delete(&entities.AsyncTask{}).Error
|
||
}
|
||
|
||
// GetArticlePublishTask 获取文章发布任务
|
||
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
|
||
var task entities.AsyncTask
|
||
err := r.db.WithContext(ctx).
|
||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||
types.TaskTypeArticlePublish,
|
||
"%\"article_id\":\""+articleID+"\"%",
|
||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||
First(&task).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &task, nil
|
||
}
|
||
|
||
// GetByArticleID 根据文章ID获取所有相关任务
|
||
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
|
||
var tasks []*entities.AsyncTask
|
||
err := r.db.WithContext(ctx).
|
||
Where("payload LIKE ? AND status IN ?",
|
||
"%\"article_id\":\""+articleID+"\"%",
|
||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||
Find(&tasks).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return tasks, nil
|
||
}
|
||
|
||
// CancelArticlePublishTask 取消文章发布任务
|
||
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||
types.TaskTypeArticlePublish,
|
||
"%\"article_id\":\""+articleID+"\"%",
|
||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||
Update("status", entities.TaskStatusCancelled).Error
|
||
}
|
||
|
||
// UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
|
||
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
|
||
return r.db.WithContext(ctx).
|
||
Model(&entities.AsyncTask{}).
|
||
Where("type = ? AND payload LIKE ? AND status IN ?",
|
||
types.TaskTypeArticlePublish,
|
||
"%\"article_id\":\""+articleID+"\"%",
|
||
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
|
||
Update("scheduled_at", newScheduledAt).Error
|
||
} |