Files
tyapi-server/internal/infrastructure/task/repositories/async_task_repository.go
2025-09-12 01:15:09 +08:00

267 lines
9.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package repositories
import (
"context"
"time"
"gorm.io/gorm"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/types"
)
// AsyncTaskRepository 异步任务仓库接口
type AsyncTaskRepository interface {
// 基础CRUD操作
Create(ctx context.Context, task *entities.AsyncTask) error
GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
Update(ctx context.Context, task *entities.AsyncTask) error
Delete(ctx context.Context, id string) error
// 查询操作
ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
// 状态更新操作
UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
IncrementRetryCount(ctx context.Context, id string) error
// 批量操作
UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
DeleteBatch(ctx context.Context, ids []string) error
// 文章任务专用方法
GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
CancelArticlePublishTask(ctx context.Context, articleID string) error
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
}
// AsyncTaskRepositoryImpl 异步任务仓库实现
type AsyncTaskRepositoryImpl struct {
db *gorm.DB
}
// NewAsyncTaskRepository 创建异步任务仓库
func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
return &AsyncTaskRepositoryImpl{
db: db,
}
}
// Create 创建任务
func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
return r.db.WithContext(ctx).Create(task).Error
}
// GetByID 根据ID获取任务
func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
var task entities.AsyncTask
err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
if err != nil {
return nil, err
}
return &task, nil
}
// Update 更新任务
func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
return r.db.WithContext(ctx).Save(task).Error
}
// Delete 删除任务
func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
}
// ListByType 根据类型列出任务
func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
query := r.db.WithContext(ctx).Where("type = ?", taskType)
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&tasks).Error
return tasks, err
}
// ListByStatus 根据状态列出任务
func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
query := r.db.WithContext(ctx).Where("status = ?", status)
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&tasks).Error
return tasks, err
}
// ListByTypeAndStatus 根据类型和状态列出任务
func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&tasks).Error
return tasks, err
}
// ListScheduledTasks 列出已到期的调度任务
func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
err := r.db.WithContext(ctx).
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
Find(&tasks).Error
return tasks, err
}
// UpdateStatus 更新任务状态
func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"updated_at": time.Now(),
}).Error
}
// UpdateStatusWithError 更新任务状态并记录错误
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"error_msg": errorMsg,
"updated_at": time.Now(),
}).Error
}
// UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"error_msg": errorMsg,
"retry_count": gorm.Expr("retry_count + 1"),
"updated_at": time.Now(),
}).Error
}
// UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"error_msg": "", // 清除错误信息
"updated_at": time.Now(),
}).Error
}
// UpdateRetryCountAndError 更新重试次数和错误信息保持pending状态
func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"retry_count": retryCount,
"error_msg": errorMsg,
"updated_at": time.Now(),
// 注意不更新status保持pending状态
}).Error
}
// UpdateScheduledAt 更新任务调度时间
func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Update("scheduled_at", scheduledAt).Error
}
// IncrementRetryCount 增加重试次数
func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Update("retry_count", gorm.Expr("retry_count + 1")).Error
}
// UpdateStatusBatch 批量更新状态
func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id IN ?", ids).
Update("status", status).Error
}
// DeleteBatch 批量删除
func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).
Where("id IN ?", ids).
Delete(&entities.AsyncTask{}).Error
}
// GetArticlePublishTask 获取文章发布任务
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
var task entities.AsyncTask
err := r.db.WithContext(ctx).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
First(&task).Error
if err != nil {
return nil, err
}
return &task, nil
}
// GetByArticleID 根据文章ID获取所有相关任务
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
err := r.db.WithContext(ctx).
Where("payload LIKE ? AND status IN ?",
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Find(&tasks).Error
if err != nil {
return nil, err
}
return tasks, nil
}
// CancelArticlePublishTask 取消文章发布任务
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Update("status", entities.TaskStatusCancelled).Error
}
// UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Update("scheduled_at", newScheduledAt).Error
}