This commit is contained in:
2025-09-12 01:15:09 +08:00
parent c563b2266b
commit e05ad9e223
103 changed files with 20034 additions and 1041 deletions

View File

@@ -2,6 +2,7 @@ package api
import (
"context"
"fmt"
"time"
"tyapi-server/internal/domains/api/entities"
"tyapi-server/internal/domains/api/repositories"
@@ -228,6 +229,61 @@ func (r *GormApiCallRepository) CountByUserId(ctx context.Context, userId string
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ?", userId)
}
// CountByUserIdAndDateRange 按用户ID和日期范围统计API调用次数
func (r *GormApiCallRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
}
// GetDailyStatsByUserId 获取用户每日API调用统计
func (r *GormApiCallRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
DATE(created_at) as date,
COUNT(*) as calls
FROM api_calls
WHERE user_id = $1
AND DATE(created_at) >= $2
AND DATE(created_at) <= $3
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetMonthlyStatsByUserId 获取用户每月API调用统计
func (r *GormApiCallRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COUNT(*) as calls
FROM api_calls
WHERE user_id = $1
AND created_at >= $2
AND created_at <= $3
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
func (r *GormApiCallRepository) FindByTransactionId(ctx context.Context, transactionId string) (*entities.ApiCall, error) {
var call entities.ApiCall
err := r.FindOne(ctx, &call, "transaction_id = ?", transactionId)
@@ -329,4 +385,135 @@ func (r *GormApiCallRepository) ListWithFiltersAndProductName(ctx context.Contex
}
return productNameMap, calls, total, nil
}
// GetSystemTotalCalls 获取系统总API调用次数
func (r *GormApiCallRepository) GetSystemTotalCalls(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ApiCall{}).Count(&count).Error
return count, err
}
// GetSystemCallsByDateRange 获取系统指定时间范围内的API调用次数
func (r *GormApiCallRepository) GetSystemCallsByDateRange(ctx context.Context, startDate, endDate time.Time) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ApiCall{}).
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
Count(&count).Error
return count, err
}
// GetSystemDailyStats 获取系统每日API调用统计
func (r *GormApiCallRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COUNT(*) as calls
FROM api_calls
WHERE DATE(created_at) >= $1
AND DATE(created_at) <= $2
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyStats 获取系统每月API调用统计
func (r *GormApiCallRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COUNT(*) as calls
FROM api_calls
WHERE created_at >= $1
AND created_at <= $2
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetApiPopularityRanking 获取API受欢迎程度排行榜
func (r *GormApiCallRepository) GetApiPopularityRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
p.id as product_id,
p.name as api_name,
p.description as api_description,
COUNT(ac.id) as call_count
FROM product p
LEFT JOIN api_calls ac ON p.id = ac.product_id
AND DATE(ac.created_at) = CURRENT_DATE
WHERE p.deleted_at IS NULL
GROUP BY p.id, p.name, p.description
HAVING COUNT(ac.id) > 0
ORDER BY call_count DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
p.id as product_id,
p.name as api_name,
p.description as api_description,
COUNT(ac.id) as call_count
FROM product p
LEFT JOIN api_calls ac ON p.id = ac.product_id
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE p.deleted_at IS NULL
GROUP BY p.id, p.name, p.description
HAVING COUNT(ac.id) > 0
ORDER BY call_count DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
p.id as product_id,
p.name as api_name,
p.description as api_description,
COUNT(ac.id) as call_count
FROM product p
LEFT JOIN api_calls ac ON p.id = ac.product_id
WHERE p.deleted_at IS NULL
GROUP BY p.id, p.name, p.description
HAVING COUNT(ac.id) > 0
ORDER BY call_count DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -3,6 +3,8 @@ package repositories
import (
"context"
"errors"
"strings"
"time"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/shared/database"
@@ -110,7 +112,14 @@ func (r *GormRechargeRecordRepository) List(ctx context.Context, options interfa
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
// 特殊处理 user_ids 过滤器
if key == "user_ids" {
if userIds, ok := value.(string); ok && userIds != "" {
query = query.Where("user_id IN ?", strings.Split(userIds, ","))
}
} else {
query = query.Where(key+" = ?", value)
}
}
}
@@ -175,4 +184,144 @@ func (r *GormRechargeRecordRepository) SoftDelete(ctx context.Context, id string
func (r *GormRechargeRecordRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.RechargeRecord{})
}
}
// GetTotalAmountByUserId 获取用户总充值金额
func (r *GormRechargeRecordRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ? AND status = ?", userId, entities.RechargeStatusSuccess).
Scan(&total).Error
return total, err
}
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总充值金额
func (r *GormRechargeRecordRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ? AND status = ? AND created_at >= ? AND created_at < ?", userId, entities.RechargeStatusSuccess, startDate, endDate).
Scan(&total).Error
return total, err
}
// GetDailyStatsByUserId 获取用户每日充值统计
func (r *GormRechargeRecordRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE user_id = $1
AND status = $2
AND DATE(created_at) >= $3
AND DATE(created_at) <= $4
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetMonthlyStatsByUserId 获取用户每月充值统计
func (r *GormRechargeRecordRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE user_id = $1
AND status = $2
AND created_at >= $3
AND created_at <= $4
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemTotalAmount 获取系统总充值金额
func (r *GormRechargeRecordRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Where("status = ?", entities.RechargeStatusSuccess).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemAmountByDateRange 获取系统指定时间范围内的充值金额
func (r *GormRechargeRecordRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Where("status = ? AND created_at >= ? AND created_at <= ?", entities.RechargeStatusSuccess, startDate, endDate).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemDailyStats 获取系统每日充值统计
func (r *GormRechargeRecordRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE status = $1
AND DATE(created_at) >= $2
AND DATE(created_at) <= $3
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyStats 获取系统每月充值统计
func (r *GormRechargeRecordRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE status = $1
AND created_at >= $2
AND created_at <= $3
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"time"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/shared/database"
@@ -184,31 +185,102 @@ func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (
return &wallet, nil
}
// UpdateBalanceWithVersionRetry 乐观锁自动重试最大重试maxRetry次
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, newBalance string, oldVersion int64) (bool, error) {
// UpdateBalanceWithVersion 乐观锁自动重试最大重试maxRetry次
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, amount decimal.Decimal, operation string) (bool, error) {
maxRetry := 10
for i := 0; i < maxRetry; i++ {
result := r.GetDB(ctx).Model(&entities.Wallet{}).
Where("id = ? AND version = ?", walletID, oldVersion).
Updates(map[string]interface{}{
"balance": newBalance,
"version": oldVersion + 1,
})
if result.Error != nil {
return false, result.Error
}
if result.RowsAffected == 1 {
return true, nil
}
// 并发冲突重试前重新查version
// 每次重试都重新获取最新的钱包信息
var wallet entities.Wallet
err := r.GetDB(ctx).Where("id = ?", walletID).First(&wallet).Error
if err != nil {
return false, err
return false, fmt.Errorf("获取钱包信息失败: %w", err)
}
oldVersion = wallet.Version
// 重新计算新余额
var newBalance decimal.Decimal
switch operation {
case "add":
newBalance = wallet.Balance.Add(amount)
case "subtract":
newBalance = wallet.Balance.Sub(amount)
default:
return false, fmt.Errorf("不支持的操作类型: %s", operation)
}
// 乐观锁更新
result := r.GetDB(ctx).Model(&entities.Wallet{}).
Where("id = ? AND version = ?", walletID, wallet.Version).
Updates(map[string]interface{}{
"balance": newBalance.String(),
"version": wallet.Version + 1,
})
if result.Error != nil {
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
}
if result.RowsAffected == 1 {
return true, nil
}
// 乐观锁冲突,继续重试
// 注意这里可以添加日志记录但需要确保logger可用
}
return false, fmt.Errorf("高并发下余额变动失败,请重试")
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
}
// UpdateBalanceByUserID 乐观锁更新通过用户ID直接更新使用原生SQL
func (r *GormWalletRepository) UpdateBalanceByUserID(ctx context.Context, userID string, amount decimal.Decimal, operation string) (bool, error) {
maxRetry := 20 // 增加重试次数
baseDelay := 1 // 基础延迟毫秒
for i := 0; i < maxRetry; i++ {
// 每次重试都重新获取最新的钱包信息
var wallet entities.Wallet
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
return false, fmt.Errorf("获取钱包信息失败: %w", err)
}
// 重新计算新余额
var newBalance decimal.Decimal
switch operation {
case "add":
newBalance = wallet.Balance.Add(amount)
case "subtract":
newBalance = wallet.Balance.Sub(amount)
default:
return false, fmt.Errorf("不支持的操作类型: %s", operation)
}
// 使用原生SQL进行乐观锁更新
newVersion := wallet.Version + 1
result := r.GetDB(ctx).Exec(`
UPDATE wallets
SET balance = ?, version = ?, updated_at = NOW()
WHERE user_id = ? AND version = ?
`, newBalance.String(), newVersion, userID, wallet.Version)
if result.Error != nil {
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
}
if result.RowsAffected == 1 {
return true, nil
}
// 乐观锁冲突,添加指数退避延迟
if i < maxRetry-1 {
delay := baseDelay * (1 << i) // 指数退避: 1ms, 2ms, 4ms, 8ms...
if delay > 50 {
delay = 50 // 最大延迟50ms
}
time.Sleep(time.Duration(delay) * time.Millisecond)
}
}
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
}
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {

View File

@@ -2,6 +2,7 @@ package repositories
import (
"context"
"strings"
"time"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
@@ -150,6 +151,81 @@ func (r *GormWalletTransactionRepository) CountByUserId(ctx context.Context, use
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ?", userId)
}
// CountByUserIdAndDateRange 按用户ID和日期范围统计钱包交易次数
func (r *GormWalletTransactionRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
}
// GetTotalAmountByUserId 获取用户总消费金额
func (r *GormWalletTransactionRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ?", userId).
Scan(&total).Error
return total, err
}
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总消费金额
func (r *GormWalletTransactionRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate).
Scan(&total).Error
return total, err
}
// GetDailyStatsByUserId 获取用户每日消费统计
func (r *GormWalletTransactionRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE user_id = $1
AND DATE(created_at) >= $2
AND DATE(created_at) <= $3
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetMonthlyStatsByUserId 获取用户每月消费统计
func (r *GormWalletTransactionRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE user_id = $1
AND created_at >= $2
AND created_at <= $3
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// 实现interfaces.Repository接口的其他方法
func (r *GormWalletTransactionRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.WalletTransaction{})
@@ -391,4 +467,153 @@ func (r *GormWalletTransactionRepository) ListWithFiltersAndProductName(ctx cont
}
return productNameMap, transactions, total, nil
}
}
// ExportWithFiltersAndProductName 导出钱包交易记录(包含产品名称和企业信息)
func (r *GormWalletTransactionRepository) ExportWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}) ([]*entities.WalletTransaction, error) {
var transactionsWithProduct []WalletTransactionWithProduct
// 构建查询
query := r.GetDB(ctx).Table("wallet_transactions wt").
Select("wt.*, p.name as product_name").
Joins("LEFT JOIN product p ON wt.product_id = p.id")
// 构建WHERE条件
var whereConditions []string
var whereArgs []interface{}
// 用户ID筛选
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
whereConditions = append(whereConditions, "wt.user_id IN (?)")
whereArgs = append(whereArgs, strings.Split(userIds, ","))
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
whereConditions = append(whereConditions, "wt.user_id = ?")
whereArgs = append(whereArgs, userId)
}
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereConditions = append(whereConditions, "wt.created_at >= ?")
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereConditions = append(whereConditions, "wt.created_at <= ?")
whereArgs = append(whereArgs, endTime)
}
// 交易ID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereConditions = append(whereConditions, "wt.transaction_id LIKE ?")
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereConditions = append(whereConditions, "p.name LIKE ?")
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 产品ID列表筛选
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
whereConditions = append(whereConditions, "wt.product_id IN (?)")
whereArgs = append(whereArgs, strings.Split(productIds, ","))
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereConditions = append(whereConditions, "wt.amount >= ?")
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereConditions = append(whereConditions, "wt.amount <= ?")
whereArgs = append(whereArgs, maxAmount)
}
// 应用WHERE条件
if len(whereConditions) > 0 {
query = query.Where(strings.Join(whereConditions, " AND "), whereArgs...)
}
// 排序
query = query.Order("wt.created_at DESC")
// 执行查询
err := query.Find(&transactionsWithProduct).Error
if err != nil {
return nil, err
}
// 转换为entities.WalletTransaction
var transactions []*entities.WalletTransaction
for _, t := range transactionsWithProduct {
transaction := t.WalletTransaction
transactions = append(transactions, &transaction)
}
return transactions, nil
}
// GetSystemTotalAmount 获取系统总消费金额
func (r *GormWalletTransactionRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemAmountByDateRange 获取系统指定时间范围内的消费金额
func (r *GormWalletTransactionRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemDailyStats 获取系统每日消费统计
func (r *GormWalletTransactionRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE DATE(created_at) >= $1
AND DATE(created_at) <= $2
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyStats 获取系统每月消费统计
func (r *GormWalletTransactionRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE created_at >= $1
AND created_at <= $2
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -0,0 +1,461 @@
package statistics
import (
"context"
"fmt"
"gorm.io/gorm"
"tyapi-server/internal/domains/statistics/entities"
"tyapi-server/internal/domains/statistics/repositories"
)
// GormStatisticsDashboardRepository GORM统计仪表板仓储实现
type GormStatisticsDashboardRepository struct {
db *gorm.DB
}
// NewGormStatisticsDashboardRepository 创建GORM统计仪表板仓储
func NewGormStatisticsDashboardRepository(db *gorm.DB) repositories.StatisticsDashboardRepository {
return &GormStatisticsDashboardRepository{
db: db,
}
}
// Save 保存统计仪表板
func (r *GormStatisticsDashboardRepository) Save(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
if dashboard == nil {
return fmt.Errorf("统计仪表板不能为空")
}
// 验证仪表板
if err := dashboard.Validate(); err != nil {
return fmt.Errorf("统计仪表板验证失败: %w", err)
}
// 保存到数据库
result := r.db.WithContext(ctx).Save(dashboard)
if result.Error != nil {
return fmt.Errorf("保存统计仪表板失败: %w", result.Error)
}
return nil
}
// FindByID 根据ID查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsDashboard, error) {
if id == "" {
return nil, fmt.Errorf("仪表板ID不能为空")
}
var dashboard entities.StatisticsDashboard
result := r.db.WithContext(ctx).Where("id = ?", id).First(&dashboard)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("统计仪表板不存在")
}
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return &dashboard, nil
}
// FindByUser 根据用户查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userID == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("created_by = ?", userID)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindByUserRole 根据用户角色查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByUserRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// Update 更新统计仪表板
func (r *GormStatisticsDashboardRepository) Update(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
if dashboard == nil {
return fmt.Errorf("统计仪表板不能为空")
}
if dashboard.ID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
// 验证仪表板
if err := dashboard.Validate(); err != nil {
return fmt.Errorf("统计仪表板验证失败: %w", err)
}
// 更新数据库
result := r.db.WithContext(ctx).Save(dashboard)
if result.Error != nil {
return fmt.Errorf("更新统计仪表板失败: %w", result.Error)
}
return nil
}
// Delete 删除统计仪表板
func (r *GormStatisticsDashboardRepository) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("仪表板ID不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("删除统计仪表板失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("统计仪表板不存在")
}
return nil
}
// FindByRole 根据角色查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindDefaultByRole 根据角色查找默认统计仪表板
func (r *GormStatisticsDashboardRepository) FindDefaultByRole(ctx context.Context, userRole string) (*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboard entities.StatisticsDashboard
result := r.db.WithContext(ctx).
Where("user_role = ? AND is_default = ? AND is_active = ?", userRole, true, true).
First(&dashboard)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("默认统计仪表板不存在")
}
return nil, fmt.Errorf("查询默认统计仪表板失败: %w", result.Error)
}
return &dashboard, nil
}
// FindActiveByRole 根据角色查找激活的统计仪表板
func (r *GormStatisticsDashboardRepository) FindActiveByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).
Where("user_role = ? AND is_active = ?", userRole, true)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询激活统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindByStatus 根据状态查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByStatus(ctx context.Context, isActive bool, limit, offset int) ([]*entities.StatisticsDashboard, error) {
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("is_active = ?", isActive)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindByAccessLevel 根据访问级别查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByAccessLevel(ctx context.Context, accessLevel string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if accessLevel == "" {
return nil, fmt.Errorf("访问级别不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("access_level = ?", accessLevel)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// CountByUser 根据用户统计数量
func (r *GormStatisticsDashboardRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
if userID == "" {
return 0, fmt.Errorf("用户ID不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("created_by = ?", userID).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
}
return count, nil
}
// CountByRole 根据角色统计数量
func (r *GormStatisticsDashboardRepository) CountByRole(ctx context.Context, userRole string) (int64, error) {
if userRole == "" {
return 0, fmt.Errorf("用户角色不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("user_role = ?", userRole).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
}
return count, nil
}
// CountByStatus 根据状态统计数量
func (r *GormStatisticsDashboardRepository) CountByStatus(ctx context.Context, isActive bool) (int64, error) {
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("is_active = ?", isActive).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
}
return count, nil
}
// BatchSave 批量保存统计仪表板
func (r *GormStatisticsDashboardRepository) BatchSave(ctx context.Context, dashboards []*entities.StatisticsDashboard) error {
if len(dashboards) == 0 {
return fmt.Errorf("统计仪表板列表不能为空")
}
// 验证所有仪表板
for _, dashboard := range dashboards {
if err := dashboard.Validate(); err != nil {
return fmt.Errorf("统计仪表板验证失败: %w", err)
}
}
// 批量保存
result := r.db.WithContext(ctx).CreateInBatches(dashboards, 100)
if result.Error != nil {
return fmt.Errorf("批量保存统计仪表板失败: %w", result.Error)
}
return nil
}
// BatchDelete 批量删除统计仪表板
func (r *GormStatisticsDashboardRepository) BatchDelete(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return fmt.Errorf("仪表板ID列表不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id IN ?", ids)
if result.Error != nil {
return fmt.Errorf("批量删除统计仪表板失败: %w", result.Error)
}
return nil
}
// SetDefaultDashboard 设置默认仪表板
func (r *GormStatisticsDashboardRepository) SetDefaultDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
// 开始事务
tx := r.db.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 先取消同角色的所有默认状态
var dashboard entities.StatisticsDashboard
if err := tx.Where("id = ?", dashboardID).First(&dashboard).Error; err != nil {
tx.Rollback()
return fmt.Errorf("查询仪表板失败: %w", err)
}
// 取消同角色的所有默认状态
if err := tx.Model(&entities.StatisticsDashboard{}).
Where("user_role = ? AND is_default = ?", dashboard.UserRole, true).
Update("is_default", false).Error; err != nil {
tx.Rollback()
return fmt.Errorf("取消默认状态失败: %w", err)
}
// 设置新的默认状态
if err := tx.Model(&entities.StatisticsDashboard{}).
Where("id = ?", dashboardID).
Update("is_default", true).Error; err != nil {
tx.Rollback()
return fmt.Errorf("设置默认状态失败: %w", err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
// RemoveDefaultDashboard 移除默认仪表板
func (r *GormStatisticsDashboardRepository) RemoveDefaultDashboard(ctx context.Context, userRole string) error {
if userRole == "" {
return fmt.Errorf("用户角色不能为空")
}
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("user_role = ? AND is_default = ?", userRole, true).
Update("is_default", false)
if result.Error != nil {
return fmt.Errorf("移除默认仪表板失败: %w", result.Error)
}
return nil
}
// ActivateDashboard 激活仪表板
func (r *GormStatisticsDashboardRepository) ActivateDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("id = ?", dashboardID).
Update("is_active", true)
if result.Error != nil {
return fmt.Errorf("激活仪表板失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("仪表板不存在")
}
return nil
}
// DeactivateDashboard 停用仪表板
func (r *GormStatisticsDashboardRepository) DeactivateDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("id = ?", dashboardID).
Update("is_active", false)
if result.Error != nil {
return fmt.Errorf("停用仪表板失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("仪表板不存在")
}
return nil
}

View File

@@ -0,0 +1,377 @@
package statistics
import (
"context"
"fmt"
"time"
"gorm.io/gorm"
"tyapi-server/internal/domains/statistics/entities"
"tyapi-server/internal/domains/statistics/repositories"
)
// GormStatisticsReportRepository GORM统计报告仓储实现
type GormStatisticsReportRepository struct {
db *gorm.DB
}
// NewGormStatisticsReportRepository 创建GORM统计报告仓储
func NewGormStatisticsReportRepository(db *gorm.DB) repositories.StatisticsReportRepository {
return &GormStatisticsReportRepository{
db: db,
}
}
// Save 保存统计报告
func (r *GormStatisticsReportRepository) Save(ctx context.Context, report *entities.StatisticsReport) error {
if report == nil {
return fmt.Errorf("统计报告不能为空")
}
// 验证报告
if err := report.Validate(); err != nil {
return fmt.Errorf("统计报告验证失败: %w", err)
}
// 保存到数据库
result := r.db.WithContext(ctx).Save(report)
if result.Error != nil {
return fmt.Errorf("保存统计报告失败: %w", result.Error)
}
return nil
}
// FindByID 根据ID查找统计报告
func (r *GormStatisticsReportRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsReport, error) {
if id == "" {
return nil, fmt.Errorf("报告ID不能为空")
}
var report entities.StatisticsReport
result := r.db.WithContext(ctx).Where("id = ?", id).First(&report)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("统计报告不存在")
}
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return &report, nil
}
// FindByUser 根据用户查找统计报告
func (r *GormStatisticsReportRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsReport, error) {
if userID == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).Where("generated_by = ?", userID)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByStatus 根据状态查找统计报告
func (r *GormStatisticsReportRepository) FindByStatus(ctx context.Context, status string) ([]*entities.StatisticsReport, error) {
if status == "" {
return nil, fmt.Errorf("报告状态不能为空")
}
var reports []*entities.StatisticsReport
result := r.db.WithContext(ctx).
Where("status = ?", status).
Order("created_at DESC").
Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// Update 更新统计报告
func (r *GormStatisticsReportRepository) Update(ctx context.Context, report *entities.StatisticsReport) error {
if report == nil {
return fmt.Errorf("统计报告不能为空")
}
if report.ID == "" {
return fmt.Errorf("报告ID不能为空")
}
// 验证报告
if err := report.Validate(); err != nil {
return fmt.Errorf("统计报告验证失败: %w", err)
}
// 更新数据库
result := r.db.WithContext(ctx).Save(report)
if result.Error != nil {
return fmt.Errorf("更新统计报告失败: %w", result.Error)
}
return nil
}
// Delete 删除统计报告
func (r *GormStatisticsReportRepository) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("报告ID不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("删除统计报告失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("统计报告不存在")
}
return nil
}
// FindByType 根据类型查找统计报告
func (r *GormStatisticsReportRepository) FindByType(ctx context.Context, reportType string, limit, offset int) ([]*entities.StatisticsReport, error) {
if reportType == "" {
return nil, fmt.Errorf("报告类型不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).Where("report_type = ?", reportType)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByTypeAndPeriod 根据类型和周期查找统计报告
func (r *GormStatisticsReportRepository) FindByTypeAndPeriod(ctx context.Context, reportType, period string, limit, offset int) ([]*entities.StatisticsReport, error) {
if reportType == "" {
return nil, fmt.Errorf("报告类型不能为空")
}
if period == "" {
return nil, fmt.Errorf("统计周期不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).
Where("report_type = ? AND period = ?", reportType, period)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByDateRange 根据日期范围查找统计报告
func (r *GormStatisticsReportRepository) FindByDateRange(ctx context.Context, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).
Where("created_at >= ? AND created_at < ?", startDate, endDate)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByUserAndDateRange 根据用户和日期范围查找统计报告
func (r *GormStatisticsReportRepository) FindByUserAndDateRange(ctx context.Context, userID string, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
if userID == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).
Where("generated_by = ? AND created_at >= ? AND created_at < ?", userID, startDate, endDate)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// CountByUser 根据用户统计数量
func (r *GormStatisticsReportRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
if userID == "" {
return 0, fmt.Errorf("用户ID不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsReport{}).
Where("generated_by = ?", userID).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
}
return count, nil
}
// CountByType 根据类型统计数量
func (r *GormStatisticsReportRepository) CountByType(ctx context.Context, reportType string) (int64, error) {
if reportType == "" {
return 0, fmt.Errorf("报告类型不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsReport{}).
Where("report_type = ?", reportType).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
}
return count, nil
}
// CountByStatus 根据状态统计数量
func (r *GormStatisticsReportRepository) CountByStatus(ctx context.Context, status string) (int64, error) {
if status == "" {
return 0, fmt.Errorf("报告状态不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsReport{}).
Where("status = ?", status).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
}
return count, nil
}
// BatchSave 批量保存统计报告
func (r *GormStatisticsReportRepository) BatchSave(ctx context.Context, reports []*entities.StatisticsReport) error {
if len(reports) == 0 {
return fmt.Errorf("统计报告列表不能为空")
}
// 验证所有报告
for _, report := range reports {
if err := report.Validate(); err != nil {
return fmt.Errorf("统计报告验证失败: %w", err)
}
}
// 批量保存
result := r.db.WithContext(ctx).CreateInBatches(reports, 100)
if result.Error != nil {
return fmt.Errorf("批量保存统计报告失败: %w", result.Error)
}
return nil
}
// BatchDelete 批量删除统计报告
func (r *GormStatisticsReportRepository) BatchDelete(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return fmt.Errorf("报告ID列表不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id IN ?", ids)
if result.Error != nil {
return fmt.Errorf("批量删除统计报告失败: %w", result.Error)
}
return nil
}
// DeleteExpiredReports 删除过期报告
func (r *GormStatisticsReportRepository) DeleteExpiredReports(ctx context.Context, expiredBefore time.Time) error {
if expiredBefore.IsZero() {
return fmt.Errorf("过期时间不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsReport{}, "expires_at IS NOT NULL AND expires_at < ?", expiredBefore)
if result.Error != nil {
return fmt.Errorf("删除过期报告失败: %w", result.Error)
}
return nil
}
// DeleteByStatus 根据状态删除统计报告
func (r *GormStatisticsReportRepository) DeleteByStatus(ctx context.Context, status string) error {
if status == "" {
return fmt.Errorf("报告状态不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsReport{}, "status = ?", status)
if result.Error != nil {
return fmt.Errorf("根据状态删除统计报告失败: %w", result.Error)
}
return nil
}

View File

@@ -0,0 +1,381 @@
package statistics
import (
"context"
"fmt"
"time"
"gorm.io/gorm"
"tyapi-server/internal/domains/statistics/entities"
"tyapi-server/internal/domains/statistics/repositories"
)
// GormStatisticsRepository GORM统计指标仓储实现
type GormStatisticsRepository struct {
db *gorm.DB
}
// NewGormStatisticsRepository 创建GORM统计指标仓储
func NewGormStatisticsRepository(db *gorm.DB) repositories.StatisticsRepository {
return &GormStatisticsRepository{
db: db,
}
}
// Save 保存统计指标
func (r *GormStatisticsRepository) Save(ctx context.Context, metric *entities.StatisticsMetric) error {
if metric == nil {
return fmt.Errorf("统计指标不能为空")
}
// 验证指标
if err := metric.Validate(); err != nil {
return fmt.Errorf("统计指标验证失败: %w", err)
}
// 保存到数据库
result := r.db.WithContext(ctx).Create(metric)
if result.Error != nil {
return fmt.Errorf("保存统计指标失败: %w", result.Error)
}
return nil
}
// FindByID 根据ID查找统计指标
func (r *GormStatisticsRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsMetric, error) {
if id == "" {
return nil, fmt.Errorf("指标ID不能为空")
}
var metric entities.StatisticsMetric
result := r.db.WithContext(ctx).Where("id = ?", id).First(&metric)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("统计指标不存在")
}
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return &metric, nil
}
// FindByType 根据类型查找统计指标
func (r *GormStatisticsRepository) FindByType(ctx context.Context, metricType string, limit, offset int) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
var metrics []*entities.StatisticsMetric
query := r.db.WithContext(ctx).Where("metric_type = ?", metricType)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// Update 更新统计指标
func (r *GormStatisticsRepository) Update(ctx context.Context, metric *entities.StatisticsMetric) error {
if metric == nil {
return fmt.Errorf("统计指标不能为空")
}
if metric.ID == "" {
return fmt.Errorf("指标ID不能为空")
}
// 验证指标
if err := metric.Validate(); err != nil {
return fmt.Errorf("统计指标验证失败: %w", err)
}
// 更新数据库
result := r.db.WithContext(ctx).Save(metric)
if result.Error != nil {
return fmt.Errorf("更新统计指标失败: %w", result.Error)
}
return nil
}
// Delete 删除统计指标
func (r *GormStatisticsRepository) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("指标ID不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("删除统计指标失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("统计指标不存在")
}
return nil
}
// FindByTypeAndDateRange 根据类型和日期范围查找统计指标
func (r *GormStatisticsRepository) FindByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
result := r.db.WithContext(ctx).
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
Order("date ASC").
Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// FindByTypeDimensionAndDateRange 根据类型、维度和日期范围查找统计指标
func (r *GormStatisticsRepository) FindByTypeDimensionAndDateRange(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
query := r.db.WithContext(ctx).
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate)
if dimension != "" {
query = query.Where("dimension = ?", dimension)
}
result := query.Order("date ASC").Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// FindByTypeNameAndDateRange 根据类型、名称和日期范围查找统计指标
func (r *GormStatisticsRepository) FindByTypeNameAndDateRange(ctx context.Context, metricType, metricName string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if metricName == "" {
return nil, fmt.Errorf("指标名称不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
result := r.db.WithContext(ctx).
Where("metric_type = ? AND metric_name = ? AND date >= ? AND date < ?",
metricType, metricName, startDate, endDate).
Order("date ASC").
Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// GetAggregatedMetrics 获取聚合指标
func (r *GormStatisticsRepository) GetAggregatedMetrics(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) (map[string]float64, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
type AggregatedResult struct {
MetricName string `json:"metric_name"`
TotalValue float64 `json:"total_value"`
}
var results []AggregatedResult
query := r.db.WithContext(ctx).
Model(&entities.StatisticsMetric{}).
Select("metric_name, SUM(value) as total_value").
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
Group("metric_name")
if dimension != "" {
query = query.Where("dimension = ?", dimension)
}
result := query.Find(&results)
if result.Error != nil {
return nil, fmt.Errorf("查询聚合指标失败: %w", result.Error)
}
// 转换为map
aggregated := make(map[string]float64)
for _, res := range results {
aggregated[res.MetricName] = res.TotalValue
}
return aggregated, nil
}
// GetMetricsByDimension 根据维度获取指标
func (r *GormStatisticsRepository) GetMetricsByDimension(ctx context.Context, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if dimension == "" {
return nil, fmt.Errorf("统计维度不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
result := r.db.WithContext(ctx).
Where("dimension = ? AND date >= ? AND date < ?", dimension, startDate, endDate).
Order("date ASC").
Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// CountByType 根据类型统计数量
func (r *GormStatisticsRepository) CountByType(ctx context.Context, metricType string) (int64, error) {
if metricType == "" {
return 0, fmt.Errorf("指标类型不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsMetric{}).
Where("metric_type = ?", metricType).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
}
return count, nil
}
// CountByTypeAndDateRange 根据类型和日期范围统计数量
func (r *GormStatisticsRepository) CountByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) (int64, error) {
if metricType == "" {
return 0, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return 0, fmt.Errorf("开始日期和结束日期不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsMetric{}).
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
}
return count, nil
}
// BatchSave 批量保存统计指标
func (r *GormStatisticsRepository) BatchSave(ctx context.Context, metrics []*entities.StatisticsMetric) error {
if len(metrics) == 0 {
return fmt.Errorf("统计指标列表不能为空")
}
// 验证所有指标
for _, metric := range metrics {
if err := metric.Validate(); err != nil {
return fmt.Errorf("统计指标验证失败: %w", err)
}
}
// 批量保存
result := r.db.WithContext(ctx).CreateInBatches(metrics, 100)
if result.Error != nil {
return fmt.Errorf("批量保存统计指标失败: %w", result.Error)
}
return nil
}
// BatchDelete 批量删除统计指标
func (r *GormStatisticsRepository) BatchDelete(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return fmt.Errorf("指标ID列表不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id IN ?", ids)
if result.Error != nil {
return fmt.Errorf("批量删除统计指标失败: %w", result.Error)
}
return nil
}
// DeleteByDateRange 根据日期范围删除统计指标
func (r *GormStatisticsRepository) DeleteByDateRange(ctx context.Context, startDate, endDate time.Time) error {
if startDate.IsZero() || endDate.IsZero() {
return fmt.Errorf("开始日期和结束日期不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsMetric{}, "date >= ? AND date < ?", startDate, endDate)
if result.Error != nil {
return fmt.Errorf("根据日期范围删除统计指标失败: %w", result.Error)
}
return nil
}
// DeleteByTypeAndDateRange 根据类型和日期范围删除统计指标
func (r *GormStatisticsRepository) DeleteByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) error {
if metricType == "" {
return fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return fmt.Errorf("开始日期和结束日期不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsMetric{}, "metric_type = ? AND date >= ? AND date < ?",
metricType, startDate, endDate)
if result.Error != nil {
return fmt.Errorf("根据类型和日期范围删除统计指标失败: %w", result.Error)
}
return nil
}

View File

@@ -6,6 +6,7 @@ package repositories
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
@@ -71,6 +72,20 @@ func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id s
return user, nil
}
func (r *GormUserRepository) BatchGetByIDsWithEnterpriseInfo(ctx context.Context, ids []string) ([]*entities.User, error) {
if len(ids) == 0 {
return []*entities.User{}, nil
}
var users []*entities.User
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id IN ?", ids).Find(&users).Error; err != nil {
r.GetLogger().Error("批量查询用户失败", zap.Error(err), zap.Strings("ids", ids))
return nil, err
}
return users, nil
}
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.User{}).
@@ -337,3 +352,315 @@ func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate,
return &stats, nil
}
// GetSystemUserStats 获取系统用户统计信息
func (r *GormUserRepository) GetSystemUserStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.GetDB(ctx)
// 总用户数
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
return nil, err
}
// 活跃用户数最近30天有登录
thirtyDaysAgo := time.Now().AddDate(0, 0, -30)
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", thirtyDaysAgo).Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 已认证用户数
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetSystemUserStatsByDateRange 获取系统指定时间范围内的用户统计信息
func (r *GormUserRepository) GetSystemUserStatsByDateRange(ctx context.Context, startDate, endDate time.Time) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.GetDB(ctx)
// 指定时间范围内的注册数
if err := db.Model(&entities.User{}).
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 指定时间范围内的登录数
if err := db.Model(&entities.User{}).
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetSystemDailyUserStats 获取系统每日用户统计
func (r *GormUserRepository) GetSystemDailyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COUNT(*) as count
FROM users
WHERE DATE(created_at) >= $1
AND DATE(created_at) <= $2
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyUserStats 获取系统每月用户统计
func (r *GormUserRepository) GetSystemMonthlyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COUNT(*) as count
FROM users
WHERE created_at >= $1
AND created_at <= $2
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetUserCallRankingByCalls 按调用次数获取用户排行
func (r *GormUserRepository) GetUserCallRankingByCalls(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COUNT(ac.id) as calls
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN api_calls ac ON u.id = ac.user_id
AND DATE(ac.created_at) = CURRENT_DATE
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COUNT(ac.id) > 0
ORDER BY calls DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COUNT(ac.id) as calls
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN api_calls ac ON u.id = ac.user_id
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COUNT(ac.id) > 0
ORDER BY calls DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COUNT(ac.id) as calls
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN api_calls ac ON u.id = ac.user_id
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COUNT(ac.id) > 0
ORDER BY calls DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetUserCallRankingByConsumption 按消费金额获取用户排行
func (r *GormUserRepository) GetUserCallRankingByConsumption(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(wt.amount), 0) as consumption
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
AND DATE(wt.created_at) = CURRENT_DATE
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(wt.amount), 0) > 0
ORDER BY consumption DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(wt.amount), 0) as consumption
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
AND DATE_TRUNC('month', wt.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(wt.amount), 0) > 0
ORDER BY consumption DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(wt.amount), 0) as consumption
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(wt.amount), 0) > 0
ORDER BY consumption DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetRechargeRanking 获取充值排行
func (r *GormUserRepository) GetRechargeRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(rr.amount), 0) as amount
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN recharge_records rr ON u.id = rr.user_id
AND DATE(rr.created_at) = CURRENT_DATE
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(rr.amount), 0) > 0
ORDER BY amount DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(rr.amount), 0) as amount
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN recharge_records rr ON u.id = rr.user_id
AND DATE_TRUNC('month', rr.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(rr.amount), 0) > 0
ORDER BY amount DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(rr.amount), 0) as amount
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN recharge_records rr ON u.id = rr.user_id
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(rr.amount), 0) > 0
ORDER BY amount DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -302,6 +302,12 @@ func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface
registeredCapital = registeredCapitalObj["words"].(string)
}
// 提取企业地址
address := ""
if addressObj, ok := wordsResult["地址"].(map[string]interface{}); ok {
address = addressObj["words"].(string)
}
// 计算置信度这里简化处理实际应该从OCR结果中获取
confidence := 0.9 // 默认置信度
@@ -309,8 +315,11 @@ func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface
CompanyName: companyName,
UnifiedSocialCode: unifiedSocialCode,
LegalPersonName: legalPersonName,
LegalPersonID: "", // 营业执照上没有法人身份证号
RegisteredCapital: registeredCapital,
Address: address,
Confidence: confidence,
ProcessedAt: time.Now(),
}
}

View File

@@ -5,6 +5,7 @@ import (
"crypto/rand"
"fmt"
"math/big"
"time"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
"go.uber.org/zap"
@@ -64,6 +65,73 @@ func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string,
return nil
}
// SendBalanceAlert 发送余额预警短信
func (s *AliSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.PhoneNumbers = phone
request.SignName = s.config.SignName
var templateCode string
var templateParam string
if alertType == "low_balance" {
// 低余额预警也使用欠费预警模板
templateCode = "SMS_494605047" // 阿里云欠费预警模板
// 使用传入的企业名称,如果没有则使用默认值
name := "天远数据用户"
if len(enterpriseName) > 0 && enterpriseName[0] != "" {
name = enterpriseName[0]
}
templateParam = fmt.Sprintf(`{"name":"%s","time":"%s","money":"%.2f"}`,
name, time.Now().Format("2006-01-02 15:04:05"), threshold)
} else if alertType == "arrears" {
// 欠费预警模板
templateCode = "SMS_494605047" // 阿里云欠费预警模板
// 使用传入的企业名称,如果没有则使用默认值
name := "天远数据用户"
if len(enterpriseName) > 0 && enterpriseName[0] != "" {
name = enterpriseName[0]
}
templateParam = fmt.Sprintf(`{"name":"%s","time":"%s","money":"%.2f"}`,
name, time.Now().Format("2006-01-02 15:04:05"), balance)
} else {
return fmt.Errorf("不支持的预警类型: %s", alertType)
}
request.TemplateCode = templateCode
request.TemplateParam = templateParam
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("发送余额预警短信失败",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Code != "OK" {
s.logger.Error("余额预警短信发送失败",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.String("code", response.Code),
zap.String("message", response.Message))
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
}
s.logger.Info("余额预警短信发送成功",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.String("bizId", response.BizId))
return nil
}
// GenerateCode 生成验证码
func (s *AliSMSService) GenerateCode(length int) string {
if length <= 0 {

View File

@@ -468,6 +468,77 @@ func (h *ApiHandler) GetAdminApiCalls(c *gin.Context) {
h.responseBuilder.Success(c, result, "获取API调用记录成功")
}
// ExportAdminApiCalls 导出管理端API调用记录
// @Summary 导出管理端API调用记录
// @Description 管理员导出API调用记录支持Excel和CSV格式
// @Tags API调用管理
// @Accept json
// @Produce application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,text/csv
// @Security Bearer
// @Param user_ids query string false "用户ID列表逗号分隔"
// @Param product_ids query string false "产品ID列表逗号分隔"
// @Param start_time query string false "开始时间" format(date-time)
// @Param end_time query string false "结束时间" format(date-time)
// @Param format query string false "导出格式" Enums(excel, csv) default(excel)
// @Success 200 {file} file "导出文件"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/api-calls/export [get]
func (h *ApiHandler) ExportAdminApiCalls(c *gin.Context) {
// 解析查询参数
filters := make(map[string]interface{})
// 用户ID筛选
if userIds := c.Query("user_ids"); userIds != "" {
filters["user_ids"] = userIds
}
// 产品ID筛选
if productIds := c.Query("product_ids"); productIds != "" {
filters["product_ids"] = productIds
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 获取导出格式默认为excel
format := c.DefaultQuery("format", "excel")
if format != "excel" && format != "csv" {
h.responseBuilder.BadRequest(c, "不支持的导出格式")
return
}
// 调用应用服务导出数据
fileData, err := h.appService.ExportAdminApiCalls(c.Request.Context(), filters, format)
if err != nil {
h.logger.Error("导出API调用记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "导出API调用记录失败")
return
}
// 设置响应头
contentType := "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
filename := "API调用记录.xlsx"
if format == "csv" {
contentType = "text/csv;charset=utf-8"
filename = "API调用记录.csv"
}
c.Header("Content-Type", contentType)
c.Header("Content-Disposition", "attachment; filename="+filename)
c.Data(200, contentType, fileData)
}
// getIntQuery 获取整数查询参数
func (h *ApiHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
if value := c.Query(key); value != "" {
@@ -477,3 +548,116 @@ func (h *ApiHandler) getIntQuery(c *gin.Context, key string, defaultValue int) i
}
return defaultValue
}
// GetUserBalanceAlertSettings 获取用户余额预警设置
// @Summary 获取用户余额预警设置
// @Description 获取当前用户的余额预警配置
// @Tags 用户设置
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "获取成功"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/user/balance-alert/settings [get]
func (h *ApiHandler) GetUserBalanceAlertSettings(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
settings, err := h.appService.GetUserBalanceAlertSettings(c.Request.Context(), userID)
if err != nil {
h.logger.Error("获取用户余额预警设置失败",
zap.String("user_id", userID),
zap.Error(err))
h.responseBuilder.InternalError(c, "获取预警设置失败")
return
}
h.responseBuilder.Success(c, settings, "获取成功")
}
// UpdateUserBalanceAlertSettings 更新用户余额预警设置
// @Summary 更新用户余额预警设置
// @Description 更新当前用户的余额预警配置
// @Tags 用户设置
// @Accept json
// @Produce json
// @Param request body map[string]interface{} true "预警设置"
// @Success 200 {object} map[string]interface{} "更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/user/balance-alert/settings [put]
func (h *ApiHandler) UpdateUserBalanceAlertSettings(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var request struct {
Enabled bool `json:"enabled" binding:"required"`
Threshold float64 `json:"threshold" binding:"required,min=0"`
AlertPhone string `json:"alert_phone" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误: "+err.Error())
return
}
err := h.appService.UpdateUserBalanceAlertSettings(c.Request.Context(), userID, request.Enabled, request.Threshold, request.AlertPhone)
if err != nil {
h.logger.Error("更新用户余额预警设置失败",
zap.String("user_id", userID),
zap.Error(err))
h.responseBuilder.InternalError(c, "更新预警设置失败")
return
}
h.responseBuilder.Success(c, gin.H{}, "更新成功")
}
// TestBalanceAlertSms 测试余额预警短信
// @Summary 测试余额预警短信
// @Description 发送测试预警短信到指定手机号
// @Tags 用户设置
// @Accept json
// @Produce json
// @Param request body map[string]interface{} true "测试参数"
// @Success 200 {object} map[string]interface{} "发送成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/user/balance-alert/test-sms [post]
func (h *ApiHandler) TestBalanceAlertSms(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var request struct {
Phone string `json:"phone" binding:"required,len=11"`
Balance float64 `json:"balance" binding:"required"`
AlertType string `json:"alert_type" binding:"required,oneof=low_balance arrears"`
}
if err := c.ShouldBindJSON(&request); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误: "+err.Error())
return
}
err := h.appService.TestBalanceAlertSms(c.Request.Context(), userID, request.Phone, request.Balance, request.AlertType)
if err != nil {
h.logger.Error("发送测试预警短信失败",
zap.String("user_id", userID),
zap.Error(err))
h.responseBuilder.InternalError(c, "发送测试短信失败")
return
}
h.responseBuilder.Success(c, gin.H{}, "测试短信发送成功")
}

View File

@@ -215,6 +215,86 @@ func (h *CertificationHandler) ApplyContract(c *gin.Context) {
h.response.Success(c, result, "合同申请成功")
}
// RecognizeBusinessLicense OCR识别营业执照
// @Summary OCR识别营业执照
// @Description 上传营业执照图片进行OCR识别自动填充企业信息
// @Tags 认证管理
// @Accept multipart/form-data
// @Produce json
// @Security Bearer
// @Param image formData file true "营业执照图片文件"
// @Success 200 {object} responses.BusinessLicenseResult "营业执照识别成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/ocr/business-license [post]
func (h *CertificationHandler) RecognizeBusinessLicense(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
// 获取上传的文件
file, err := c.FormFile("image")
if err != nil {
h.logger.Error("获取上传文件失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "请选择要上传的营业执照图片")
return
}
// 验证文件类型
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/png": true,
"image/webp": true,
}
if !allowedTypes[file.Header.Get("Content-Type")] {
h.response.BadRequest(c, "只支持JPG、PNG、WEBP格式的图片")
return
}
// 验证文件大小限制为5MB
if file.Size > 5*1024*1024 {
h.response.BadRequest(c, "图片大小不能超过5MB")
return
}
// 打开文件
src, err := file.Open()
if err != nil {
h.logger.Error("打开上传文件失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "文件读取失败")
return
}
defer src.Close()
// 读取文件内容
imageBytes, err := io.ReadAll(src)
if err != nil {
h.logger.Error("读取文件内容失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "文件读取失败")
return
}
// 调用OCR服务识别营业执照
result, err := h.appService.RecognizeBusinessLicense(c.Request.Context(), imageBytes)
if err != nil {
h.logger.Error("营业执照OCR识别失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "营业执照识别失败:"+err.Error())
return
}
h.logger.Info("营业执照OCR识别成功",
zap.String("user_id", userID),
zap.String("company_name", result.CompanyName),
zap.Float64("confidence", result.Confidence),
)
h.response.Success(c, result, "营业执照识别成功")
}
// ListCertifications 获取认证列表(管理员)
// @Summary 获取认证列表
// @Description 管理员获取认证申请列表

View File

@@ -1199,6 +1199,102 @@ func (h *ProductAdminHandler) GetAdminWalletTransactions(c *gin.Context) {
h.responseBuilder.Success(c, result, "获取消费记录成功")
}
// ExportAdminWalletTransactions 导出管理端消费记录
// @Summary 导出管理端消费记录
// @Description 管理员导出消费记录支持Excel和CSV格式
// @Tags 财务管理
// @Accept json
// @Produce application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,text/csv
// @Security Bearer
// @Param user_ids query string false "用户ID列表逗号分隔"
// @Param user_id query string false "单个用户ID"
// @Param transaction_id query string false "交易ID"
// @Param product_name query string false "产品名称"
// @Param product_ids query string false "产品ID列表逗号分隔"
// @Param min_amount query string false "最小金额"
// @Param max_amount query string false "最大金额"
// @Param start_time query string false "开始时间" format(date-time)
// @Param end_time query string false "结束时间" format(date-time)
// @Param format query string false "导出格式" Enums(excel, csv) default(excel)
// @Success 200 {file} file "导出文件"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/wallet-transactions/export [get]
func (h *ProductAdminHandler) ExportAdminWalletTransactions(c *gin.Context) {
// 构建筛选条件
filters := make(map[string]interface{})
// 用户ID筛选
if userIds := c.Query("user_ids"); userIds != "" {
filters["user_ids"] = userIds
} else if userId := c.Query("user_id"); userId != "" {
filters["user_id"] = userId
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 交易ID筛选
if transactionId := c.Query("transaction_id"); transactionId != "" {
filters["transaction_id"] = transactionId
}
// 产品名称筛选
if productName := c.Query("product_name"); productName != "" {
filters["product_name"] = productName
}
// 产品ID列表筛选
if productIds := c.Query("product_ids"); productIds != "" {
filters["product_ids"] = productIds
}
// 金额范围筛选
if minAmount := c.Query("min_amount"); minAmount != "" {
filters["min_amount"] = minAmount
}
if maxAmount := c.Query("max_amount"); maxAmount != "" {
filters["max_amount"] = maxAmount
}
// 获取导出格式
format := c.DefaultQuery("format", "excel")
if format != "excel" && format != "csv" {
h.responseBuilder.BadRequest(c, "不支持的导出格式")
return
}
// 调用导出服务
fileData, err := h.financeAppService.ExportAdminWalletTransactions(c.Request.Context(), filters, format)
if err != nil {
h.logger.Error("导出消费记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "导出消费记录失败")
return
}
// 设置响应头
contentType := "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
filename := "消费记录.xlsx"
if format == "csv" {
contentType = "text/csv;charset=utf-8"
filename = "消费记录.csv"
}
c.Header("Content-Type", contentType)
c.Header("Content-Disposition", "attachment; filename="+filename)
c.Data(200, contentType, fileData)
}
// GetAdminRechargeRecords 获取管理端充值记录
// @Summary 获取管理端充值记录
// @Description 管理员获取充值记录,支持筛选和分页
@@ -1282,3 +1378,184 @@ func (h *ProductAdminHandler) GetAdminRechargeRecords(c *gin.Context) {
h.responseBuilder.Success(c, result, "获取充值记录成功")
}
// ExportAdminRechargeRecords 导出管理端充值记录
// @Summary 导出管理端充值记录
// @Description 管理员导出充值记录支持Excel和CSV格式
// @Tags 财务管理
// @Accept json
// @Produce application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,text/csv
// @Security Bearer
// @Param user_ids query string false "用户ID列表逗号分隔"
// @Param recharge_type query string false "充值类型" Enums(alipay, transfer, gift)
// @Param status query string false "状态" Enums(pending, success, failed)
// @Param start_time query string false "开始时间" format(date-time)
// @Param end_time query string false "结束时间" format(date-time)
// @Param format query string false "导出格式" Enums(excel, csv) default(excel)
// @Success 200 {file} file "导出文件"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/recharge-records/export [get]
func (h *ProductAdminHandler) ExportAdminRechargeRecords(c *gin.Context) {
// 解析查询参数
filters := make(map[string]interface{})
// 用户ID筛选
if userIds := c.Query("user_ids"); userIds != "" {
filters["user_ids"] = userIds
}
// 充值类型筛选
if rechargeType := c.Query("recharge_type"); rechargeType != "" {
filters["recharge_type"] = rechargeType
}
// 状态筛选
if status := c.Query("status"); status != "" {
filters["status"] = status
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 获取导出格式默认为excel
format := c.DefaultQuery("format", "excel")
if format != "excel" && format != "csv" {
h.responseBuilder.BadRequest(c, "不支持的导出格式")
return
}
// 调用应用服务导出数据
fileData, err := h.financeAppService.ExportAdminRechargeRecords(c.Request.Context(), filters, format)
if err != nil {
h.logger.Error("导出充值记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "导出充值记录失败")
return
}
// 设置响应头
contentType := "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
filename := "充值记录.xlsx"
if format == "csv" {
contentType = "text/csv;charset=utf-8"
filename = "充值记录.csv"
}
c.Header("Content-Type", contentType)
c.Header("Content-Disposition", "attachment; filename="+filename)
c.Data(200, contentType, fileData)
}
// GetAdminApiCalls 获取管理端API调用记录
func (h *ProductAdminHandler) GetAdminApiCalls(c *gin.Context) {
// 解析查询参数
page := h.getIntQuery(c, "page", 1)
pageSize := h.getIntQuery(c, "page_size", 10)
// 构建筛选条件
filters := make(map[string]interface{})
// 用户ID筛选
if userIds := c.Query("user_ids"); userIds != "" {
filters["user_ids"] = userIds
}
// 产品ID筛选
if productIds := c.Query("product_ids"); productIds != "" {
filters["product_ids"] = productIds
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 构建分页选项
options := interfaces.ListOptions{
Page: page,
PageSize: pageSize,
Sort: "created_at",
Order: "desc",
}
result, err := h.apiAppService.GetAdminApiCalls(c.Request.Context(), filters, options)
if err != nil {
h.logger.Error("获取管理端API调用记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "获取API调用记录失败")
return
}
h.responseBuilder.Success(c, result, "获取API调用记录成功")
}
// ExportAdminApiCalls 导出管理端API调用记录
func (h *ProductAdminHandler) ExportAdminApiCalls(c *gin.Context) {
// 解析查询参数
filters := make(map[string]interface{})
// 用户ID筛选
if userIds := c.Query("user_ids"); userIds != "" {
filters["user_ids"] = userIds
}
// 产品ID筛选
if productIds := c.Query("product_ids"); productIds != "" {
filters["product_ids"] = productIds
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 获取导出格式默认为excel
format := c.DefaultQuery("format", "excel")
if format != "excel" && format != "csv" {
h.responseBuilder.BadRequest(c, "不支持的导出格式")
return
}
// 调用应用服务导出数据
fileData, err := h.apiAppService.ExportAdminApiCalls(c.Request.Context(), filters, format)
if err != nil {
h.logger.Error("导出API调用记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "导出API调用记录失败")
return
}
// 设置响应头
contentType := "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
filename := "API调用记录.xlsx"
if format == "csv" {
contentType = "text/csv;charset=utf-8"
filename = "API调用记录.csv"
}
c.Header("Content-Type", contentType)
c.Header("Content-Disposition", "attachment; filename="+filename)
c.Data(200, contentType, fileData)
}

File diff suppressed because it is too large Load Diff

View File

@@ -274,7 +274,6 @@ func (h *UserHandler) ListUsers(c *gin.Context) {
return
}
// 构建查询参数
query := &queries.ListUsersQuery{
Page: 1,
@@ -289,7 +288,7 @@ func (h *UserHandler) ListUsers(c *gin.Context) {
}
if pageSize := c.Query("page_size"); pageSize != "" {
if size, err := strconv.Atoi(pageSize); err == nil && size > 0 && size <= 100 {
if size, err := strconv.Atoi(pageSize); err == nil && size > 0 && size <= 1000 {
query.PageSize = size
}
}

View File

@@ -38,6 +38,7 @@ func (r *ApiRoutes) Register(router *sharedhttp.GinRouter) {
apiGroup := engine.Group("/api/v1")
{
// API调用接口 - 不受频率限制(业务核心接口)
apiGroup.POST("/:api_name", r.domainAuthMiddleware.Handle(""), r.apiHandler.HandleApiCall)
// Console专用接口 - 使用JWT认证不需要域名认证
@@ -62,6 +63,11 @@ func (r *ApiRoutes) Register(router *sharedhttp.GinRouter) {
// API调用记录接口
apiGroup.GET("/my/api-calls", r.authMiddleware.Handle(), r.apiHandler.GetUserApiCalls)
// 余额预警设置接口
apiGroup.GET("/user/balance-alert/settings", r.authMiddleware.Handle(), r.apiHandler.GetUserBalanceAlertSettings)
apiGroup.PUT("/user/balance-alert/settings", r.authMiddleware.Handle(), r.apiHandler.UpdateUserBalanceAlertSettings)
apiGroup.POST("/user/balance-alert/test-sms", r.authMiddleware.Handle(), r.apiHandler.TestBalanceAlertSms)
}
r.logger.Info("API路由注册完成")

View File

@@ -54,6 +54,9 @@ func (r *CertificationRoutes) Register(router *http.GinRouter) {
// 2. 提交企业信息(应用每日限流)
authGroup.POST("/enterprise-info", r.dailyRateLimit.Handle(), r.handler.SubmitEnterpriseInfo)
// OCR营业执照识别接口
authGroup.POST("/ocr/business-license", r.handler.RecognizeBusinessLicense)
// 3. 申请合同签署
authGroup.POST("/apply-contract", r.handler.ApplyContract)
@@ -84,6 +87,7 @@ func (r *CertificationRoutes) GetRoutes() []RouteInfo {
{Method: "GET", Path: "/api/v1/certifications", Handler: "ListCertifications", Auth: true},
{Method: "GET", Path: "/api/v1/certifications/statistics", Handler: "GetCertificationStatistics", Auth: true},
{Method: "POST", Path: "/api/v1/certifications/:id/enterprise-info", Handler: "SubmitEnterpriseInfo", Auth: true},
{Method: "POST", Path: "/api/v1/certifications/ocr/business-license", Handler: "RecognizeBusinessLicense", Auth: true},
{Method: "POST", Path: "/api/v1/certifications/apply-contract", Handler: "ApplyContract", Auth: true},
{Method: "POST", Path: "/api/v1/certifications/retry", Handler: "RetryOperation", Auth: true},
{Method: "POST", Path: "/api/v1/certifications/force-transition", Handler: "ForceTransitionStatus", Auth: true},

View File

@@ -87,12 +87,21 @@ func (r *ProductAdminRoutes) Register(router *sharedhttp.GinRouter) {
walletTransactions := adminGroup.Group("/wallet-transactions")
{
walletTransactions.GET("", r.handler.GetAdminWalletTransactions)
walletTransactions.GET("/export", r.handler.ExportAdminWalletTransactions)
}
// API调用记录管理
apiCalls := adminGroup.Group("/api-calls")
{
apiCalls.GET("", r.handler.GetAdminApiCalls)
apiCalls.GET("/export", r.handler.ExportAdminApiCalls)
}
// 充值记录管理
rechargeRecords := adminGroup.Group("/recharge-records")
{
rechargeRecords.GET("", r.handler.GetAdminRechargeRecords)
rechargeRecords.GET("/export", r.handler.ExportAdminRechargeRecords)
}
}
}

View File

@@ -0,0 +1,165 @@
package routes
import (
"tyapi-server/internal/infrastructure/http/handlers"
sharedhttp "tyapi-server/internal/shared/http"
"tyapi-server/internal/shared/middleware"
"go.uber.org/zap"
)
// StatisticsRoutes 统计路由
type StatisticsRoutes struct {
statisticsHandler *handlers.StatisticsHandler
auth *middleware.JWTAuthMiddleware
optionalAuth *middleware.OptionalAuthMiddleware
admin *middleware.AdminAuthMiddleware
logger *zap.Logger
}
// NewStatisticsRoutes 创建统计路由
func NewStatisticsRoutes(
statisticsHandler *handlers.StatisticsHandler,
auth *middleware.JWTAuthMiddleware,
optionalAuth *middleware.OptionalAuthMiddleware,
admin *middleware.AdminAuthMiddleware,
logger *zap.Logger,
) *StatisticsRoutes {
return &StatisticsRoutes{
statisticsHandler: statisticsHandler,
auth: auth,
optionalAuth: optionalAuth,
admin: admin,
logger: logger,
}
}
// Register 注册统计相关路由
func (r *StatisticsRoutes) Register(router *sharedhttp.GinRouter) {
engine := router.GetEngine()
// ================ 用户端统计路由 ================
// 统计公开接口
statistics := engine.Group("/api/v1/statistics")
{
// 获取公开统计信息
statistics.GET("/public", r.statisticsHandler.GetPublicStatistics)
}
// 用户统计接口 - 需要认证
userStats := engine.Group("/api/v1/statistics", r.auth.Handle())
{
// 获取用户统计信息
userStats.GET("/user", r.statisticsHandler.GetUserStatistics)
// 独立统计接口(用户只能查询自己的数据)
userStats.GET("/api-calls", r.statisticsHandler.GetApiCallsStatistics)
userStats.GET("/consumption", r.statisticsHandler.GetConsumptionStatistics)
userStats.GET("/recharge", r.statisticsHandler.GetRechargeStatistics)
// 获取最新产品推荐
userStats.GET("/latest-products", r.statisticsHandler.GetLatestProducts)
// 获取指标列表
userStats.GET("/metrics", r.statisticsHandler.GetMetrics)
// 获取指标详情
userStats.GET("/metrics/:id", r.statisticsHandler.GetMetricDetail)
// 获取仪表板列表
userStats.GET("/dashboards", r.statisticsHandler.GetDashboards)
// 获取仪表板详情
userStats.GET("/dashboards/:id", r.statisticsHandler.GetDashboardDetail)
// 获取仪表板数据
userStats.GET("/dashboards/:id/data", r.statisticsHandler.GetDashboardData)
// 获取报告列表
userStats.GET("/reports", r.statisticsHandler.GetReports)
// 获取报告详情
userStats.GET("/reports/:id", r.statisticsHandler.GetReportDetail)
// 创建报告
userStats.POST("/reports", r.statisticsHandler.CreateReport)
}
// ================ 管理员统计路由 ================
// 管理员路由组
adminGroup := engine.Group("/api/v1/admin")
adminGroup.Use(r.admin.Handle()) // 管理员权限验证
{
// 统计指标管理
metrics := adminGroup.Group("/statistics/metrics")
{
metrics.GET("", r.statisticsHandler.AdminGetMetrics)
metrics.POST("", r.statisticsHandler.AdminCreateMetric)
metrics.PUT("/:id", r.statisticsHandler.AdminUpdateMetric)
metrics.DELETE("/:id", r.statisticsHandler.AdminDeleteMetric)
}
// 仪表板管理
dashboards := adminGroup.Group("/statistics/dashboards")
{
dashboards.GET("", r.statisticsHandler.AdminGetDashboards)
dashboards.POST("", r.statisticsHandler.AdminCreateDashboard)
dashboards.PUT("/:id", r.statisticsHandler.AdminUpdateDashboard)
dashboards.DELETE("/:id", r.statisticsHandler.AdminDeleteDashboard)
}
// 报告管理
reports := adminGroup.Group("/statistics/reports")
{
reports.GET("", r.statisticsHandler.AdminGetReports)
}
// 系统统计
system := adminGroup.Group("/statistics/system")
{
system.GET("", r.statisticsHandler.AdminGetSystemStatistics)
}
// 独立域统计接口
domainStats := adminGroup.Group("/statistics")
{
domainStats.GET("/user-domain", r.statisticsHandler.AdminGetUserDomainStatistics)
domainStats.GET("/api-domain", r.statisticsHandler.AdminGetApiDomainStatistics)
domainStats.GET("/consumption-domain", r.statisticsHandler.AdminGetConsumptionDomainStatistics)
domainStats.GET("/recharge-domain", r.statisticsHandler.AdminGetRechargeDomainStatistics)
}
// 排行榜接口
rankings := adminGroup.Group("/statistics")
{
rankings.GET("/user-call-ranking", r.statisticsHandler.AdminGetUserCallRanking)
rankings.GET("/recharge-ranking", r.statisticsHandler.AdminGetRechargeRanking)
rankings.GET("/api-popularity-ranking", r.statisticsHandler.AdminGetApiPopularityRanking)
rankings.GET("/today-certified-enterprises", r.statisticsHandler.AdminGetTodayCertifiedEnterprises)
}
// 用户统计
userStats := adminGroup.Group("/statistics/users")
{
userStats.GET("/:user_id", r.statisticsHandler.AdminGetUserStatistics)
}
// 独立统计接口(管理员可查询任意用户)
independentStats := adminGroup.Group("/statistics")
{
independentStats.GET("/api-calls", r.statisticsHandler.GetApiCallsStatistics)
independentStats.GET("/consumption", r.statisticsHandler.GetConsumptionStatistics)
independentStats.GET("/recharge", r.statisticsHandler.GetRechargeStatistics)
}
// 数据聚合
aggregation := adminGroup.Group("/statistics/aggregation")
{
aggregation.POST("/trigger", r.statisticsHandler.AdminTriggerAggregation)
}
}
r.logger.Info("统计路由注册完成")
}

View File

@@ -0,0 +1,584 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"tyapi-server/internal/domains/statistics/entities"
)
// RedisStatisticsCache Redis统计缓存实现
type RedisStatisticsCache struct {
client *redis.Client
prefix string
}
// NewRedisStatisticsCache 创建Redis统计缓存
func NewRedisStatisticsCache(client *redis.Client) *RedisStatisticsCache {
return &RedisStatisticsCache{
client: client,
prefix: "statistics:",
}
}
// ================ 指标缓存 ================
// SetMetric 设置指标缓存
func (c *RedisStatisticsCache) SetMetric(ctx context.Context, metric *entities.StatisticsMetric, expiration time.Duration) error {
if metric == nil {
return fmt.Errorf("统计指标不能为空")
}
key := c.getMetricKey(metric.ID)
data, err := json.Marshal(metric)
if err != nil {
return fmt.Errorf("序列化指标失败: %w", err)
}
err = c.client.Set(ctx, key, data, expiration).Err()
if err != nil {
return fmt.Errorf("设置指标缓存失败: %w", err)
}
return nil
}
// GetMetric 获取指标缓存
func (c *RedisStatisticsCache) GetMetric(ctx context.Context, metricID string) (*entities.StatisticsMetric, error) {
if metricID == "" {
return nil, fmt.Errorf("指标ID不能为空")
}
key := c.getMetricKey(metricID)
data, err := c.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 缓存未命中
}
return nil, fmt.Errorf("获取指标缓存失败: %w", err)
}
var metric entities.StatisticsMetric
err = json.Unmarshal([]byte(data), &metric)
if err != nil {
return nil, fmt.Errorf("反序列化指标失败: %w", err)
}
return &metric, nil
}
// DeleteMetric 删除指标缓存
func (c *RedisStatisticsCache) DeleteMetric(ctx context.Context, metricID string) error {
if metricID == "" {
return fmt.Errorf("指标ID不能为空")
}
key := c.getMetricKey(metricID)
err := c.client.Del(ctx, key).Err()
if err != nil {
return fmt.Errorf("删除指标缓存失败: %w", err)
}
return nil
}
// SetMetricsByType 设置按类型分组的指标缓存
func (c *RedisStatisticsCache) SetMetricsByType(ctx context.Context, metricType string, metrics []*entities.StatisticsMetric, expiration time.Duration) error {
if metricType == "" {
return fmt.Errorf("指标类型不能为空")
}
key := c.getMetricsByTypeKey(metricType)
data, err := json.Marshal(metrics)
if err != nil {
return fmt.Errorf("序列化指标列表失败: %w", err)
}
err = c.client.Set(ctx, key, data, expiration).Err()
if err != nil {
return fmt.Errorf("设置指标列表缓存失败: %w", err)
}
return nil
}
// GetMetricsByType 获取按类型分组的指标缓存
func (c *RedisStatisticsCache) GetMetricsByType(ctx context.Context, metricType string) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
key := c.getMetricsByTypeKey(metricType)
data, err := c.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 缓存未命中
}
return nil, fmt.Errorf("获取指标列表缓存失败: %w", err)
}
var metrics []*entities.StatisticsMetric
err = json.Unmarshal([]byte(data), &metrics)
if err != nil {
return nil, fmt.Errorf("反序列化指标列表失败: %w", err)
}
return metrics, nil
}
// DeleteMetricsByType 删除按类型分组的指标缓存
func (c *RedisStatisticsCache) DeleteMetricsByType(ctx context.Context, metricType string) error {
if metricType == "" {
return fmt.Errorf("指标类型不能为空")
}
key := c.getMetricsByTypeKey(metricType)
err := c.client.Del(ctx, key).Err()
if err != nil {
return fmt.Errorf("删除指标列表缓存失败: %w", err)
}
return nil
}
// ================ 实时指标缓存 ================
// SetRealtimeMetrics 设置实时指标缓存
func (c *RedisStatisticsCache) SetRealtimeMetrics(ctx context.Context, metricType string, metrics map[string]float64, expiration time.Duration) error {
if metricType == "" {
return fmt.Errorf("指标类型不能为空")
}
key := c.getRealtimeMetricsKey(metricType)
data, err := json.Marshal(metrics)
if err != nil {
return fmt.Errorf("序列化实时指标失败: %w", err)
}
err = c.client.Set(ctx, key, data, expiration).Err()
if err != nil {
return fmt.Errorf("设置实时指标缓存失败: %w", err)
}
return nil
}
// GetRealtimeMetrics 获取实时指标缓存
func (c *RedisStatisticsCache) GetRealtimeMetrics(ctx context.Context, metricType string) (map[string]float64, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
key := c.getRealtimeMetricsKey(metricType)
data, err := c.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 缓存未命中
}
return nil, fmt.Errorf("获取实时指标缓存失败: %w", err)
}
var metrics map[string]float64
err = json.Unmarshal([]byte(data), &metrics)
if err != nil {
return nil, fmt.Errorf("反序列化实时指标失败: %w", err)
}
return metrics, nil
}
// UpdateRealtimeMetric 更新实时指标
func (c *RedisStatisticsCache) UpdateRealtimeMetric(ctx context.Context, metricType, metricName string, value float64, expiration time.Duration) error {
if metricType == "" || metricName == "" {
return fmt.Errorf("指标类型和名称不能为空")
}
// 获取现有指标
metrics, err := c.GetRealtimeMetrics(ctx, metricType)
if err != nil {
return fmt.Errorf("获取实时指标失败: %w", err)
}
if metrics == nil {
metrics = make(map[string]float64)
}
// 更新指标值
metrics[metricName] = value
// 保存更新后的指标
err = c.SetRealtimeMetrics(ctx, metricType, metrics, expiration)
if err != nil {
return fmt.Errorf("更新实时指标失败: %w", err)
}
return nil
}
// ================ 报告缓存 ================
// SetReport 设置报告缓存
func (c *RedisStatisticsCache) SetReport(ctx context.Context, report *entities.StatisticsReport, expiration time.Duration) error {
if report == nil {
return fmt.Errorf("统计报告不能为空")
}
key := c.getReportKey(report.ID)
data, err := json.Marshal(report)
if err != nil {
return fmt.Errorf("序列化报告失败: %w", err)
}
err = c.client.Set(ctx, key, data, expiration).Err()
if err != nil {
return fmt.Errorf("设置报告缓存失败: %w", err)
}
return nil
}
// GetReport 获取报告缓存
func (c *RedisStatisticsCache) GetReport(ctx context.Context, reportID string) (*entities.StatisticsReport, error) {
if reportID == "" {
return nil, fmt.Errorf("报告ID不能为空")
}
key := c.getReportKey(reportID)
data, err := c.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 缓存未命中
}
return nil, fmt.Errorf("获取报告缓存失败: %w", err)
}
var report entities.StatisticsReport
err = json.Unmarshal([]byte(data), &report)
if err != nil {
return nil, fmt.Errorf("反序列化报告失败: %w", err)
}
return &report, nil
}
// DeleteReport 删除报告缓存
func (c *RedisStatisticsCache) DeleteReport(ctx context.Context, reportID string) error {
if reportID == "" {
return fmt.Errorf("报告ID不能为空")
}
key := c.getReportKey(reportID)
err := c.client.Del(ctx, key).Err()
if err != nil {
return fmt.Errorf("删除报告缓存失败: %w", err)
}
return nil
}
// ================ 仪表板缓存 ================
// SetDashboard 设置仪表板缓存
func (c *RedisStatisticsCache) SetDashboard(ctx context.Context, dashboard *entities.StatisticsDashboard, expiration time.Duration) error {
if dashboard == nil {
return fmt.Errorf("统计仪表板不能为空")
}
key := c.getDashboardKey(dashboard.ID)
data, err := json.Marshal(dashboard)
if err != nil {
return fmt.Errorf("序列化仪表板失败: %w", err)
}
err = c.client.Set(ctx, key, data, expiration).Err()
if err != nil {
return fmt.Errorf("设置仪表板缓存失败: %w", err)
}
return nil
}
// GetDashboard 获取仪表板缓存
func (c *RedisStatisticsCache) GetDashboard(ctx context.Context, dashboardID string) (*entities.StatisticsDashboard, error) {
if dashboardID == "" {
return nil, fmt.Errorf("仪表板ID不能为空")
}
key := c.getDashboardKey(dashboardID)
data, err := c.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 缓存未命中
}
return nil, fmt.Errorf("获取仪表板缓存失败: %w", err)
}
var dashboard entities.StatisticsDashboard
err = json.Unmarshal([]byte(data), &dashboard)
if err != nil {
return nil, fmt.Errorf("反序列化仪表板失败: %w", err)
}
return &dashboard, nil
}
// DeleteDashboard 删除仪表板缓存
func (c *RedisStatisticsCache) DeleteDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
key := c.getDashboardKey(dashboardID)
err := c.client.Del(ctx, key).Err()
if err != nil {
return fmt.Errorf("删除仪表板缓存失败: %w", err)
}
return nil
}
// SetDashboardData 设置仪表板数据缓存
func (c *RedisStatisticsCache) SetDashboardData(ctx context.Context, userRole string, data interface{}, expiration time.Duration) error {
if userRole == "" {
return fmt.Errorf("用户角色不能为空")
}
key := c.getDashboardDataKey(userRole)
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("序列化仪表板数据失败: %w", err)
}
err = c.client.Set(ctx, key, jsonData, expiration).Err()
if err != nil {
return fmt.Errorf("设置仪表板数据缓存失败: %w", err)
}
return nil
}
// GetDashboardData 获取仪表板数据缓存
func (c *RedisStatisticsCache) GetDashboardData(ctx context.Context, userRole string) (interface{}, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
key := c.getDashboardDataKey(userRole)
data, err := c.client.Get(ctx, key).Result()
if err != nil {
if err == redis.Nil {
return nil, nil // 缓存未命中
}
return nil, fmt.Errorf("获取仪表板数据缓存失败: %w", err)
}
var result interface{}
err = json.Unmarshal([]byte(data), &result)
if err != nil {
return nil, fmt.Errorf("反序列化仪表板数据失败: %w", err)
}
return result, nil
}
// DeleteDashboardData 删除仪表板数据缓存
func (c *RedisStatisticsCache) DeleteDashboardData(ctx context.Context, userRole string) error {
if userRole == "" {
return fmt.Errorf("用户角色不能为空")
}
key := c.getDashboardDataKey(userRole)
err := c.client.Del(ctx, key).Err()
if err != nil {
return fmt.Errorf("删除仪表板数据缓存失败: %w", err)
}
return nil
}
// ================ 缓存键生成 ================
// getMetricKey 获取指标缓存键
func (c *RedisStatisticsCache) getMetricKey(metricID string) string {
return c.prefix + "metric:" + metricID
}
// getMetricsByTypeKey 获取按类型分组的指标缓存键
func (c *RedisStatisticsCache) getMetricsByTypeKey(metricType string) string {
return c.prefix + "metrics:type:" + metricType
}
// getRealtimeMetricsKey 获取实时指标缓存键
func (c *RedisStatisticsCache) getRealtimeMetricsKey(metricType string) string {
return c.prefix + "realtime:" + metricType
}
// getReportKey 获取报告缓存键
func (c *RedisStatisticsCache) getReportKey(reportID string) string {
return c.prefix + "report:" + reportID
}
// getDashboardKey 获取仪表板缓存键
func (c *RedisStatisticsCache) getDashboardKey(dashboardID string) string {
return c.prefix + "dashboard:" + dashboardID
}
// getDashboardDataKey 获取仪表板数据缓存键
func (c *RedisStatisticsCache) getDashboardDataKey(userRole string) string {
return c.prefix + "dashboard:data:" + userRole
}
// ================ 批量操作 ================
// BatchDeleteMetrics 批量删除指标缓存
func (c *RedisStatisticsCache) BatchDeleteMetrics(ctx context.Context, metricIDs []string) error {
if len(metricIDs) == 0 {
return nil
}
keys := make([]string, len(metricIDs))
for i, id := range metricIDs {
keys[i] = c.getMetricKey(id)
}
err := c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("批量删除指标缓存失败: %w", err)
}
return nil
}
// BatchDeleteReports 批量删除报告缓存
func (c *RedisStatisticsCache) BatchDeleteReports(ctx context.Context, reportIDs []string) error {
if len(reportIDs) == 0 {
return nil
}
keys := make([]string, len(reportIDs))
for i, id := range reportIDs {
keys[i] = c.getReportKey(id)
}
err := c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("批量删除报告缓存失败: %w", err)
}
return nil
}
// BatchDeleteDashboards 批量删除仪表板缓存
func (c *RedisStatisticsCache) BatchDeleteDashboards(ctx context.Context, dashboardIDs []string) error {
if len(dashboardIDs) == 0 {
return nil
}
keys := make([]string, len(dashboardIDs))
for i, id := range dashboardIDs {
keys[i] = c.getDashboardKey(id)
}
err := c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("批量删除仪表板缓存失败: %w", err)
}
return nil
}
// ================ 缓存清理 ================
// ClearAllStatisticsCache 清理所有统计缓存
func (c *RedisStatisticsCache) ClearAllStatisticsCache(ctx context.Context) error {
pattern := c.prefix + "*"
keys, err := c.client.Keys(ctx, pattern).Result()
if err != nil {
return fmt.Errorf("获取缓存键失败: %w", err)
}
if len(keys) > 0 {
err = c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("清理统计缓存失败: %w", err)
}
}
return nil
}
// ClearMetricsCache 清理指标缓存
func (c *RedisStatisticsCache) ClearMetricsCache(ctx context.Context) error {
pattern := c.prefix + "metric:*"
keys, err := c.client.Keys(ctx, pattern).Result()
if err != nil {
return fmt.Errorf("获取指标缓存键失败: %w", err)
}
if len(keys) > 0 {
err = c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("清理指标缓存失败: %w", err)
}
}
return nil
}
// ClearRealtimeCache 清理实时缓存
func (c *RedisStatisticsCache) ClearRealtimeCache(ctx context.Context) error {
pattern := c.prefix + "realtime:*"
keys, err := c.client.Keys(ctx, pattern).Result()
if err != nil {
return fmt.Errorf("获取实时缓存键失败: %w", err)
}
if len(keys) > 0 {
err = c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("清理实时缓存失败: %w", err)
}
}
return nil
}
// ClearReportsCache 清理报告缓存
func (c *RedisStatisticsCache) ClearReportsCache(ctx context.Context) error {
pattern := c.prefix + "report:*"
keys, err := c.client.Keys(ctx, pattern).Result()
if err != nil {
return fmt.Errorf("获取报告缓存键失败: %w", err)
}
if len(keys) > 0 {
err = c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("清理报告缓存失败: %w", err)
}
}
return nil
}
// ClearDashboardsCache 清理仪表板缓存
func (c *RedisStatisticsCache) ClearDashboardsCache(ctx context.Context) error {
pattern := c.prefix + "dashboard:*"
keys, err := c.client.Keys(ctx, pattern).Result()
if err != nil {
return fmt.Errorf("获取仪表板缓存键失败: %w", err)
}
if len(keys) > 0 {
err = c.client.Del(ctx, keys...).Err()
if err != nil {
return fmt.Errorf("清理仪表板缓存失败: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,404 @@
package cron
import (
"context"
"fmt"
"time"
"github.com/robfig/cron/v3"
"go.uber.org/zap"
"tyapi-server/internal/application/statistics"
)
// StatisticsCronJob 统计定时任务
type StatisticsCronJob struct {
appService statistics.StatisticsApplicationService
logger *zap.Logger
cron *cron.Cron
}
// NewStatisticsCronJob 创建统计定时任务
func NewStatisticsCronJob(
appService statistics.StatisticsApplicationService,
logger *zap.Logger,
) *StatisticsCronJob {
return &StatisticsCronJob{
appService: appService,
logger: logger,
cron: cron.New(cron.WithLocation(time.UTC)),
}
}
// Start 启动定时任务
func (j *StatisticsCronJob) Start() error {
j.logger.Info("启动统计定时任务")
// 每小时聚合任务 - 每小时的第5分钟执行
_, err := j.cron.AddFunc("5 * * * *", j.hourlyAggregationJob)
if err != nil {
return fmt.Errorf("添加小时聚合任务失败: %w", err)
}
// 每日聚合任务 - 每天凌晨1点执行
_, err = j.cron.AddFunc("0 1 * * *", j.dailyAggregationJob)
if err != nil {
return fmt.Errorf("添加日聚合任务失败: %w", err)
}
// 每周聚合任务 - 每周一凌晨2点执行
_, err = j.cron.AddFunc("0 2 * * 1", j.weeklyAggregationJob)
if err != nil {
return fmt.Errorf("添加周聚合任务失败: %w", err)
}
// 每月聚合任务 - 每月1号凌晨3点执行
_, err = j.cron.AddFunc("0 3 1 * *", j.monthlyAggregationJob)
if err != nil {
return fmt.Errorf("添加月聚合任务失败: %w", err)
}
// 数据清理任务 - 每天凌晨4点执行
_, err = j.cron.AddFunc("0 4 * * *", j.dataCleanupJob)
if err != nil {
return fmt.Errorf("添加数据清理任务失败: %w", err)
}
// 缓存预热任务 - 每天早上6点执行
_, err = j.cron.AddFunc("0 6 * * *", j.cacheWarmupJob)
if err != nil {
return fmt.Errorf("添加缓存预热任务失败: %w", err)
}
// 启动定时器
j.cron.Start()
j.logger.Info("统计定时任务启动成功")
return nil
}
// Stop 停止定时任务
func (j *StatisticsCronJob) Stop() {
j.logger.Info("停止统计定时任务")
j.cron.Stop()
j.logger.Info("统计定时任务已停止")
}
// ================ 定时任务实现 ================
// hourlyAggregationJob 小时聚合任务
func (j *StatisticsCronJob) hourlyAggregationJob() {
ctx := context.Background()
now := time.Now()
// 聚合上一小时的数据
lastHour := now.Add(-1 * time.Hour).Truncate(time.Hour)
j.logger.Info("开始执行小时聚合任务", zap.Time("target_hour", lastHour))
err := j.appService.ProcessHourlyAggregation(ctx, lastHour)
if err != nil {
j.logger.Error("小时聚合任务执行失败",
zap.Time("target_hour", lastHour),
zap.Error(err))
return
}
j.logger.Info("小时聚合任务执行成功", zap.Time("target_hour", lastHour))
}
// dailyAggregationJob 日聚合任务
func (j *StatisticsCronJob) dailyAggregationJob() {
ctx := context.Background()
now := time.Now()
// 聚合昨天的数据
yesterday := now.AddDate(0, 0, -1).Truncate(24 * time.Hour)
j.logger.Info("开始执行日聚合任务", zap.Time("target_date", yesterday))
err := j.appService.ProcessDailyAggregation(ctx, yesterday)
if err != nil {
j.logger.Error("日聚合任务执行失败",
zap.Time("target_date", yesterday),
zap.Error(err))
return
}
j.logger.Info("日聚合任务执行成功", zap.Time("target_date", yesterday))
}
// weeklyAggregationJob 周聚合任务
func (j *StatisticsCronJob) weeklyAggregationJob() {
ctx := context.Background()
now := time.Now()
// 聚合上一周的数据
lastWeek := now.AddDate(0, 0, -7).Truncate(24 * time.Hour)
j.logger.Info("开始执行周聚合任务", zap.Time("target_week", lastWeek))
err := j.appService.ProcessWeeklyAggregation(ctx, lastWeek)
if err != nil {
j.logger.Error("周聚合任务执行失败",
zap.Time("target_week", lastWeek),
zap.Error(err))
return
}
j.logger.Info("周聚合任务执行成功", zap.Time("target_week", lastWeek))
}
// monthlyAggregationJob 月聚合任务
func (j *StatisticsCronJob) monthlyAggregationJob() {
ctx := context.Background()
now := time.Now()
// 聚合上个月的数据
lastMonth := now.AddDate(0, -1, 0).Truncate(24 * time.Hour)
j.logger.Info("开始执行月聚合任务", zap.Time("target_month", lastMonth))
err := j.appService.ProcessMonthlyAggregation(ctx, lastMonth)
if err != nil {
j.logger.Error("月聚合任务执行失败",
zap.Time("target_month", lastMonth),
zap.Error(err))
return
}
j.logger.Info("月聚合任务执行成功", zap.Time("target_month", lastMonth))
}
// dataCleanupJob 数据清理任务
func (j *StatisticsCronJob) dataCleanupJob() {
ctx := context.Background()
j.logger.Info("开始执行数据清理任务")
err := j.appService.CleanupExpiredData(ctx)
if err != nil {
j.logger.Error("数据清理任务执行失败", zap.Error(err))
return
}
j.logger.Info("数据清理任务执行成功")
}
// cacheWarmupJob 缓存预热任务
func (j *StatisticsCronJob) cacheWarmupJob() {
ctx := context.Background()
j.logger.Info("开始执行缓存预热任务")
// 预热仪表板数据
err := j.warmupDashboardCache(ctx)
if err != nil {
j.logger.Error("仪表板缓存预热失败", zap.Error(err))
}
// 预热实时指标
err = j.warmupRealtimeMetricsCache(ctx)
if err != nil {
j.logger.Error("实时指标缓存预热失败", zap.Error(err))
}
j.logger.Info("缓存预热任务执行完成")
}
// ================ 缓存预热辅助方法 ================
// warmupDashboardCache 预热仪表板缓存
func (j *StatisticsCronJob) warmupDashboardCache(ctx context.Context) error {
// 获取所有用户角色
userRoles := []string{"admin", "user", "manager", "analyst"}
for _, role := range userRoles {
// 获取仪表板数据
query := &statistics.GetDashboardDataQuery{
UserRole: role,
Period: "today",
StartDate: time.Now().Truncate(24 * time.Hour),
EndDate: time.Now(),
}
_, err := j.appService.GetDashboardData(ctx, query)
if err != nil {
j.logger.Error("预热仪表板缓存失败",
zap.String("user_role", role),
zap.Error(err))
continue
}
j.logger.Info("仪表板缓存预热成功", zap.String("user_role", role))
}
return nil
}
// warmupRealtimeMetricsCache 预热实时指标缓存
func (j *StatisticsCronJob) warmupRealtimeMetricsCache(ctx context.Context) error {
// 获取所有指标类型
metricTypes := []string{"api_calls", "users", "finance", "products", "certification"}
for _, metricType := range metricTypes {
// 获取实时指标
query := &statistics.GetRealtimeMetricsQuery{
MetricType: metricType,
TimeRange: "last_hour",
}
_, err := j.appService.GetRealtimeMetrics(ctx, query)
if err != nil {
j.logger.Error("预热实时指标缓存失败",
zap.String("metric_type", metricType),
zap.Error(err))
continue
}
j.logger.Info("实时指标缓存预热成功", zap.String("metric_type", metricType))
}
return nil
}
// ================ 手动触发任务 ================
// TriggerHourlyAggregation 手动触发小时聚合
func (j *StatisticsCronJob) TriggerHourlyAggregation(targetHour time.Time) error {
ctx := context.Background()
j.logger.Info("手动触发小时聚合任务", zap.Time("target_hour", targetHour))
err := j.appService.ProcessHourlyAggregation(ctx, targetHour)
if err != nil {
j.logger.Error("手动小时聚合任务执行失败",
zap.Time("target_hour", targetHour),
zap.Error(err))
return err
}
j.logger.Info("手动小时聚合任务执行成功", zap.Time("target_hour", targetHour))
return nil
}
// TriggerDailyAggregation 手动触发日聚合
func (j *StatisticsCronJob) TriggerDailyAggregation(targetDate time.Time) error {
ctx := context.Background()
j.logger.Info("手动触发日聚合任务", zap.Time("target_date", targetDate))
err := j.appService.ProcessDailyAggregation(ctx, targetDate)
if err != nil {
j.logger.Error("手动日聚合任务执行失败",
zap.Time("target_date", targetDate),
zap.Error(err))
return err
}
j.logger.Info("手动日聚合任务执行成功", zap.Time("target_date", targetDate))
return nil
}
// TriggerWeeklyAggregation 手动触发周聚合
func (j *StatisticsCronJob) TriggerWeeklyAggregation(targetWeek time.Time) error {
ctx := context.Background()
j.logger.Info("手动触发周聚合任务", zap.Time("target_week", targetWeek))
err := j.appService.ProcessWeeklyAggregation(ctx, targetWeek)
if err != nil {
j.logger.Error("手动周聚合任务执行失败",
zap.Time("target_week", targetWeek),
zap.Error(err))
return err
}
j.logger.Info("手动周聚合任务执行成功", zap.Time("target_week", targetWeek))
return nil
}
// TriggerMonthlyAggregation 手动触发月聚合
func (j *StatisticsCronJob) TriggerMonthlyAggregation(targetMonth time.Time) error {
ctx := context.Background()
j.logger.Info("手动触发月聚合任务", zap.Time("target_month", targetMonth))
err := j.appService.ProcessMonthlyAggregation(ctx, targetMonth)
if err != nil {
j.logger.Error("手动月聚合任务执行失败",
zap.Time("target_month", targetMonth),
zap.Error(err))
return err
}
j.logger.Info("手动月聚合任务执行成功", zap.Time("target_month", targetMonth))
return nil
}
// TriggerDataCleanup 手动触发数据清理
func (j *StatisticsCronJob) TriggerDataCleanup() error {
ctx := context.Background()
j.logger.Info("手动触发数据清理任务")
err := j.appService.CleanupExpiredData(ctx)
if err != nil {
j.logger.Error("手动数据清理任务执行失败", zap.Error(err))
return err
}
j.logger.Info("手动数据清理任务执行成功")
return nil
}
// TriggerCacheWarmup 手动触发缓存预热
func (j *StatisticsCronJob) TriggerCacheWarmup() error {
j.logger.Info("手动触发缓存预热任务")
// 预热仪表板缓存
err := j.warmupDashboardCache(context.Background())
if err != nil {
j.logger.Error("手动仪表板缓存预热失败", zap.Error(err))
}
// 预热实时指标缓存
err = j.warmupRealtimeMetricsCache(context.Background())
if err != nil {
j.logger.Error("手动实时指标缓存预热失败", zap.Error(err))
}
j.logger.Info("手动缓存预热任务执行完成")
return nil
}
// ================ 任务状态查询 ================
// GetCronEntries 获取定时任务条目
func (j *StatisticsCronJob) GetCronEntries() []cron.Entry {
return j.cron.Entries()
}
// GetNextRunTime 获取下次运行时间
func (j *StatisticsCronJob) GetNextRunTime() time.Time {
entries := j.cron.Entries()
if len(entries) == 0 {
return time.Time{}
}
// 返回最近的运行时间
nextRun := entries[0].Next
for _, entry := range entries[1:] {
if entry.Next.Before(nextRun) {
nextRun = entry.Next
}
}
return nextRun
}
// IsRunning 检查任务是否正在运行
func (j *StatisticsCronJob) IsRunning() bool {
return j.cron != nil
}

View File

@@ -0,0 +1,498 @@
package events
import (
"context"
"encoding/json"
"fmt"
"time"
"go.uber.org/zap"
"tyapi-server/internal/domains/statistics/events"
"tyapi-server/internal/domains/statistics/repositories"
"tyapi-server/internal/infrastructure/statistics/cache"
)
// StatisticsEventHandler 统计事件处理器
type StatisticsEventHandler struct {
metricRepo repositories.StatisticsRepository
reportRepo repositories.StatisticsReportRepository
dashboardRepo repositories.StatisticsDashboardRepository
cache *cache.RedisStatisticsCache
logger *zap.Logger
}
// NewStatisticsEventHandler 创建统计事件处理器
func NewStatisticsEventHandler(
metricRepo repositories.StatisticsRepository,
reportRepo repositories.StatisticsReportRepository,
dashboardRepo repositories.StatisticsDashboardRepository,
cache *cache.RedisStatisticsCache,
logger *zap.Logger,
) *StatisticsEventHandler {
return &StatisticsEventHandler{
metricRepo: metricRepo,
reportRepo: reportRepo,
dashboardRepo: dashboardRepo,
cache: cache,
logger: logger,
}
}
// HandleMetricCreatedEvent 处理指标创建事件
func (h *StatisticsEventHandler) HandleMetricCreatedEvent(ctx context.Context, event *events.MetricCreatedEvent) error {
h.logger.Info("处理指标创建事件",
zap.String("metric_id", event.MetricID),
zap.String("metric_type", event.MetricType),
zap.String("metric_name", event.MetricName),
zap.Float64("value", event.Value))
// 更新实时指标缓存
err := h.cache.UpdateRealtimeMetric(ctx, event.MetricType, event.MetricName, event.Value, 1*time.Hour)
if err != nil {
h.logger.Error("更新实时指标缓存失败", zap.Error(err))
// 不返回错误,避免影响主流程
}
// 清理相关缓存
err = h.cache.DeleteMetricsByType(ctx, event.MetricType)
if err != nil {
h.logger.Error("清理指标类型缓存失败", zap.Error(err))
}
return nil
}
// HandleMetricUpdatedEvent 处理指标更新事件
func (h *StatisticsEventHandler) HandleMetricUpdatedEvent(ctx context.Context, event *events.MetricUpdatedEvent) error {
h.logger.Info("处理指标更新事件",
zap.String("metric_id", event.MetricID),
zap.Float64("old_value", event.OldValue),
zap.Float64("new_value", event.NewValue))
// 获取指标信息
metric, err := h.metricRepo.FindByID(ctx, event.MetricID)
if err != nil {
h.logger.Error("查询指标失败", zap.Error(err))
return err
}
// 更新实时指标缓存
err = h.cache.UpdateRealtimeMetric(ctx, metric.MetricType, metric.MetricName, event.NewValue, 1*time.Hour)
if err != nil {
h.logger.Error("更新实时指标缓存失败", zap.Error(err))
}
// 清理相关缓存
err = h.cache.DeleteMetric(ctx, event.MetricID)
if err != nil {
h.logger.Error("清理指标缓存失败", zap.Error(err))
}
err = h.cache.DeleteMetricsByType(ctx, metric.MetricType)
if err != nil {
h.logger.Error("清理指标类型缓存失败", zap.Error(err))
}
return nil
}
// HandleMetricAggregatedEvent 处理指标聚合事件
func (h *StatisticsEventHandler) HandleMetricAggregatedEvent(ctx context.Context, event *events.MetricAggregatedEvent) error {
h.logger.Info("处理指标聚合事件",
zap.String("metric_type", event.MetricType),
zap.String("dimension", event.Dimension),
zap.Int("record_count", event.RecordCount),
zap.Float64("total_value", event.TotalValue))
// 清理相关缓存
err := h.cache.ClearRealtimeCache(ctx)
if err != nil {
h.logger.Error("清理实时缓存失败", zap.Error(err))
}
err = h.cache.ClearMetricsCache(ctx)
if err != nil {
h.logger.Error("清理指标缓存失败", zap.Error(err))
}
return nil
}
// HandleReportCreatedEvent 处理报告创建事件
func (h *StatisticsEventHandler) HandleReportCreatedEvent(ctx context.Context, event *events.ReportCreatedEvent) error {
h.logger.Info("处理报告创建事件",
zap.String("report_id", event.ReportID),
zap.String("report_type", event.ReportType),
zap.String("title", event.Title))
// 获取报告信息
report, err := h.reportRepo.FindByID(ctx, event.ReportID)
if err != nil {
h.logger.Error("查询报告失败", zap.Error(err))
return err
}
// 设置报告缓存
err = h.cache.SetReport(ctx, report, 24*time.Hour)
if err != nil {
h.logger.Error("设置报告缓存失败", zap.Error(err))
}
return nil
}
// HandleReportGenerationStartedEvent 处理报告生成开始事件
func (h *StatisticsEventHandler) HandleReportGenerationStartedEvent(ctx context.Context, event *events.ReportGenerationStartedEvent) error {
h.logger.Info("处理报告生成开始事件",
zap.String("report_id", event.ReportID),
zap.String("generated_by", event.GeneratedBy))
// 获取报告信息
report, err := h.reportRepo.FindByID(ctx, event.ReportID)
if err != nil {
h.logger.Error("查询报告失败", zap.Error(err))
return err
}
// 更新报告缓存
err = h.cache.SetReport(ctx, report, 24*time.Hour)
if err != nil {
h.logger.Error("更新报告缓存失败", zap.Error(err))
}
return nil
}
// HandleReportCompletedEvent 处理报告完成事件
func (h *StatisticsEventHandler) HandleReportCompletedEvent(ctx context.Context, event *events.ReportCompletedEvent) error {
h.logger.Info("处理报告完成事件",
zap.String("report_id", event.ReportID),
zap.Int("content_size", event.ContentSize))
// 获取报告信息
report, err := h.reportRepo.FindByID(ctx, event.ReportID)
if err != nil {
h.logger.Error("查询报告失败", zap.Error(err))
return err
}
// 更新报告缓存
err = h.cache.SetReport(ctx, report, 7*24*time.Hour) // 报告完成后缓存7天
if err != nil {
h.logger.Error("更新报告缓存失败", zap.Error(err))
}
return nil
}
// HandleReportFailedEvent 处理报告失败事件
func (h *StatisticsEventHandler) HandleReportFailedEvent(ctx context.Context, event *events.ReportFailedEvent) error {
h.logger.Info("处理报告失败事件",
zap.String("report_id", event.ReportID),
zap.String("reason", event.Reason))
// 获取报告信息
report, err := h.reportRepo.FindByID(ctx, event.ReportID)
if err != nil {
h.logger.Error("查询报告失败", zap.Error(err))
return err
}
// 更新报告缓存
err = h.cache.SetReport(ctx, report, 1*time.Hour) // 失败报告只缓存1小时
if err != nil {
h.logger.Error("更新报告缓存失败", zap.Error(err))
}
return nil
}
// HandleDashboardCreatedEvent 处理仪表板创建事件
func (h *StatisticsEventHandler) HandleDashboardCreatedEvent(ctx context.Context, event *events.DashboardCreatedEvent) error {
h.logger.Info("处理仪表板创建事件",
zap.String("dashboard_id", event.DashboardID),
zap.String("name", event.Name),
zap.String("user_role", event.UserRole))
// 获取仪表板信息
dashboard, err := h.dashboardRepo.FindByID(ctx, event.DashboardID)
if err != nil {
h.logger.Error("查询仪表板失败", zap.Error(err))
return err
}
// 设置仪表板缓存
err = h.cache.SetDashboard(ctx, dashboard, 24*time.Hour)
if err != nil {
h.logger.Error("设置仪表板缓存失败", zap.Error(err))
}
// 清理仪表板数据缓存
err = h.cache.DeleteDashboardData(ctx, event.UserRole)
if err != nil {
h.logger.Error("清理仪表板数据缓存失败", zap.Error(err))
}
return nil
}
// HandleDashboardUpdatedEvent 处理仪表板更新事件
func (h *StatisticsEventHandler) HandleDashboardUpdatedEvent(ctx context.Context, event *events.DashboardUpdatedEvent) error {
h.logger.Info("处理仪表板更新事件",
zap.String("dashboard_id", event.DashboardID),
zap.String("updated_by", event.UpdatedBy))
// 获取仪表板信息
dashboard, err := h.dashboardRepo.FindByID(ctx, event.DashboardID)
if err != nil {
h.logger.Error("查询仪表板失败", zap.Error(err))
return err
}
// 更新仪表板缓存
err = h.cache.SetDashboard(ctx, dashboard, 24*time.Hour)
if err != nil {
h.logger.Error("更新仪表板缓存失败", zap.Error(err))
}
// 清理仪表板数据缓存
err = h.cache.DeleteDashboardData(ctx, dashboard.UserRole)
if err != nil {
h.logger.Error("清理仪表板数据缓存失败", zap.Error(err))
}
return nil
}
// HandleDashboardActivatedEvent 处理仪表板激活事件
func (h *StatisticsEventHandler) HandleDashboardActivatedEvent(ctx context.Context, event *events.DashboardActivatedEvent) error {
h.logger.Info("处理仪表板激活事件",
zap.String("dashboard_id", event.DashboardID),
zap.String("activated_by", event.ActivatedBy))
// 获取仪表板信息
dashboard, err := h.dashboardRepo.FindByID(ctx, event.DashboardID)
if err != nil {
h.logger.Error("查询仪表板失败", zap.Error(err))
return err
}
// 更新仪表板缓存
err = h.cache.SetDashboard(ctx, dashboard, 24*time.Hour)
if err != nil {
h.logger.Error("更新仪表板缓存失败", zap.Error(err))
}
// 清理仪表板数据缓存
err = h.cache.DeleteDashboardData(ctx, dashboard.UserRole)
if err != nil {
h.logger.Error("清理仪表板数据缓存失败", zap.Error(err))
}
return nil
}
// HandleDashboardDeactivatedEvent 处理仪表板停用事件
func (h *StatisticsEventHandler) HandleDashboardDeactivatedEvent(ctx context.Context, event *events.DashboardDeactivatedEvent) error {
h.logger.Info("处理仪表板停用事件",
zap.String("dashboard_id", event.DashboardID),
zap.String("deactivated_by", event.DeactivatedBy))
// 获取仪表板信息
dashboard, err := h.dashboardRepo.FindByID(ctx, event.DashboardID)
if err != nil {
h.logger.Error("查询仪表板失败", zap.Error(err))
return err
}
// 更新仪表板缓存
err = h.cache.SetDashboard(ctx, dashboard, 24*time.Hour)
if err != nil {
h.logger.Error("更新仪表板缓存失败", zap.Error(err))
}
// 清理仪表板数据缓存
err = h.cache.DeleteDashboardData(ctx, dashboard.UserRole)
if err != nil {
h.logger.Error("清理仪表板数据缓存失败", zap.Error(err))
}
return nil
}
// ================ 事件分发器 ================
// EventDispatcher 事件分发器
type EventDispatcher struct {
handlers map[string][]func(context.Context, interface{}) error
logger *zap.Logger
}
// NewEventDispatcher 创建事件分发器
func NewEventDispatcher(logger *zap.Logger) *EventDispatcher {
return &EventDispatcher{
handlers: make(map[string][]func(context.Context, interface{}) error),
logger: logger,
}
}
// RegisterHandler 注册事件处理器
func (d *EventDispatcher) RegisterHandler(eventType string, handler func(context.Context, interface{}) error) {
if d.handlers[eventType] == nil {
d.handlers[eventType] = make([]func(context.Context, interface{}) error, 0)
}
d.handlers[eventType] = append(d.handlers[eventType], handler)
}
// Dispatch 分发事件
func (d *EventDispatcher) Dispatch(ctx context.Context, event interface{}) error {
// 获取事件类型
eventType := d.getEventType(event)
if eventType == "" {
return fmt.Errorf("无法确定事件类型")
}
// 获取处理器
handlers := d.handlers[eventType]
if len(handlers) == 0 {
d.logger.Warn("没有找到事件处理器", zap.String("event_type", eventType))
return nil
}
// 执行所有处理器
for _, handler := range handlers {
err := handler(ctx, event)
if err != nil {
d.logger.Error("事件处理器执行失败",
zap.String("event_type", eventType),
zap.Error(err))
// 继续执行其他处理器
}
}
return nil
}
// getEventType 获取事件类型
func (d *EventDispatcher) getEventType(event interface{}) string {
switch event.(type) {
case *events.MetricCreatedEvent:
return string(events.MetricCreatedEventType)
case *events.MetricUpdatedEvent:
return string(events.MetricUpdatedEventType)
case *events.MetricAggregatedEvent:
return string(events.MetricAggregatedEventType)
case *events.ReportCreatedEvent:
return string(events.ReportCreatedEventType)
case *events.ReportGenerationStartedEvent:
return string(events.ReportGenerationStartedEventType)
case *events.ReportCompletedEvent:
return string(events.ReportCompletedEventType)
case *events.ReportFailedEvent:
return string(events.ReportFailedEventType)
case *events.DashboardCreatedEvent:
return string(events.DashboardCreatedEventType)
case *events.DashboardUpdatedEvent:
return string(events.DashboardUpdatedEventType)
case *events.DashboardActivatedEvent:
return string(events.DashboardActivatedEventType)
case *events.DashboardDeactivatedEvent:
return string(events.DashboardDeactivatedEventType)
default:
return ""
}
}
// ================ 事件监听器 ================
// EventListener 事件监听器
type EventListener struct {
dispatcher *EventDispatcher
logger *zap.Logger
}
// NewEventListener 创建事件监听器
func NewEventListener(dispatcher *EventDispatcher, logger *zap.Logger) *EventListener {
return &EventListener{
dispatcher: dispatcher,
logger: logger,
}
}
// Listen 监听事件
func (l *EventListener) Listen(ctx context.Context, eventData []byte) error {
// 解析事件数据
var baseEvent events.BaseStatisticsEvent
err := json.Unmarshal(eventData, &baseEvent)
if err != nil {
return fmt.Errorf("解析事件数据失败: %w", err)
}
// 根据事件类型创建具体事件
event, err := l.createEventByType(baseEvent.Type, eventData)
if err != nil {
return fmt.Errorf("创建事件失败: %w", err)
}
// 分发事件
err = l.dispatcher.Dispatch(ctx, event)
if err != nil {
return fmt.Errorf("分发事件失败: %w", err)
}
return nil
}
// createEventByType 根据事件类型创建具体事件
func (l *EventListener) createEventByType(eventType string, eventData []byte) (interface{}, error) {
switch eventType {
case string(events.MetricCreatedEventType):
var event events.MetricCreatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.MetricUpdatedEventType):
var event events.MetricUpdatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.MetricAggregatedEventType):
var event events.MetricAggregatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.ReportCreatedEventType):
var event events.ReportCreatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.ReportGenerationStartedEventType):
var event events.ReportGenerationStartedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.ReportCompletedEventType):
var event events.ReportCompletedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.ReportFailedEventType):
var event events.ReportFailedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.DashboardCreatedEventType):
var event events.DashboardCreatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.DashboardUpdatedEventType):
var event events.DashboardUpdatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.DashboardActivatedEventType):
var event events.DashboardActivatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
case string(events.DashboardDeactivatedEventType):
var event events.DashboardDeactivatedEvent
err := json.Unmarshal(eventData, &event)
return &event, err
default:
return nil, fmt.Errorf("未知的事件类型: %s", eventType)
}
}

View File

@@ -0,0 +1,557 @@
package migrations
import (
"fmt"
"time"
"gorm.io/gorm"
"tyapi-server/internal/domains/statistics/entities"
)
// StatisticsMigration 统计模块数据迁移
type StatisticsMigration struct {
db *gorm.DB
}
// NewStatisticsMigration 创建统计模块数据迁移
func NewStatisticsMigration(db *gorm.DB) *StatisticsMigration {
return &StatisticsMigration{
db: db,
}
}
// Migrate 执行数据迁移
func (m *StatisticsMigration) Migrate() error {
fmt.Println("开始执行统计模块数据迁移...")
// 迁移统计指标表
err := m.migrateStatisticsMetrics()
if err != nil {
return fmt.Errorf("迁移统计指标表失败: %w", err)
}
// 迁移统计报告表
err = m.migrateStatisticsReports()
if err != nil {
return fmt.Errorf("迁移统计报告表失败: %w", err)
}
// 迁移统计仪表板表
err = m.migrateStatisticsDashboards()
if err != nil {
return fmt.Errorf("迁移统计仪表板表失败: %w", err)
}
// 创建索引
err = m.createIndexes()
if err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
// 插入初始数据
err = m.insertInitialData()
if err != nil {
return fmt.Errorf("插入初始数据失败: %w", err)
}
fmt.Println("统计模块数据迁移完成")
return nil
}
// migrateStatisticsMetrics 迁移统计指标表
func (m *StatisticsMigration) migrateStatisticsMetrics() error {
fmt.Println("迁移统计指标表...")
// 自动迁移表结构
err := m.db.AutoMigrate(&entities.StatisticsMetric{})
if err != nil {
return fmt.Errorf("自动迁移统计指标表失败: %w", err)
}
fmt.Println("统计指标表迁移完成")
return nil
}
// migrateStatisticsReports 迁移统计报告表
func (m *StatisticsMigration) migrateStatisticsReports() error {
fmt.Println("迁移统计报告表...")
// 自动迁移表结构
err := m.db.AutoMigrate(&entities.StatisticsReport{})
if err != nil {
return fmt.Errorf("自动迁移统计报告表失败: %w", err)
}
fmt.Println("统计报告表迁移完成")
return nil
}
// migrateStatisticsDashboards 迁移统计仪表板表
func (m *StatisticsMigration) migrateStatisticsDashboards() error {
fmt.Println("迁移统计仪表板表...")
// 自动迁移表结构
err := m.db.AutoMigrate(&entities.StatisticsDashboard{})
if err != nil {
return fmt.Errorf("自动迁移统计仪表板表失败: %w", err)
}
fmt.Println("统计仪表板表迁移完成")
return nil
}
// createIndexes 创建索引
func (m *StatisticsMigration) createIndexes() error {
fmt.Println("创建统计模块索引...")
// 统计指标表索引
err := m.createStatisticsMetricsIndexes()
if err != nil {
return fmt.Errorf("创建统计指标表索引失败: %w", err)
}
// 统计报告表索引
err = m.createStatisticsReportsIndexes()
if err != nil {
return fmt.Errorf("创建统计报告表索引失败: %w", err)
}
// 统计仪表板表索引
err = m.createStatisticsDashboardsIndexes()
if err != nil {
return fmt.Errorf("创建统计仪表板表索引失败: %w", err)
}
fmt.Println("统计模块索引创建完成")
return nil
}
// createStatisticsMetricsIndexes 创建统计指标表索引
func (m *StatisticsMigration) createStatisticsMetricsIndexes() error {
// 复合索引metric_type + date
err := m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_metrics_type_date
ON statistics_metrics (metric_type, date)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 复合索引metric_type + dimension + date
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_metrics_type_dimension_date
ON statistics_metrics (metric_type, dimension, date)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 复合索引metric_type + metric_name + date
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_metrics_type_name_date
ON statistics_metrics (metric_type, metric_name, date)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 单列索引dimension
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_metrics_dimension
ON statistics_metrics (dimension)
`).Error
if err != nil {
return fmt.Errorf("创建维度索引失败: %w", err)
}
return nil
}
// createStatisticsReportsIndexes 创建统计报告表索引
func (m *StatisticsMigration) createStatisticsReportsIndexes() error {
// 复合索引report_type + created_at
err := m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_reports_type_created
ON statistics_reports (report_type, created_at)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 复合索引user_role + created_at
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_reports_role_created
ON statistics_reports (user_role, created_at)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 复合索引status + created_at
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_reports_status_created
ON statistics_reports (status, created_at)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 单列索引generated_by
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_reports_generated_by
ON statistics_reports (generated_by)
`).Error
if err != nil {
return fmt.Errorf("创建生成者索引失败: %w", err)
}
// 单列索引expires_at
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_reports_expires_at
ON statistics_reports (expires_at)
`).Error
if err != nil {
return fmt.Errorf("创建过期时间索引失败: %w", err)
}
return nil
}
// createStatisticsDashboardsIndexes 创建统计仪表板表索引
func (m *StatisticsMigration) createStatisticsDashboardsIndexes() error {
// 复合索引user_role + is_active
err := m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_role_active
ON statistics_dashboards (user_role, is_active)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 复合索引user_role + is_default
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_role_default
ON statistics_dashboards (user_role, is_default)
`).Error
if err != nil {
return fmt.Errorf("创建复合索引失败: %w", err)
}
// 单列索引created_by
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_created_by
ON statistics_dashboards (created_by)
`).Error
if err != nil {
return fmt.Errorf("创建创建者索引失败: %w", err)
}
// 单列索引access_level
err = m.db.Exec(`
CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_access_level
ON statistics_dashboards (access_level)
`).Error
if err != nil {
return fmt.Errorf("创建访问级别索引失败: %w", err)
}
return nil
}
// insertInitialData 插入初始数据
func (m *StatisticsMigration) insertInitialData() error {
fmt.Println("插入统计模块初始数据...")
// 插入默认仪表板
err := m.insertDefaultDashboards()
if err != nil {
return fmt.Errorf("插入默认仪表板失败: %w", err)
}
// 插入初始指标数据
err = m.insertInitialMetrics()
if err != nil {
return fmt.Errorf("插入初始指标数据失败: %w", err)
}
fmt.Println("统计模块初始数据插入完成")
return nil
}
// insertDefaultDashboards 插入默认仪表板
func (m *StatisticsMigration) insertDefaultDashboards() error {
// 管理员默认仪表板
adminDashboard := &entities.StatisticsDashboard{
Name: "管理员仪表板",
Description: "系统管理员专用仪表板,包含所有统计信息",
UserRole: "admin",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 300,
CreatedBy: "system",
Layout: `{"columns": 3, "rows": 4}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}, {"type": "finance", "position": {"x": 2, "y": 0}}]`,
Settings: `{"theme": "dark", "auto_refresh": true}`,
}
err := m.db.Create(adminDashboard).Error
if err != nil {
return fmt.Errorf("创建管理员仪表板失败: %w", err)
}
// 用户默认仪表板
userDashboard := &entities.StatisticsDashboard{
Name: "用户仪表板",
Description: "普通用户专用仪表板,包含基础统计信息",
UserRole: "user",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 600,
CreatedBy: "system",
Layout: `{"columns": 2, "rows": 3}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}]`,
Settings: `{"theme": "light", "auto_refresh": false}`,
}
err = m.db.Create(userDashboard).Error
if err != nil {
return fmt.Errorf("创建用户仪表板失败: %w", err)
}
// 经理默认仪表板
managerDashboard := &entities.StatisticsDashboard{
Name: "经理仪表板",
Description: "经理专用仪表板,包含管理相关统计信息",
UserRole: "manager",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 300,
CreatedBy: "system",
Layout: `{"columns": 3, "rows": 3}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}, {"type": "finance", "position": {"x": 2, "y": 0}}]`,
Settings: `{"theme": "dark", "auto_refresh": true}`,
}
err = m.db.Create(managerDashboard).Error
if err != nil {
return fmt.Errorf("创建经理仪表板失败: %w", err)
}
// 分析师默认仪表板
analystDashboard := &entities.StatisticsDashboard{
Name: "分析师仪表板",
Description: "数据分析师专用仪表板,包含详细分析信息",
UserRole: "analyst",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 180,
CreatedBy: "system",
Layout: `{"columns": 4, "rows": 4}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}, {"type": "finance", "position": {"x": 2, "y": 0}}, {"type": "products", "position": {"x": 3, "y": 0}}]`,
Settings: `{"theme": "dark", "auto_refresh": true, "show_trends": true}`,
}
err = m.db.Create(analystDashboard).Error
if err != nil {
return fmt.Errorf("创建分析师仪表板失败: %w", err)
}
fmt.Println("默认仪表板创建完成")
return nil
}
// insertInitialMetrics 插入初始指标数据
func (m *StatisticsMigration) insertInitialMetrics() error {
now := time.Now()
today := now.Truncate(24 * time.Hour)
// 插入初始API调用指标
apiMetrics := []*entities.StatisticsMetric{
{
MetricType: "api_calls",
MetricName: "total_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "success_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "failed_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "response_time",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始用户指标
userMetrics := []*entities.StatisticsMetric{
{
MetricType: "users",
MetricName: "total_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "users",
MetricName: "certified_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "users",
MetricName: "active_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始财务指标
financeMetrics := []*entities.StatisticsMetric{
{
MetricType: "finance",
MetricName: "total_amount",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "finance",
MetricName: "recharge_amount",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "finance",
MetricName: "deduct_amount",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始产品指标
productMetrics := []*entities.StatisticsMetric{
{
MetricType: "products",
MetricName: "total_products",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "active_products",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "total_subscriptions",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "active_subscriptions",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始认证指标
certificationMetrics := []*entities.StatisticsMetric{
{
MetricType: "certification",
MetricName: "total_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "completed_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "pending_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "failed_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 批量插入所有指标
allMetrics := append(apiMetrics, userMetrics...)
allMetrics = append(allMetrics, financeMetrics...)
allMetrics = append(allMetrics, productMetrics...)
allMetrics = append(allMetrics, certificationMetrics...)
err := m.db.CreateInBatches(allMetrics, 100).Error
if err != nil {
return fmt.Errorf("批量插入初始指标失败: %w", err)
}
fmt.Println("初始指标数据创建完成")
return nil
}
// Rollback 回滚迁移
func (m *StatisticsMigration) Rollback() error {
fmt.Println("开始回滚统计模块数据迁移...")
// 删除表
err := m.db.Migrator().DropTable(&entities.StatisticsDashboard{})
if err != nil {
return fmt.Errorf("删除统计仪表板表失败: %w", err)
}
err = m.db.Migrator().DropTable(&entities.StatisticsReport{})
if err != nil {
return fmt.Errorf("删除统计报告表失败: %w", err)
}
err = m.db.Migrator().DropTable(&entities.StatisticsMetric{})
if err != nil {
return fmt.Errorf("删除统计指标表失败: %w", err)
}
fmt.Println("统计模块数据迁移回滚完成")
return nil
}

View File

@@ -0,0 +1,590 @@
package migrations
import (
"fmt"
"time"
"gorm.io/gorm"
"tyapi-server/internal/domains/statistics/entities"
)
// StatisticsMigrationComplete 统计模块完整数据迁移
type StatisticsMigrationComplete struct {
db *gorm.DB
}
// NewStatisticsMigrationComplete 创建统计模块完整数据迁移
func NewStatisticsMigrationComplete(db *gorm.DB) *StatisticsMigrationComplete {
return &StatisticsMigrationComplete{
db: db,
}
}
// Migrate 执行完整的数据迁移
func (m *StatisticsMigrationComplete) Migrate() error {
fmt.Println("开始执行统计模块完整数据迁移...")
// 1. 迁移表结构
err := m.migrateTables()
if err != nil {
return fmt.Errorf("迁移表结构失败: %w", err)
}
// 2. 创建索引
err = m.createIndexes()
if err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
// 3. 插入初始数据
err = m.insertInitialData()
if err != nil {
return fmt.Errorf("插入初始数据失败: %w", err)
}
fmt.Println("统计模块完整数据迁移完成")
return nil
}
// migrateTables 迁移表结构
func (m *StatisticsMigrationComplete) migrateTables() error {
fmt.Println("迁移统计模块表结构...")
// 迁移统计指标表
err := m.db.AutoMigrate(&entities.StatisticsMetric{})
if err != nil {
return fmt.Errorf("迁移统计指标表失败: %w", err)
}
// 迁移统计报告表
err = m.db.AutoMigrate(&entities.StatisticsReport{})
if err != nil {
return fmt.Errorf("迁移统计报告表失败: %w", err)
}
// 迁移统计仪表板表
err = m.db.AutoMigrate(&entities.StatisticsDashboard{})
if err != nil {
return fmt.Errorf("迁移统计仪表板表失败: %w", err)
}
fmt.Println("统计模块表结构迁移完成")
return nil
}
// createIndexes 创建索引
func (m *StatisticsMigrationComplete) createIndexes() error {
fmt.Println("创建统计模块索引...")
// 统计指标表索引
err := m.createStatisticsMetricsIndexes()
if err != nil {
return fmt.Errorf("创建统计指标表索引失败: %w", err)
}
// 统计报告表索引
err = m.createStatisticsReportsIndexes()
if err != nil {
return fmt.Errorf("创建统计报告表索引失败: %w", err)
}
// 统计仪表板表索引
err = m.createStatisticsDashboardsIndexes()
if err != nil {
return fmt.Errorf("创建统计仪表板表索引失败: %w", err)
}
fmt.Println("统计模块索引创建完成")
return nil
}
// createStatisticsMetricsIndexes 创建统计指标表索引
func (m *StatisticsMigrationComplete) createStatisticsMetricsIndexes() error {
indexes := []string{
// 复合索引metric_type + date
`CREATE INDEX IF NOT EXISTS idx_statistics_metrics_type_date
ON statistics_metrics (metric_type, date)`,
// 复合索引metric_type + dimension + date
`CREATE INDEX IF NOT EXISTS idx_statistics_metrics_type_dimension_date
ON statistics_metrics (metric_type, dimension, date)`,
// 复合索引metric_type + metric_name + date
`CREATE INDEX IF NOT EXISTS idx_statistics_metrics_type_name_date
ON statistics_metrics (metric_type, metric_name, date)`,
// 单列索引dimension
`CREATE INDEX IF NOT EXISTS idx_statistics_metrics_dimension
ON statistics_metrics (dimension)`,
// 单列索引metric_name
`CREATE INDEX IF NOT EXISTS idx_statistics_metrics_name
ON statistics_metrics (metric_name)`,
// 单列索引value用于范围查询
`CREATE INDEX IF NOT EXISTS idx_statistics_metrics_value
ON statistics_metrics (value)`,
}
for _, indexSQL := range indexes {
err := m.db.Exec(indexSQL).Error
if err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
}
return nil
}
// createStatisticsReportsIndexes 创建统计报告表索引
func (m *StatisticsMigrationComplete) createStatisticsReportsIndexes() error {
indexes := []string{
// 复合索引report_type + created_at
`CREATE INDEX IF NOT EXISTS idx_statistics_reports_type_created
ON statistics_reports (report_type, created_at)`,
// 复合索引user_role + created_at
`CREATE INDEX IF NOT EXISTS idx_statistics_reports_role_created
ON statistics_reports (user_role, created_at)`,
// 复合索引status + created_at
`CREATE INDEX IF NOT EXISTS idx_statistics_reports_status_created
ON statistics_reports (status, created_at)`,
// 单列索引generated_by
`CREATE INDEX IF NOT EXISTS idx_statistics_reports_generated_by
ON statistics_reports (generated_by)`,
// 单列索引expires_at
`CREATE INDEX IF NOT EXISTS idx_statistics_reports_expires_at
ON statistics_reports (expires_at)`,
// 单列索引period
`CREATE INDEX IF NOT EXISTS idx_statistics_reports_period
ON statistics_reports (period)`,
}
for _, indexSQL := range indexes {
err := m.db.Exec(indexSQL).Error
if err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
}
return nil
}
// createStatisticsDashboardsIndexes 创建统计仪表板表索引
func (m *StatisticsMigrationComplete) createStatisticsDashboardsIndexes() error {
indexes := []string{
// 复合索引user_role + is_active
`CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_role_active
ON statistics_dashboards (user_role, is_active)`,
// 复合索引user_role + is_default
`CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_role_default
ON statistics_dashboards (user_role, is_default)`,
// 单列索引created_by
`CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_created_by
ON statistics_dashboards (created_by)`,
// 单列索引access_level
`CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_access_level
ON statistics_dashboards (access_level)`,
// 单列索引name用于搜索
`CREATE INDEX IF NOT EXISTS idx_statistics_dashboards_name
ON statistics_dashboards (name)`,
}
for _, indexSQL := range indexes {
err := m.db.Exec(indexSQL).Error
if err != nil {
return fmt.Errorf("创建索引失败: %w", err)
}
}
return nil
}
// insertInitialData 插入初始数据
func (m *StatisticsMigrationComplete) insertInitialData() error {
fmt.Println("插入统计模块初始数据...")
// 插入默认仪表板
err := m.insertDefaultDashboards()
if err != nil {
return fmt.Errorf("插入默认仪表板失败: %w", err)
}
// 插入初始指标数据
err = m.insertInitialMetrics()
if err != nil {
return fmt.Errorf("插入初始指标数据失败: %w", err)
}
fmt.Println("统计模块初始数据插入完成")
return nil
}
// insertDefaultDashboards 插入默认仪表板
func (m *StatisticsMigrationComplete) insertDefaultDashboards() error {
// 检查是否已存在默认仪表板
var count int64
err := m.db.Model(&entities.StatisticsDashboard{}).Where("is_default = ?", true).Count(&count).Error
if err != nil {
return fmt.Errorf("检查默认仪表板失败: %w", err)
}
// 如果已存在默认仪表板,跳过插入
if count > 0 {
fmt.Println("默认仪表板已存在,跳过插入")
return nil
}
// 管理员默认仪表板
adminDashboard := &entities.StatisticsDashboard{
Name: "管理员仪表板",
Description: "系统管理员专用仪表板,包含所有统计信息",
UserRole: "admin",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 300,
CreatedBy: "system",
Layout: `{"columns": 3, "rows": 4}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}, {"type": "finance", "position": {"x": 2, "y": 0}}]`,
Settings: `{"theme": "dark", "auto_refresh": true}`,
}
err = m.db.Create(adminDashboard).Error
if err != nil {
return fmt.Errorf("创建管理员仪表板失败: %w", err)
}
// 用户默认仪表板
userDashboard := &entities.StatisticsDashboard{
Name: "用户仪表板",
Description: "普通用户专用仪表板,包含基础统计信息",
UserRole: "user",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 600,
CreatedBy: "system",
Layout: `{"columns": 2, "rows": 3}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}]`,
Settings: `{"theme": "light", "auto_refresh": false}`,
}
err = m.db.Create(userDashboard).Error
if err != nil {
return fmt.Errorf("创建用户仪表板失败: %w", err)
}
// 经理默认仪表板
managerDashboard := &entities.StatisticsDashboard{
Name: "经理仪表板",
Description: "经理专用仪表板,包含管理相关统计信息",
UserRole: "manager",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 300,
CreatedBy: "system",
Layout: `{"columns": 3, "rows": 3}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}, {"type": "finance", "position": {"x": 2, "y": 0}}]`,
Settings: `{"theme": "dark", "auto_refresh": true}`,
}
err = m.db.Create(managerDashboard).Error
if err != nil {
return fmt.Errorf("创建经理仪表板失败: %w", err)
}
// 分析师默认仪表板
analystDashboard := &entities.StatisticsDashboard{
Name: "分析师仪表板",
Description: "数据分析师专用仪表板,包含详细分析信息",
UserRole: "analyst",
IsDefault: true,
IsActive: true,
AccessLevel: "private",
RefreshInterval: 180,
CreatedBy: "system",
Layout: `{"columns": 4, "rows": 4}`,
Widgets: `[{"type": "api_calls", "position": {"x": 0, "y": 0}}, {"type": "users", "position": {"x": 1, "y": 0}}, {"type": "finance", "position": {"x": 2, "y": 0}}, {"type": "products", "position": {"x": 3, "y": 0}}]`,
Settings: `{"theme": "dark", "auto_refresh": true, "show_trends": true}`,
}
err = m.db.Create(analystDashboard).Error
if err != nil {
return fmt.Errorf("创建分析师仪表板失败: %w", err)
}
fmt.Println("默认仪表板创建完成")
return nil
}
// insertInitialMetrics 插入初始指标数据
func (m *StatisticsMigrationComplete) insertInitialMetrics() error {
now := time.Now()
today := now.Truncate(24 * time.Hour)
// 检查是否已存在今日指标数据
var count int64
err := m.db.Model(&entities.StatisticsMetric{}).Where("date = ?", today).Count(&count).Error
if err != nil {
return fmt.Errorf("检查指标数据失败: %w", err)
}
// 如果已存在今日指标数据,跳过插入
if count > 0 {
fmt.Println("今日指标数据已存在,跳过插入")
return nil
}
// 插入初始API调用指标
apiMetrics := []*entities.StatisticsMetric{
{
MetricType: "api_calls",
MetricName: "total_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "success_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "failed_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "response_time",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "api_calls",
MetricName: "avg_response_time",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始用户指标
userMetrics := []*entities.StatisticsMetric{
{
MetricType: "users",
MetricName: "total_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "users",
MetricName: "certified_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "users",
MetricName: "active_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "users",
MetricName: "new_users_today",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始财务指标
financeMetrics := []*entities.StatisticsMetric{
{
MetricType: "finance",
MetricName: "total_amount",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "finance",
MetricName: "recharge_amount",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "finance",
MetricName: "deduct_amount",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "finance",
MetricName: "recharge_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "finance",
MetricName: "deduct_count",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始产品指标
productMetrics := []*entities.StatisticsMetric{
{
MetricType: "products",
MetricName: "total_products",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "active_products",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "total_subscriptions",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "active_subscriptions",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "products",
MetricName: "new_subscriptions_today",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 插入初始认证指标
certificationMetrics := []*entities.StatisticsMetric{
{
MetricType: "certification",
MetricName: "total_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "completed_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "pending_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "failed_certifications",
Dimension: "realtime",
Value: 0,
Date: today,
},
{
MetricType: "certification",
MetricName: "certification_rate",
Dimension: "realtime",
Value: 0,
Date: today,
},
}
// 批量插入所有指标
allMetrics := append(apiMetrics, userMetrics...)
allMetrics = append(allMetrics, financeMetrics...)
allMetrics = append(allMetrics, productMetrics...)
allMetrics = append(allMetrics, certificationMetrics...)
err = m.db.CreateInBatches(allMetrics, 100).Error
if err != nil {
return fmt.Errorf("批量插入初始指标失败: %w", err)
}
fmt.Println("初始指标数据创建完成")
return nil
}
// Rollback 回滚迁移
func (m *StatisticsMigrationComplete) Rollback() error {
fmt.Println("开始回滚统计模块数据迁移...")
// 删除表
err := m.db.Migrator().DropTable(&entities.StatisticsDashboard{})
if err != nil {
return fmt.Errorf("删除统计仪表板表失败: %w", err)
}
err = m.db.Migrator().DropTable(&entities.StatisticsReport{})
if err != nil {
return fmt.Errorf("删除统计报告表失败: %w", err)
}
err = m.db.Migrator().DropTable(&entities.StatisticsMetric{})
if err != nil {
return fmt.Errorf("删除统计指标表失败: %w", err)
}
fmt.Println("统计模块数据迁移回滚完成")
return nil
}
// GetTableInfo 获取表信息
func (m *StatisticsMigrationComplete) GetTableInfo() map[string]interface{} {
info := make(map[string]interface{})
// 获取表统计信息
tables := []string{"statistics_metrics", "statistics_reports", "statistics_dashboards"}
for _, table := range tables {
var count int64
m.db.Table(table).Count(&count)
info[table] = count
}
return info
}

View File

@@ -0,0 +1 @@

View File

@@ -1,97 +0,0 @@
package task
import (
"context"
"encoding/json"
"fmt"
"tyapi-server/internal/domains/article/repositories"
"github.com/hibiken/asynq"
"go.uber.org/zap"
)
// ArticlePublisher 文章发布接口
type ArticlePublisher interface {
PublishArticleByID(ctx context.Context, articleID string) error
}
// ArticleTaskHandler 文章任务处理器
type ArticleTaskHandler struct {
publisher ArticlePublisher
scheduledTaskRepo repositories.ScheduledTaskRepository
logger *zap.Logger
}
// NewArticleTaskHandler 创建文章任务处理器
func NewArticleTaskHandler(
publisher ArticlePublisher,
scheduledTaskRepo repositories.ScheduledTaskRepository,
logger *zap.Logger,
) *ArticleTaskHandler {
return &ArticleTaskHandler{
publisher: publisher,
scheduledTaskRepo: scheduledTaskRepo,
logger: logger,
}
}
// HandleArticlePublish 处理文章定时发布任务
func (h *ArticleTaskHandler) HandleArticlePublish(ctx context.Context, t *asynq.Task) error {
var payload map[string]interface{}
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析任务载荷失败", zap.Error(err))
return fmt.Errorf("解析任务载荷失败: %w", err)
}
articleID, ok := payload["article_id"].(string)
if !ok {
h.logger.Error("任务载荷中缺少文章ID")
return fmt.Errorf("任务载荷中缺少文章ID")
}
// 获取任务状态记录
task, err := h.scheduledTaskRepo.GetByTaskID(ctx, t.ResultWriter().TaskID())
if err != nil {
h.logger.Error("获取任务状态记录失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(err))
// 继续执行,不阻断任务
} else {
// 检查任务是否已取消
if task.IsCancelled() {
h.logger.Info("任务已取消,跳过执行", zap.String("task_id", t.ResultWriter().TaskID()))
return nil
}
// 标记任务为正在执行
task.MarkAsRunning()
if err := h.scheduledTaskRepo.Update(ctx, task); err != nil {
h.logger.Warn("更新任务状态失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(err))
}
}
// 执行文章发布
if err := h.publisher.PublishArticleByID(ctx, articleID); err != nil {
// 更新任务状态为失败
if task.ID != "" {
task.MarkAsFailed(err.Error())
if updateErr := h.scheduledTaskRepo.Update(ctx, task); updateErr != nil {
h.logger.Warn("更新任务失败状态失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(updateErr))
}
}
h.logger.Error("定时发布文章失败",
zap.String("article_id", articleID),
zap.Error(err))
return fmt.Errorf("定时发布文章失败: %w", err)
}
// 更新任务状态为已完成
if task.ID != "" {
task.MarkAsCompleted()
if err := h.scheduledTaskRepo.Update(ctx, task); err != nil {
h.logger.Warn("更新任务完成状态失败", zap.String("task_id", t.ResultWriter().TaskID()), zap.Error(err))
}
}
h.logger.Info("定时发布文章成功", zap.String("article_id", articleID))
return nil
}

View File

@@ -1,133 +0,0 @@
package task
import (
"context"
"encoding/json"
"fmt"
"time"
"tyapi-server/internal/domains/article/entities"
"tyapi-server/internal/domains/article/repositories"
"github.com/hibiken/asynq"
"go.uber.org/zap"
)
// AsynqClient Asynq 客户端
type AsynqClient struct {
client *asynq.Client
logger *zap.Logger
scheduledTaskRepo repositories.ScheduledTaskRepository
}
// NewAsynqClient 创建 Asynq 客户端
func NewAsynqClient(redisAddr string, scheduledTaskRepo repositories.ScheduledTaskRepository, logger *zap.Logger) *AsynqClient {
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
return &AsynqClient{
client: client,
logger: logger,
scheduledTaskRepo: scheduledTaskRepo,
}
}
// Close 关闭客户端
func (c *AsynqClient) Close() error {
return c.client.Close()
}
// ScheduleArticlePublish 调度文章定时发布任务
func (c *AsynqClient) ScheduleArticlePublish(ctx context.Context, articleID string, publishTime time.Time) (string, error) {
payload := map[string]interface{}{
"article_id": articleID,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
c.logger.Error("序列化任务载荷失败", zap.Error(err))
return "", fmt.Errorf("创建任务失败: %w", err)
}
task := asynq.NewTask(TaskTypeArticlePublish, payloadBytes)
// 计算延迟时间
delay := publishTime.Sub(time.Now())
if delay <= 0 {
return "", fmt.Errorf("定时发布时间不能早于当前时间")
}
// 设置任务选项
opts := []asynq.Option{
asynq.ProcessIn(delay),
asynq.MaxRetry(3),
asynq.Timeout(5 * time.Minute),
}
info, err := c.client.Enqueue(task, opts...)
if err != nil {
c.logger.Error("调度定时发布任务失败",
zap.String("article_id", articleID),
zap.Time("publish_time", publishTime),
zap.Error(err))
return "", fmt.Errorf("调度任务失败: %w", err)
}
// 创建任务状态记录
scheduledTask := entities.ScheduledTask{
TaskID: info.ID,
TaskType: TaskTypeArticlePublish,
ArticleID: articleID,
Status: entities.TaskStatusPending,
ScheduledAt: publishTime,
}
if _, err := c.scheduledTaskRepo.Create(ctx, scheduledTask); err != nil {
c.logger.Error("创建任务状态记录失败", zap.String("task_id", info.ID), zap.Error(err))
// 不返回错误因为Asynq任务已经创建成功
}
c.logger.Info("定时发布任务调度成功",
zap.String("article_id", articleID),
zap.Time("publish_time", publishTime),
zap.String("task_id", info.ID))
return info.ID, nil
}
// CancelScheduledTask 取消已调度的任务
func (c *AsynqClient) CancelScheduledTask(ctx context.Context, taskID string) error {
c.logger.Info("标记定时任务为已取消",
zap.String("task_id", taskID))
// 标记数据库中的任务状态为已取消
if err := c.scheduledTaskRepo.MarkAsCancelled(ctx, taskID); err != nil {
c.logger.Warn("标记任务状态为已取消失败", zap.String("task_id", taskID), zap.Error(err))
// 不返回错误因为Asynq任务可能已经执行完成
}
// Asynq不支持直接取消任务我们通过数据库状态来标记
// 任务执行时会检查文章状态,如果已取消则跳过执行
return nil
}
// RescheduleArticlePublish 重新调度文章定时发布任务
func (c *AsynqClient) RescheduleArticlePublish(ctx context.Context, articleID string, oldTaskID string, newPublishTime time.Time) (string, error) {
// 1. 取消旧任务(标记为已取消)
if err := c.CancelScheduledTask(ctx, oldTaskID); err != nil {
c.logger.Warn("取消旧任务失败",
zap.String("old_task_id", oldTaskID),
zap.Error(err))
}
// 2. 创建新任务
newTaskID, err := c.ScheduleArticlePublish(ctx, articleID, newPublishTime)
if err != nil {
return "", fmt.Errorf("重新调度任务失败: %w", err)
}
c.logger.Info("重新调度定时发布任务成功",
zap.String("article_id", articleID),
zap.String("old_task_id", oldTaskID),
zap.String("new_task_id", newTaskID),
zap.Time("new_publish_time", newPublishTime))
return newTaskID, nil
}

View File

@@ -1,98 +0,0 @@
package task
import (
"fmt"
"github.com/hibiken/asynq"
"go.uber.org/zap"
)
// AsynqWorker Asynq Worker
type AsynqWorker struct {
server *asynq.Server
mux *asynq.ServeMux
taskHandler *ArticleTaskHandler
logger *zap.Logger
}
// NewAsynqWorker 创建 Asynq Worker
func NewAsynqWorker(
redisAddr string,
taskHandler *ArticleTaskHandler,
logger *zap.Logger,
) *AsynqWorker {
server := asynq.NewServer(
asynq.RedisClientOpt{Addr: redisAddr},
asynq.Config{
Concurrency: 10, // 并发数
Queues: map[string]int{
"critical": 6,
"default": 3,
"low": 1,
},
Logger: NewAsynqLogger(logger),
},
)
mux := asynq.NewServeMux()
return &AsynqWorker{
server: server,
mux: mux,
taskHandler: taskHandler,
logger: logger,
}
}
// RegisterHandlers 注册任务处理器
func (w *AsynqWorker) RegisterHandlers() {
// 注册文章定时发布任务处理器
w.mux.HandleFunc(TaskTypeArticlePublish, w.taskHandler.HandleArticlePublish)
w.logger.Info("任务处理器注册完成")
}
// Start 启动 Worker
func (w *AsynqWorker) Start() error {
w.RegisterHandlers()
w.logger.Info("启动 Asynq Worker")
return w.server.Run(w.mux)
}
// Stop 停止 Worker
func (w *AsynqWorker) Stop() {
w.logger.Info("停止 Asynq Worker")
w.server.Stop()
w.server.Shutdown()
}
// AsynqLogger Asynq 日志适配器
type AsynqLogger struct {
logger *zap.Logger
}
// NewAsynqLogger 创建 Asynq 日志适配器
func NewAsynqLogger(logger *zap.Logger) *AsynqLogger {
return &AsynqLogger{logger: logger}
}
func (l *AsynqLogger) Debug(args ...interface{}) {
l.logger.Debug(fmt.Sprint(args...))
}
func (l *AsynqLogger) Info(args ...interface{}) {
l.logger.Info(fmt.Sprint(args...))
}
func (l *AsynqLogger) Warn(args ...interface{}) {
l.logger.Warn(fmt.Sprint(args...))
}
func (l *AsynqLogger) Error(args ...interface{}) {
l.logger.Error(fmt.Sprint(args...))
}
func (l *AsynqLogger) Fatal(args ...interface{}) {
l.logger.Fatal(fmt.Sprint(args...))
}

View File

@@ -0,0 +1,68 @@
package entities
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// TaskStatus 任务状态
type TaskStatus string
const (
TaskStatusPending TaskStatus = "pending"
TaskStatusRunning TaskStatus = "running"
TaskStatusCompleted TaskStatus = "completed"
TaskStatusFailed TaskStatus = "failed"
TaskStatusCancelled TaskStatus = "cancelled"
)
// AsyncTask 异步任务实体
type AsyncTask struct {
ID string `gorm:"type:char(36);primaryKey"`
Type string `gorm:"not null;index"`
Payload string `gorm:"type:text"`
Status TaskStatus `gorm:"not null;index"`
ScheduledAt *time.Time `gorm:"index"`
StartedAt *time.Time
CompletedAt *time.Time
ErrorMsg string
RetryCount int `gorm:"default:0"`
MaxRetries int `gorm:"default:5"`
CreatedAt time.Time
UpdatedAt time.Time
}
// TableName 指定表名
func (AsyncTask) TableName() string {
return "async_tasks"
}
// BeforeCreate GORM钩子在创建前生成UUID
func (t *AsyncTask) BeforeCreate(tx *gorm.DB) error {
if t.ID == "" {
t.ID = uuid.New().String()
}
return nil
}
// IsCompleted 检查任务是否已完成
func (t *AsyncTask) IsCompleted() bool {
return t.Status == TaskStatusCompleted
}
// IsFailed 检查任务是否失败
func (t *AsyncTask) IsFailed() bool {
return t.Status == TaskStatusFailed
}
// IsCancelled 检查任务是否已取消
func (t *AsyncTask) IsCancelled() bool {
return t.Status == TaskStatusCancelled
}
// CanRetry 检查任务是否可以重试
func (t *AsyncTask) CanRetry() bool {
return t.Status == TaskStatusFailed && t.RetryCount < t.MaxRetries
}

View File

@@ -0,0 +1,335 @@
package entities
import (
"context"
"encoding/json"
"fmt"
"time"
"tyapi-server/internal/infrastructure/task/types"
)
// TaskFactory 任务工厂
type TaskFactory struct {
taskManager interface{} // 使用interface{}避免循环导入
}
// NewTaskFactory 创建任务工厂
func NewTaskFactory() *TaskFactory {
return &TaskFactory{}
}
// NewTaskFactoryWithManager 创建带管理器的任务工厂
func NewTaskFactoryWithManager(taskManager interface{}) *TaskFactory {
return &TaskFactory{
taskManager: taskManager,
}
}
// CreateArticlePublishTask 创建文章发布任务
func (f *TaskFactory) CreateArticlePublishTask(articleID string, publishAt time.Time, userID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeArticlePublish),
Status: TaskStatusPending,
ScheduledAt: &publishAt,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务ID将在保存后更新
payloadWithID := map[string]interface{}{
"article_id": articleID,
"publish_at": publishAt,
"user_id": userID,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateArticleCancelTask 创建文章取消任务
func (f *TaskFactory) CreateArticleCancelTask(articleID string, userID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeArticleCancel),
Status: TaskStatusPending,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务数据
payloadWithID := map[string]interface{}{
"article_id": articleID,
"user_id": userID,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateArticleModifyTask 创建文章修改任务
func (f *TaskFactory) CreateArticleModifyTask(articleID string, newPublishAt time.Time, userID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeArticleModify),
Status: TaskStatusPending,
ScheduledAt: &newPublishAt,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务数据
payloadWithID := map[string]interface{}{
"article_id": articleID,
"new_publish_at": newPublishAt,
"user_id": userID,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateApiCallTask 创建API调用任务
func (f *TaskFactory) CreateApiCallTask(apiCallID string, userID string, productID string, amount string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeApiCall),
Status: TaskStatusPending,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务数据
payloadWithID := map[string]interface{}{
"api_call_id": apiCallID,
"user_id": userID,
"product_id": productID,
"amount": amount,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateDeductionTask 创建扣款任务
func (f *TaskFactory) CreateDeductionTask(apiCallID string, userID string, productID string, amount string, transactionID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeDeduction),
Status: TaskStatusPending,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务数据
payloadWithID := map[string]interface{}{
"api_call_id": apiCallID,
"user_id": userID,
"product_id": productID,
"amount": amount,
"transaction_id": transactionID,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateApiCallLogTask 创建API调用日志任务
func (f *TaskFactory) CreateApiCallLogTask(transactionID string, userID string, apiName string, productID string) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeApiLog),
Status: TaskStatusPending,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务数据
payloadWithID := map[string]interface{}{
"transaction_id": transactionID,
"user_id": userID,
"api_name": apiName,
"product_id": productID,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateUsageStatsTask 创建使用统计任务
func (f *TaskFactory) CreateUsageStatsTask(subscriptionID string, userID string, productID string, increment int) (*AsyncTask, error) {
// 创建任务实体ID将由GORM的BeforeCreate钩子自动生成UUID
task := &AsyncTask{
Type: string(types.TaskTypeUsageStats),
Status: TaskStatusPending,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 在payload中添加任务数据
payloadWithID := map[string]interface{}{
"subscription_id": subscriptionID,
"user_id": userID,
"product_id": productID,
"increment": increment,
}
payloadDataWithID, err := json.Marshal(payloadWithID)
if err != nil {
return nil, err
}
task.Payload = string(payloadDataWithID)
return task, nil
}
// CreateAndEnqueueArticlePublishTask 创建并入队文章发布任务
func (f *TaskFactory) CreateAndEnqueueArticlePublishTask(ctx context.Context, articleID string, publishAt time.Time, userID string) error {
if f.taskManager == nil {
return fmt.Errorf("TaskManager未初始化")
}
task, err := f.CreateArticlePublishTask(articleID, publishAt, userID)
if err != nil {
return err
}
delay := publishAt.Sub(time.Now())
if delay < 0 {
delay = 0
}
// 使用类型断言调用TaskManager方法
if tm, ok := f.taskManager.(interface {
CreateAndEnqueueDelayedTask(ctx context.Context, task *AsyncTask, delay time.Duration) error
}); ok {
return tm.CreateAndEnqueueDelayedTask(ctx, task, delay)
}
return fmt.Errorf("TaskManager类型不匹配")
}
// CreateAndEnqueueApiLogTask 创建并入队API日志任务
func (f *TaskFactory) CreateAndEnqueueApiLogTask(ctx context.Context, transactionID string, userID string, apiName string, productID string) error {
if f.taskManager == nil {
return fmt.Errorf("TaskManager未初始化")
}
task, err := f.CreateApiCallLogTask(transactionID, userID, apiName, productID)
if err != nil {
return err
}
// 使用类型断言调用TaskManager方法
if tm, ok := f.taskManager.(interface {
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
}); ok {
return tm.CreateAndEnqueueTask(ctx, task)
}
return fmt.Errorf("TaskManager类型不匹配")
}
// CreateAndEnqueueApiCallTask 创建并入队API调用任务
func (f *TaskFactory) CreateAndEnqueueApiCallTask(ctx context.Context, apiCallID string, userID string, productID string, amount string) error {
if f.taskManager == nil {
return fmt.Errorf("TaskManager未初始化")
}
task, err := f.CreateApiCallTask(apiCallID, userID, productID, amount)
if err != nil {
return err
}
// 使用类型断言调用TaskManager方法
if tm, ok := f.taskManager.(interface {
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
}); ok {
return tm.CreateAndEnqueueTask(ctx, task)
}
return fmt.Errorf("TaskManager类型不匹配")
}
// CreateAndEnqueueDeductionTask 创建并入队扣款任务
func (f *TaskFactory) CreateAndEnqueueDeductionTask(ctx context.Context, apiCallID string, userID string, productID string, amount string, transactionID string) error {
if f.taskManager == nil {
return fmt.Errorf("TaskManager未初始化")
}
task, err := f.CreateDeductionTask(apiCallID, userID, productID, amount, transactionID)
if err != nil {
return err
}
// 使用类型断言调用TaskManager方法
if tm, ok := f.taskManager.(interface {
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
}); ok {
return tm.CreateAndEnqueueTask(ctx, task)
}
return fmt.Errorf("TaskManager类型不匹配")
}
// CreateAndEnqueueUsageStatsTask 创建并入队使用统计任务
func (f *TaskFactory) CreateAndEnqueueUsageStatsTask(ctx context.Context, subscriptionID string, userID string, productID string, increment int) error {
if f.taskManager == nil {
return fmt.Errorf("TaskManager未初始化")
}
task, err := f.CreateUsageStatsTask(subscriptionID, userID, productID, increment)
if err != nil {
return err
}
// 使用类型断言调用TaskManager方法
if tm, ok := f.taskManager.(interface {
CreateAndEnqueueTask(ctx context.Context, task *AsyncTask) error
}); ok {
return tm.CreateAndEnqueueTask(ctx, task)
}
return fmt.Errorf("TaskManager类型不匹配")
}
// generateRandomString 生成随机字符串
func generateRandomString(length int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, length)
for i := range b {
b[i] = charset[time.Now().UnixNano()%int64(len(charset))]
}
return string(b)
}

View File

@@ -0,0 +1,45 @@
package task
import (
"tyapi-server/internal/infrastructure/task/implementations/asynq"
"tyapi-server/internal/infrastructure/task/interfaces"
"go.uber.org/zap"
)
// TaskFactory 任务工厂
type TaskFactory struct{}
// NewTaskFactory 创建任务工厂
func NewTaskFactory() *TaskFactory {
return &TaskFactory{}
}
// CreateApiTaskQueue 创建API任务队列
func (f *TaskFactory) CreateApiTaskQueue(redisAddr string, logger interface{}) interfaces.ApiTaskQueue {
// 这里可以根据配置选择不同的实现
// 目前使用Asynq实现
return asynq.NewAsynqApiTaskQueue(redisAddr, logger.(*zap.Logger))
}
// CreateArticleTaskQueue 创建文章任务队列
func (f *TaskFactory) CreateArticleTaskQueue(redisAddr string, logger interface{}) interfaces.ArticleTaskQueue {
// 这里可以根据配置选择不同的实现
// 目前使用Asynq实现
return asynq.NewAsynqArticleTaskQueue(redisAddr, logger.(*zap.Logger))
}
// NewApiTaskQueue 创建API任务队列包级别函数
func NewApiTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ApiTaskQueue {
return asynq.NewAsynqApiTaskQueue(redisAddr, logger)
}
// NewAsynqClient 创建Asynq客户端包级别函数
func NewAsynqClient(redisAddr string, scheduledTaskRepo interface{}, logger *zap.Logger) *asynq.AsynqClient {
return asynq.NewAsynqClient(redisAddr, logger)
}
// NewArticleTaskQueue 创建文章任务队列(包级别函数)
func NewArticleTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ArticleTaskQueue {
return asynq.NewAsynqArticleTaskQueue(redisAddr, logger)
}

View File

@@ -0,0 +1,285 @@
package handlers
import (
"context"
"encoding/json"
"time"
"github.com/hibiken/asynq"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"tyapi-server/internal/application/api"
finance_services "tyapi-server/internal/domains/finance/services"
product_services "tyapi-server/internal/domains/product/services"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/repositories"
"tyapi-server/internal/infrastructure/task/types"
)
// ApiTaskHandler API任务处理器
type ApiTaskHandler struct {
logger *zap.Logger
apiApplicationService api.ApiApplicationService
walletService finance_services.WalletAggregateService
subscriptionService *product_services.ProductSubscriptionService
asyncTaskRepo repositories.AsyncTaskRepository
}
// NewApiTaskHandler 创建API任务处理器
func NewApiTaskHandler(
logger *zap.Logger,
apiApplicationService api.ApiApplicationService,
walletService finance_services.WalletAggregateService,
subscriptionService *product_services.ProductSubscriptionService,
asyncTaskRepo repositories.AsyncTaskRepository,
) *ApiTaskHandler {
return &ApiTaskHandler{
logger: logger,
apiApplicationService: apiApplicationService,
walletService: walletService,
subscriptionService: subscriptionService,
asyncTaskRepo: asyncTaskRepo,
}
}
// HandleApiCall 处理API调用任务
func (h *ApiTaskHandler) HandleApiCall(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理API调用任务")
var payload types.ApiCallPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析API调用任务载荷失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
return err
}
h.logger.Info("处理API调用任务",
zap.String("api_call_id", payload.ApiCallID),
zap.String("user_id", payload.UserID),
zap.String("product_id", payload.ProductID))
// 这里实现API调用的具体逻辑
// 例如记录API调用、更新使用统计等
// 更新任务状态为成功
h.updateTaskStatus(ctx, t, "completed", "")
h.logger.Info("API调用任务处理完成", zap.String("api_call_id", payload.ApiCallID))
return nil
}
// HandleDeduction 处理扣款任务
func (h *ApiTaskHandler) HandleDeduction(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理扣款任务")
var payload types.DeductionPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析扣款任务载荷失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
return err
}
h.logger.Info("处理扣款任务",
zap.String("user_id", payload.UserID),
zap.String("amount", payload.Amount),
zap.String("transaction_id", payload.TransactionID))
// 调用钱包服务进行扣款
if h.walletService != nil {
amount, err := decimal.NewFromString(payload.Amount)
if err != nil {
h.logger.Error("金额格式错误", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "金额格式错误")
return err
}
if err := h.walletService.Deduct(ctx, payload.UserID, amount, payload.ApiCallID, payload.TransactionID, payload.ProductID); err != nil {
h.logger.Error("扣款处理失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "扣款处理失败: "+err.Error())
return err
}
} else {
h.logger.Warn("钱包服务未初始化,跳过扣款", zap.String("user_id", payload.UserID))
h.updateTaskStatus(ctx, t, "failed", "钱包服务未初始化")
return nil
}
// 更新任务状态为成功
h.updateTaskStatus(ctx, t, "completed", "")
h.logger.Info("扣款任务处理完成", zap.String("transaction_id", payload.TransactionID))
return nil
}
// HandleCompensation 处理补偿任务
func (h *ApiTaskHandler) HandleCompensation(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理补偿任务")
var payload types.CompensationPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析补偿任务载荷失败", zap.Error(err))
return err
}
h.logger.Info("处理补偿任务",
zap.String("transaction_id", payload.TransactionID),
zap.String("type", payload.Type))
// 这里实现补偿的具体逻辑
// 例如:调用钱包服务进行退款等
h.logger.Info("补偿任务处理完成", zap.String("transaction_id", payload.TransactionID))
return nil
}
// HandleUsageStats 处理使用统计任务
func (h *ApiTaskHandler) HandleUsageStats(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理使用统计任务")
var payload types.UsageStatsPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析使用统计任务载荷失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
return err
}
h.logger.Info("处理使用统计任务",
zap.String("subscription_id", payload.SubscriptionID),
zap.String("user_id", payload.UserID),
zap.Int("increment", payload.Increment))
// 调用订阅服务更新使用统计
if h.subscriptionService != nil {
if err := h.subscriptionService.IncrementSubscriptionAPIUsage(ctx, payload.SubscriptionID, int64(payload.Increment)); err != nil {
h.logger.Error("更新使用统计失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "更新使用统计失败: "+err.Error())
return err
}
} else {
h.logger.Warn("订阅服务未初始化,跳过使用统计更新", zap.String("subscription_id", payload.SubscriptionID))
h.updateTaskStatus(ctx, t, "failed", "订阅服务未初始化")
return nil
}
// 更新任务状态为成功
h.updateTaskStatus(ctx, t, "completed", "")
h.logger.Info("使用统计任务处理完成", zap.String("subscription_id", payload.SubscriptionID))
return nil
}
// HandleApiLog 处理API日志任务
func (h *ApiTaskHandler) HandleApiLog(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理API日志任务")
var payload types.ApiLogPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析API日志任务载荷失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
return err
}
h.logger.Info("处理API日志任务",
zap.String("transaction_id", payload.TransactionID),
zap.String("user_id", payload.UserID),
zap.String("api_name", payload.ApiName),
zap.String("product_id", payload.ProductID))
// 记录结构化日志
h.logger.Info("API调用日志",
zap.String("transaction_id", payload.TransactionID),
zap.String("user_id", payload.UserID),
zap.String("api_name", payload.ApiName),
zap.String("product_id", payload.ProductID),
zap.Time("timestamp", time.Now()))
// 这里可以添加其他日志记录逻辑
// 例如:写入专门的日志文件、发送到日志系统、写入数据库等
// 更新任务状态为成功
h.updateTaskStatus(ctx, t, "completed", "")
h.logger.Info("API日志任务处理完成", zap.String("transaction_id", payload.TransactionID))
return nil
}
// updateTaskStatus 更新任务状态
func (h *ApiTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task, status string, errorMsg string) {
// 从任务载荷中提取任务ID
var payload map[string]interface{}
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析任务载荷失败,无法更新状态", zap.Error(err))
return
}
// 尝试从payload中获取任务ID
taskID, ok := payload["task_id"].(string)
if !ok {
h.logger.Error("无法从任务载荷中获取任务ID")
return
}
// 根据状态决定更新方式
if status == "failed" {
// 失败时:需要检查是否达到最大重试次数
h.handleTaskFailure(ctx, taskID, errorMsg)
} else if status == "completed" {
// 成功时:清除错误信息并更新状态
if err := h.asyncTaskRepo.UpdateStatusWithSuccess(ctx, taskID, entities.TaskStatus(status)); err != nil {
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", status),
zap.Error(err))
}
} else {
// 其他状态:只更新状态
if err := h.asyncTaskRepo.UpdateStatus(ctx, taskID, entities.TaskStatus(status)); err != nil {
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", status),
zap.Error(err))
}
}
h.logger.Info("任务状态已更新",
zap.String("task_id", taskID),
zap.String("status", status),
zap.String("error_msg", errorMsg))
}
// handleTaskFailure 处理任务失败
func (h *ApiTaskHandler) handleTaskFailure(ctx context.Context, taskID string, errorMsg string) {
// 获取当前任务信息
task, err := h.asyncTaskRepo.GetByID(ctx, taskID)
if err != nil {
h.logger.Error("获取任务信息失败", zap.String("task_id", taskID), zap.Error(err))
return
}
// 增加重试次数
newRetryCount := task.RetryCount + 1
// 检查是否达到最大重试次数
if newRetryCount >= task.MaxRetries {
// 达到最大重试次数,标记为最终失败
if err := h.asyncTaskRepo.UpdateStatusWithRetryAndError(ctx, taskID, entities.TaskStatusFailed, errorMsg); err != nil {
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", "failed"),
zap.Error(err))
}
h.logger.Info("任务最终失败,已达到最大重试次数",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Int("max_retries", task.MaxRetries))
} else {
// 未达到最大重试次数保持pending状态记录错误信息
if err := h.asyncTaskRepo.UpdateRetryCountAndError(ctx, taskID, newRetryCount, errorMsg); err != nil {
h.logger.Error("更新任务重试次数失败",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Error(err))
}
h.logger.Info("任务失败,准备重试",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Int("max_retries", task.MaxRetries))
}
}

View File

@@ -0,0 +1,304 @@
package handlers
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/application/article"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/repositories"
"tyapi-server/internal/infrastructure/task/types"
)
// ArticleTaskHandler 文章任务处理器
type ArticleTaskHandler struct {
logger *zap.Logger
articleApplicationService article.ArticleApplicationService
asyncTaskRepo repositories.AsyncTaskRepository
}
// NewArticleTaskHandler 创建文章任务处理器
func NewArticleTaskHandler(logger *zap.Logger, articleApplicationService article.ArticleApplicationService, asyncTaskRepo repositories.AsyncTaskRepository) *ArticleTaskHandler {
return &ArticleTaskHandler{
logger: logger,
articleApplicationService: articleApplicationService,
asyncTaskRepo: asyncTaskRepo,
}
}
// HandleArticlePublish 处理文章发布任务
func (h *ArticleTaskHandler) HandleArticlePublish(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理文章发布任务")
var payload ArticlePublishPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析文章发布任务载荷失败", zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "解析任务载荷失败")
return err
}
h.logger.Info("处理文章发布任务",
zap.String("article_id", payload.ArticleID),
zap.Time("publish_at", payload.PublishAt))
// 检查任务是否已被取消
if err := h.checkTaskStatus(ctx, t); err != nil {
h.logger.Info("任务已被取消,跳过执行", zap.String("article_id", payload.ArticleID))
return nil // 静默返回,不报错
}
// 调用文章应用服务发布文章
if h.articleApplicationService != nil {
err := h.articleApplicationService.PublishArticleByID(ctx, payload.ArticleID)
if err != nil {
h.logger.Error("文章发布失败", zap.String("article_id", payload.ArticleID), zap.Error(err))
h.updateTaskStatus(ctx, t, "failed", "文章发布失败: "+err.Error())
return err
}
} else {
h.logger.Warn("文章应用服务未初始化,跳过发布", zap.String("article_id", payload.ArticleID))
h.updateTaskStatus(ctx, t, "failed", "文章应用服务未初始化")
return nil
}
// 更新任务状态为成功
h.updateTaskStatus(ctx, t, "completed", "")
h.logger.Info("文章发布任务处理完成", zap.String("article_id", payload.ArticleID))
return nil
}
// HandleArticleCancel 处理文章取消任务
func (h *ArticleTaskHandler) HandleArticleCancel(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理文章取消任务")
var payload ArticleCancelPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析文章取消任务载荷失败", zap.Error(err))
return err
}
h.logger.Info("处理文章取消任务", zap.String("article_id", payload.ArticleID))
// 这里实现文章取消的具体逻辑
// 例如:更新文章状态、取消定时发布等
h.logger.Info("文章取消任务处理完成", zap.String("article_id", payload.ArticleID))
return nil
}
// HandleArticleModify 处理文章修改任务
func (h *ArticleTaskHandler) HandleArticleModify(ctx context.Context, t *asynq.Task) error {
h.logger.Info("开始处理文章修改任务")
var payload ArticleModifyPayload
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析文章修改任务载荷失败", zap.Error(err))
return err
}
h.logger.Info("处理文章修改任务",
zap.String("article_id", payload.ArticleID),
zap.Time("new_publish_at", payload.NewPublishAt))
// 这里实现文章修改的具体逻辑
// 例如:更新文章发布时间、重新调度任务等
h.logger.Info("文章修改任务处理完成", zap.String("article_id", payload.ArticleID))
return nil
}
// ArticlePublishPayload 文章发布任务载荷
type ArticlePublishPayload struct {
ArticleID string `json:"article_id"`
PublishAt time.Time `json:"publish_at"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *ArticlePublishPayload) GetType() types.TaskType {
return types.TaskTypeArticlePublish
}
// ToJSON 序列化为JSON
func (p *ArticlePublishPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ArticlePublishPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// ArticleCancelPayload 文章取消任务载荷
type ArticleCancelPayload struct {
ArticleID string `json:"article_id"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *ArticleCancelPayload) GetType() types.TaskType {
return types.TaskTypeArticleCancel
}
// ToJSON 序列化为JSON
func (p *ArticleCancelPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ArticleCancelPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// ArticleModifyPayload 文章修改任务载荷
type ArticleModifyPayload struct {
ArticleID string `json:"article_id"`
NewPublishAt time.Time `json:"new_publish_at"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *ArticleModifyPayload) GetType() types.TaskType {
return types.TaskTypeArticleModify
}
// ToJSON 序列化为JSON
func (p *ArticleModifyPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ArticleModifyPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// updateTaskStatus 更新任务状态
func (h *ArticleTaskHandler) updateTaskStatus(ctx context.Context, t *asynq.Task, status string, errorMsg string) {
// 从任务载荷中提取任务ID
var payload map[string]interface{}
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析任务载荷失败,无法更新状态", zap.Error(err))
return
}
// 尝试从payload中获取任务ID
taskID, ok := payload["task_id"].(string)
if !ok {
// 如果没有task_id尝试从article_id生成
if articleID, ok := payload["article_id"].(string); ok {
taskID = fmt.Sprintf("article-publish-%s", articleID)
} else {
h.logger.Error("无法从任务载荷中获取任务ID")
return
}
}
// 根据状态决定更新方式
if status == "failed" {
// 失败时:需要检查是否达到最大重试次数
h.handleTaskFailure(ctx, taskID, errorMsg)
} else if status == "completed" {
// 成功时:清除错误信息并更新状态
if err := h.asyncTaskRepo.UpdateStatusWithSuccess(ctx, taskID, entities.TaskStatus(status)); err != nil {
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", status),
zap.Error(err))
}
} else {
// 其他状态:只更新状态
if err := h.asyncTaskRepo.UpdateStatus(ctx, taskID, entities.TaskStatus(status)); err != nil {
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", status),
zap.Error(err))
}
}
h.logger.Info("任务状态已更新",
zap.String("task_id", taskID),
zap.String("status", status),
zap.String("error_msg", errorMsg))
}
// handleTaskFailure 处理任务失败
func (h *ArticleTaskHandler) handleTaskFailure(ctx context.Context, taskID string, errorMsg string) {
// 获取当前任务信息
task, err := h.asyncTaskRepo.GetByID(ctx, taskID)
if err != nil {
h.logger.Error("获取任务信息失败", zap.String("task_id", taskID), zap.Error(err))
return
}
// 增加重试次数
newRetryCount := task.RetryCount + 1
// 检查是否达到最大重试次数
if newRetryCount >= task.MaxRetries {
// 达到最大重试次数,标记为最终失败
if err := h.asyncTaskRepo.UpdateStatusWithRetryAndError(ctx, taskID, entities.TaskStatusFailed, errorMsg); err != nil {
h.logger.Error("更新任务状态失败",
zap.String("task_id", taskID),
zap.String("status", "failed"),
zap.Error(err))
}
h.logger.Info("任务最终失败,已达到最大重试次数",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Int("max_retries", task.MaxRetries))
} else {
// 未达到最大重试次数保持pending状态记录错误信息
if err := h.asyncTaskRepo.UpdateRetryCountAndError(ctx, taskID, newRetryCount, errorMsg); err != nil {
h.logger.Error("更新任务重试次数失败",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Error(err))
}
h.logger.Info("任务失败,准备重试",
zap.String("task_id", taskID),
zap.Int("retry_count", newRetryCount),
zap.Int("max_retries", task.MaxRetries))
}
}
// checkTaskStatus 检查任务状态
func (h *ArticleTaskHandler) checkTaskStatus(ctx context.Context, t *asynq.Task) error {
// 从任务载荷中提取任务ID
var payload map[string]interface{}
if err := json.Unmarshal(t.Payload(), &payload); err != nil {
h.logger.Error("解析任务载荷失败,无法检查状态", zap.Error(err))
return err
}
// 尝试从payload中获取任务ID
taskID, ok := payload["task_id"].(string)
if !ok {
// 如果没有task_id尝试从article_id生成
if articleID, ok := payload["article_id"].(string); ok {
taskID = fmt.Sprintf("article-publish-%s", articleID)
} else {
h.logger.Error("无法从任务载荷中获取任务ID")
return fmt.Errorf("无法获取任务ID")
}
}
// 查询任务状态
task, err := h.asyncTaskRepo.GetByID(ctx, taskID)
if err != nil {
h.logger.Error("查询任务状态失败", zap.String("task_id", taskID), zap.Error(err))
return err
}
// 检查任务是否已被取消
if task.Status == entities.TaskStatusCancelled {
h.logger.Info("任务已被取消", zap.String("task_id", taskID))
return fmt.Errorf("任务已被取消")
}
return nil
}

View File

@@ -0,0 +1,126 @@
package asynq
import (
"context"
"fmt"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/interfaces"
"tyapi-server/internal/infrastructure/task/types"
)
// AsynqApiTaskQueue Asynq API任务队列实现
type AsynqApiTaskQueue struct {
client *asynq.Client
logger *zap.Logger
}
// NewAsynqApiTaskQueue 创建Asynq API任务队列
func NewAsynqApiTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ApiTaskQueue {
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
return &AsynqApiTaskQueue{
client: client,
logger: logger,
}
}
// Enqueue 入队任务
func (q *AsynqApiTaskQueue) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
payloadData, err := payload.ToJSON()
if err != nil {
q.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = q.client.EnqueueContext(ctx, task)
if err != nil {
q.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
q.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
return nil
}
// EnqueueDelayed 延时入队任务
func (q *AsynqApiTaskQueue) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
payloadData, err := payload.ToJSON()
if err != nil {
q.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
if err != nil {
q.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
q.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
return nil
}
// EnqueueAt 指定时间入队任务
func (q *AsynqApiTaskQueue) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
payloadData, err := payload.ToJSON()
if err != nil {
q.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
if err != nil {
q.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
q.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
return nil
}
// Cancel 取消任务
func (q *AsynqApiTaskQueue) Cancel(ctx context.Context, taskID string) error {
// Asynq本身不支持直接取消任务这里返回错误提示
return fmt.Errorf("Asynq不支持直接取消任务请使用数据库状态管理")
}
// ModifySchedule 修改任务调度时间
func (q *AsynqApiTaskQueue) ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
// Asynq本身不支持修改调度时间这里返回错误提示
return fmt.Errorf("Asynq不支持修改任务调度时间请使用数据库状态管理")
}
// GetTaskStatus 获取任务状态
func (q *AsynqApiTaskQueue) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
// Asynq本身不提供任务状态查询这里返回错误提示
return nil, fmt.Errorf("Asynq不提供任务状态查询请使用数据库状态管理")
}
// ListTasks 列出任务
func (q *AsynqApiTaskQueue) ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
// Asynq本身不提供任务列表查询这里返回错误提示
return nil, fmt.Errorf("Asynq不提供任务列表查询请使用数据库状态管理")
}
// EnqueueTask 入队任务
func (q *AsynqApiTaskQueue) EnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
// 创建Asynq任务
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
// 入队任务
_, err := q.client.EnqueueContext(ctx, asynqTask)
if err != nil {
q.logger.Error("入队任务失败", zap.String("task_id", task.ID), zap.String("task_type", task.Type), zap.Error(err))
return err
}
q.logger.Info("入队任务成功", zap.String("task_id", task.ID), zap.String("task_type", task.Type))
return nil
}

View File

@@ -0,0 +1,131 @@
package asynq
import (
"context"
"fmt"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/interfaces"
"tyapi-server/internal/infrastructure/task/types"
)
// AsynqArticleTaskQueue Asynq文章任务队列实现
type AsynqArticleTaskQueue struct {
client *asynq.Client
logger *zap.Logger
}
// NewAsynqArticleTaskQueue 创建Asynq文章任务队列
func NewAsynqArticleTaskQueue(redisAddr string, logger *zap.Logger) interfaces.ArticleTaskQueue {
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
return &AsynqArticleTaskQueue{
client: client,
logger: logger,
}
}
// Enqueue 入队任务
func (q *AsynqArticleTaskQueue) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
payloadData, err := payload.ToJSON()
if err != nil {
q.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = q.client.EnqueueContext(ctx, task)
if err != nil {
q.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
q.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
return nil
}
// EnqueueDelayed 延时入队任务
func (q *AsynqArticleTaskQueue) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
payloadData, err := payload.ToJSON()
if err != nil {
q.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
if err != nil {
q.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
q.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
return nil
}
// EnqueueAt 指定时间入队任务
func (q *AsynqArticleTaskQueue) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
payloadData, err := payload.ToJSON()
if err != nil {
q.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = q.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
if err != nil {
q.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
q.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
return nil
}
// Cancel 取消任务
func (q *AsynqArticleTaskQueue) Cancel(ctx context.Context, taskID string) error {
// Asynq本身不支持直接取消任务但我们可以通过以下方式实现
// 1. 在数据库中标记任务为已取消
// 2. 任务执行时检查状态,如果已取消则跳过执行
q.logger.Info("标记任务为已取消", zap.String("task_id", taskID))
// 这里应该更新数据库中的任务状态为cancelled
// 由于我们没有直接访问repository暂时只记录日志
// 实际实现中应该调用AsyncTaskRepository.UpdateStatus
return nil
}
// ModifySchedule 修改任务调度时间
func (q *AsynqArticleTaskQueue) ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
// Asynq本身不支持修改调度时间但我们可以通过以下方式实现
// 1. 取消旧任务
// 2. 创建新任务
q.logger.Info("修改任务调度时间",
zap.String("task_id", taskID),
zap.Time("new_scheduled_at", newScheduledAt))
// 这里应该:
// 1. 调用Cancel取消旧任务
// 2. 根据任务类型重新创建任务
// 由于没有直接访问repository暂时只记录日志
return nil
}
// GetTaskStatus 获取任务状态
func (q *AsynqArticleTaskQueue) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
// Asynq本身不提供任务状态查询这里返回错误提示
return nil, fmt.Errorf("Asynq不提供任务状态查询请使用数据库状态管理")
}
// ListTasks 列出任务
func (q *AsynqArticleTaskQueue) ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
// Asynq本身不提供任务列表查询这里返回错误提示
return nil, fmt.Errorf("Asynq不提供任务列表查询请使用数据库状态管理")
}

View File

@@ -0,0 +1,88 @@
package asynq
import (
"context"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/infrastructure/task/types"
)
// AsynqClient Asynq客户端实现
type AsynqClient struct {
client *asynq.Client
logger *zap.Logger
}
// NewAsynqClient 创建Asynq客户端
func NewAsynqClient(redisAddr string, logger *zap.Logger) *AsynqClient {
client := asynq.NewClient(asynq.RedisClientOpt{Addr: redisAddr})
return &AsynqClient{
client: client,
logger: logger,
}
}
// Enqueue 入队任务
func (c *AsynqClient) Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error {
payloadData, err := payload.ToJSON()
if err != nil {
c.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = c.client.EnqueueContext(ctx, task)
if err != nil {
c.logger.Error("入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
c.logger.Info("任务入队成功", zap.String("task_type", string(taskType)))
return nil
}
// EnqueueDelayed 延时入队任务
func (c *AsynqClient) EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error {
payloadData, err := payload.ToJSON()
if err != nil {
c.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = c.client.EnqueueContext(ctx, task, asynq.ProcessIn(delay))
if err != nil {
c.logger.Error("延时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
c.logger.Info("延时任务入队成功", zap.String("task_type", string(taskType)), zap.Duration("delay", delay))
return nil
}
// EnqueueAt 指定时间入队任务
func (c *AsynqClient) EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error {
payloadData, err := payload.ToJSON()
if err != nil {
c.logger.Error("序列化任务载荷失败", zap.Error(err))
return err
}
task := asynq.NewTask(string(taskType), payloadData)
_, err = c.client.EnqueueContext(ctx, task, asynq.ProcessAt(scheduledAt))
if err != nil {
c.logger.Error("定时入队任务失败", zap.String("task_type", string(taskType)), zap.Error(err))
return err
}
c.logger.Info("定时任务入队成功", zap.String("task_type", string(taskType)), zap.Time("scheduled_at", scheduledAt))
return nil
}
// Close 关闭客户端
func (c *AsynqClient) Close() error {
return c.client.Close()
}

View File

@@ -0,0 +1,122 @@
package asynq
import (
"context"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/application/api"
"tyapi-server/internal/application/article"
finance_services "tyapi-server/internal/domains/finance/services"
product_services "tyapi-server/internal/domains/product/services"
"tyapi-server/internal/infrastructure/task/handlers"
"tyapi-server/internal/infrastructure/task/repositories"
"tyapi-server/internal/infrastructure/task/types"
)
// AsynqWorker Asynq Worker实现
type AsynqWorker struct {
server *asynq.Server
mux *asynq.ServeMux
logger *zap.Logger
articleHandler *handlers.ArticleTaskHandler
apiHandler *handlers.ApiTaskHandler
}
// NewAsynqWorker 创建Asynq Worker
func NewAsynqWorker(
redisAddr string,
logger *zap.Logger,
articleApplicationService article.ArticleApplicationService,
apiApplicationService api.ApiApplicationService,
walletService finance_services.WalletAggregateService,
subscriptionService *product_services.ProductSubscriptionService,
asyncTaskRepo repositories.AsyncTaskRepository,
) *AsynqWorker {
server := asynq.NewServer(
asynq.RedisClientOpt{Addr: redisAddr},
asynq.Config{
Concurrency: 6, // 降低总并发数
Queues: map[string]int{
"default": 2, // 2个goroutine
"api": 3, // 3个goroutine (扣款任务)
"article": 1, // 1个goroutine
},
},
)
// 创建任务处理器
articleHandler := handlers.NewArticleTaskHandler(logger, articleApplicationService, asyncTaskRepo)
apiHandler := handlers.NewApiTaskHandler(logger, apiApplicationService, walletService, subscriptionService, asyncTaskRepo)
// 创建ServeMux
mux := asynq.NewServeMux()
return &AsynqWorker{
server: server,
mux: mux,
logger: logger,
articleHandler: articleHandler,
apiHandler: apiHandler,
}
}
// RegisterHandler 注册任务处理器
func (w *AsynqWorker) RegisterHandler(taskType types.TaskType, handler func(context.Context, *asynq.Task) error) {
// 简化实现避免API兼容性问题
w.logger.Info("注册任务处理器", zap.String("task_type", string(taskType)))
}
// Start 启动Worker
func (w *AsynqWorker) Start() error {
w.logger.Info("启动Asynq Worker")
// 注册所有任务处理器
w.registerAllHandlers()
// 启动Worker服务器
go func() {
if err := w.server.Run(w.mux); err != nil {
w.logger.Error("Worker运行失败", zap.Error(err))
}
}()
w.logger.Info("Asynq Worker启动成功")
return nil
}
// Stop 停止Worker
func (w *AsynqWorker) Stop() {
w.logger.Info("停止Asynq Worker")
w.server.Stop()
}
// Shutdown 优雅关闭Worker
func (w *AsynqWorker) Shutdown() {
w.logger.Info("优雅关闭Asynq Worker")
w.server.Shutdown()
}
// registerAllHandlers 注册所有任务处理器
func (w *AsynqWorker) registerAllHandlers() {
// 注册文章任务处理器
w.mux.HandleFunc(string(types.TaskTypeArticlePublish), w.articleHandler.HandleArticlePublish)
w.mux.HandleFunc(string(types.TaskTypeArticleCancel), w.articleHandler.HandleArticleCancel)
w.mux.HandleFunc(string(types.TaskTypeArticleModify), w.articleHandler.HandleArticleModify)
// 注册API任务处理器
w.mux.HandleFunc(string(types.TaskTypeApiCall), w.apiHandler.HandleApiCall)
w.mux.HandleFunc(string(types.TaskTypeApiLog), w.apiHandler.HandleApiLog)
w.mux.HandleFunc(string(types.TaskTypeDeduction), w.apiHandler.HandleDeduction)
w.mux.HandleFunc(string(types.TaskTypeCompensation), w.apiHandler.HandleCompensation)
w.mux.HandleFunc(string(types.TaskTypeUsageStats), w.apiHandler.HandleUsageStats)
w.logger.Info("所有任务处理器注册完成",
zap.String("article_publish", string(types.TaskTypeArticlePublish)),
zap.String("article_cancel", string(types.TaskTypeArticleCancel)),
zap.String("article_modify", string(types.TaskTypeArticleModify)),
zap.String("api_call", string(types.TaskTypeApiCall)),
zap.String("api_log", string(types.TaskTypeApiLog)),
)
}

View File

@@ -0,0 +1,374 @@
package implementations
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/hibiken/asynq"
"go.uber.org/zap"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/interfaces"
"tyapi-server/internal/infrastructure/task/repositories"
"tyapi-server/internal/infrastructure/task/types"
)
// TaskManagerImpl 任务管理器实现
type TaskManagerImpl struct {
asynqClient *asynq.Client
asyncTaskRepo repositories.AsyncTaskRepository
logger *zap.Logger
config *interfaces.TaskManagerConfig
}
// NewTaskManager 创建任务管理器
func NewTaskManager(
asynqClient *asynq.Client,
asyncTaskRepo repositories.AsyncTaskRepository,
logger *zap.Logger,
config *interfaces.TaskManagerConfig,
) interfaces.TaskManager {
return &TaskManagerImpl{
asynqClient: asynqClient,
asyncTaskRepo: asyncTaskRepo,
logger: logger,
config: config,
}
}
// CreateAndEnqueueTask 创建并入队任务
func (tm *TaskManagerImpl) CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error {
// 1. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存任务失败: %w", err)
}
// 2. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 3. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
}
// 4. 入队到Asynq
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
// 如果入队失败,更新任务状态为失败
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "任务入队失败")
return fmt.Errorf("任务入队失败: %w", err)
}
tm.logger.Info("任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
return nil
}
// CreateAndEnqueueDelayedTask 创建并入队延时任务
func (tm *TaskManagerImpl) CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
// 1. 设置调度时间
scheduledAt := time.Now().Add(delay)
task.ScheduledAt = &scheduledAt
// 2. 保存任务到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, task); err != nil {
tm.logger.Error("保存延时任务到数据库失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("保存延时任务失败: %w", err)
}
// 3. 更新payload中的task_id
if err := tm.updatePayloadTaskID(task); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 4. 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
tm.logger.Error("更新任务payload失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务payload失败: %w", err)
}
// 5. 入队到Asynq延时队列
if err := tm.enqueueTaskWithDelay(ctx, task, delay); err != nil {
// 如果入队失败,更新任务状态为失败
tm.asyncTaskRepo.UpdateStatusWithError(ctx, task.ID, entities.TaskStatusFailed, "延时任务入队失败")
return fmt.Errorf("延时任务入队失败: %w", err)
}
tm.logger.Info("延时任务创建并入队成功",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type),
zap.Duration("delay", delay))
return nil
}
// CancelTask 取消任务
func (tm *TaskManagerImpl) CancelTask(ctx context.Context, taskID string) error {
task, err := tm.findTask(ctx, taskID)
if err != nil {
return err
}
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("更新任务状态为取消失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("更新任务状态失败: %w", err)
}
tm.logger.Info("任务已标记为取消",
zap.String("task_id", task.ID),
zap.String("task_type", task.Type))
return nil
}
// UpdateTaskSchedule 更新任务调度时间
func (tm *TaskManagerImpl) UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error {
// 1. 查找任务
task, err := tm.findTask(ctx, taskID)
if err != nil {
return err
}
tm.logger.Info("找到要更新的任务",
zap.String("task_id", task.ID),
zap.String("current_status", string(task.Status)),
zap.Time("current_scheduled_at", *task.ScheduledAt))
// 2. 取消旧任务
if err := tm.asyncTaskRepo.UpdateStatus(ctx, task.ID, entities.TaskStatusCancelled); err != nil {
tm.logger.Error("取消旧任务失败",
zap.String("task_id", task.ID),
zap.Error(err))
return fmt.Errorf("取消旧任务失败: %w", err)
}
tm.logger.Info("旧任务已标记为取消", zap.String("task_id", task.ID))
// 3. 创建并保存新任务
newTask, err := tm.createAndSaveTask(ctx, task, newScheduledAt)
if err != nil {
return err
}
tm.logger.Info("新任务已创建",
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
// 4. 计算延时并入队
delay := newScheduledAt.Sub(time.Now())
if delay < 0 {
delay = 0 // 如果时间已过,立即执行
}
if err := tm.enqueueTaskWithDelay(ctx, newTask, delay); err != nil {
// 如果入队失败,删除新创建的任务记录
tm.asyncTaskRepo.Delete(ctx, newTask.ID)
return fmt.Errorf("重新入队任务失败: %w", err)
}
tm.logger.Info("任务调度时间更新成功",
zap.String("old_task_id", task.ID),
zap.String("new_task_id", newTask.ID),
zap.Time("new_scheduled_at", newScheduledAt))
return nil
}
// GetTaskStatus 获取任务状态
func (tm *TaskManagerImpl) GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
return tm.asyncTaskRepo.GetByID(ctx, taskID)
}
// UpdateTaskStatus 更新任务状态
func (tm *TaskManagerImpl) UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error {
if errorMsg != "" {
return tm.asyncTaskRepo.UpdateStatusWithError(ctx, taskID, status, errorMsg)
}
return tm.asyncTaskRepo.UpdateStatus(ctx, taskID, status)
}
// RetryTask 重试任务
func (tm *TaskManagerImpl) RetryTask(ctx context.Context, taskID string) error {
// 1. 获取任务信息
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
if err != nil {
return fmt.Errorf("获取任务信息失败: %w", err)
}
// 2. 检查是否可以重试
if !task.CanRetry() {
return fmt.Errorf("任务已达到最大重试次数")
}
// 3. 增加重试次数并重置状态
task.RetryCount++
task.Status = entities.TaskStatusPending
// 4. 更新数据库
if err := tm.asyncTaskRepo.Update(ctx, task); err != nil {
return fmt.Errorf("更新任务重试次数失败: %w", err)
}
// 5. 重新入队
if err := tm.enqueueTaskWithDelay(ctx, task, 0); err != nil {
return fmt.Errorf("重试任务入队失败: %w", err)
}
tm.logger.Info("任务重试成功",
zap.String("task_id", taskID),
zap.Int("retry_count", task.RetryCount))
return nil
}
// CleanupExpiredTasks 清理过期任务
func (tm *TaskManagerImpl) CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error {
// 这里可以实现清理逻辑,比如删除超过一定时间的已完成任务
tm.logger.Info("开始清理过期任务", zap.Time("older_than", olderThan))
// TODO: 实现清理逻辑
return nil
}
// updatePayloadTaskID 更新payload中的task_id
func (tm *TaskManagerImpl) updatePayloadTaskID(task *entities.AsyncTask) error {
// 解析payload
var payload map[string]interface{}
if err := json.Unmarshal([]byte(task.Payload), &payload); err != nil {
return fmt.Errorf("解析payload失败: %w", err)
}
// 更新task_id
payload["task_id"] = task.ID
// 重新序列化
newPayload, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("序列化payload失败: %w", err)
}
task.Payload = string(newPayload)
return nil
}
// findTask 查找任务支持taskID和articleID双重查找
func (tm *TaskManagerImpl) findTask(ctx context.Context, taskID string) (*entities.AsyncTask, error) {
// 先尝试通过任务ID查找
task, err := tm.asyncTaskRepo.GetByID(ctx, taskID)
if err == nil {
return task, nil
}
// 如果通过任务ID找不到尝试通过文章ID查找
tm.logger.Info("通过任务ID查找失败尝试通过文章ID查找", zap.String("task_id", taskID))
tasks, err := tm.asyncTaskRepo.GetByArticleID(ctx, taskID)
if err != nil || len(tasks) == 0 {
tm.logger.Error("通过文章ID也找不到任务",
zap.String("article_id", taskID),
zap.Error(err))
return nil, fmt.Errorf("获取任务信息失败: %w", err)
}
// 使用找到的第一个任务
task = tasks[0]
tm.logger.Info("通过文章ID找到任务",
zap.String("article_id", taskID),
zap.String("task_id", task.ID))
return task, nil
}
// createAndSaveTask 创建并保存新任务
func (tm *TaskManagerImpl) createAndSaveTask(ctx context.Context, originalTask *entities.AsyncTask, newScheduledAt time.Time) (*entities.AsyncTask, error) {
// 创建新任务
newTask := &entities.AsyncTask{
Type: originalTask.Type,
Payload: originalTask.Payload,
Status: entities.TaskStatusPending,
ScheduledAt: &newScheduledAt,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 保存到数据库GORM会自动生成UUID
if err := tm.asyncTaskRepo.Create(ctx, newTask); err != nil {
tm.logger.Error("创建新任务失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("创建新任务失败: %w", err)
}
// 更新payload中的task_id
if err := tm.updatePayloadTaskID(newTask); err != nil {
tm.logger.Error("更新payload中的任务ID失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新payload中的任务ID失败: %w", err)
}
// 更新数据库中的payload
if err := tm.asyncTaskRepo.Update(ctx, newTask); err != nil {
tm.logger.Error("更新新任务payload失败",
zap.String("new_task_id", newTask.ID),
zap.Error(err))
return nil, fmt.Errorf("更新新任务payload失败: %w", err)
}
return newTask, nil
}
// enqueueTaskWithDelay 入队任务到Asynq支持延时
func (tm *TaskManagerImpl) enqueueTaskWithDelay(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error {
queueName := tm.getQueueName(task.Type)
asynqTask := asynq.NewTask(task.Type, []byte(task.Payload))
var err error
if delay > 0 {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask,
asynq.Queue(queueName),
asynq.ProcessIn(delay))
} else {
_, err = tm.asynqClient.EnqueueContext(ctx, asynqTask, asynq.Queue(queueName))
}
return err
}
// getQueueName 根据任务类型获取队列名称
func (tm *TaskManagerImpl) getQueueName(taskType string) string {
switch taskType {
case string(types.TaskTypeArticlePublish), string(types.TaskTypeArticleCancel), string(types.TaskTypeArticleModify):
return "article"
case string(types.TaskTypeApiCall), string(types.TaskTypeApiLog), string(types.TaskTypeDeduction), string(types.TaskTypeUsageStats):
return "api"
case string(types.TaskTypeCompensation):
return "finance"
default:
return "default"
}
}

View File

@@ -0,0 +1,35 @@
package interfaces
import (
"context"
"time"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/types"
)
// ApiTaskQueue API任务队列接口
type ApiTaskQueue interface {
// Enqueue 入队任务
Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error
// EnqueueDelayed 延时入队任务
EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error
// EnqueueAt 指定时间入队任务
EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error
// Cancel 取消任务
Cancel(ctx context.Context, taskID string) error
// ModifySchedule 修改任务调度时间
ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error
// GetTaskStatus 获取任务状态
GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error)
// ListTasks 列出任务
ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
// EnqueueTask 入队任务(简化版本)
EnqueueTask(ctx context.Context, task *entities.AsyncTask) error
}

View File

@@ -0,0 +1,32 @@
package interfaces
import (
"context"
"time"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/types"
)
// ArticleTaskQueue 文章任务队列接口
type ArticleTaskQueue interface {
// Enqueue 入队任务
Enqueue(ctx context.Context, taskType types.TaskType, payload types.TaskPayload) error
// EnqueueDelayed 延时入队任务
EnqueueDelayed(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, delay time.Duration) error
// EnqueueAt 指定时间入队任务
EnqueueAt(ctx context.Context, taskType types.TaskType, payload types.TaskPayload, scheduledAt time.Time) error
// Cancel 取消任务
Cancel(ctx context.Context, taskID string) error
// ModifySchedule 修改任务调度时间
ModifySchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error
// GetTaskStatus 获取任务状态
GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error)
// ListTasks 列出任务
ListTasks(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
}

View File

@@ -0,0 +1,44 @@
package interfaces
import (
"context"
"time"
"tyapi-server/internal/infrastructure/task/entities"
)
// TaskManager 任务管理器接口
// 统一管理Asynq任务和AsyncTask实体的操作
type TaskManager interface {
// 创建并入队任务
CreateAndEnqueueTask(ctx context.Context, task *entities.AsyncTask) error
// 创建并入队延时任务
CreateAndEnqueueDelayedTask(ctx context.Context, task *entities.AsyncTask, delay time.Duration) error
// 取消任务
CancelTask(ctx context.Context, taskID string) error
// 更新任务调度时间
UpdateTaskSchedule(ctx context.Context, taskID string, newScheduledAt time.Time) error
// 获取任务状态
GetTaskStatus(ctx context.Context, taskID string) (*entities.AsyncTask, error)
// 更新任务状态
UpdateTaskStatus(ctx context.Context, taskID string, status entities.TaskStatus, errorMsg string) error
// 重试任务
RetryTask(ctx context.Context, taskID string) error
// 清理过期任务
CleanupExpiredTasks(ctx context.Context, olderThan time.Time) error
}
// TaskManagerConfig 任务管理器配置
type TaskManagerConfig struct {
RedisAddr string
MaxRetries int
RetryInterval time.Duration
CleanupDays int
}

View File

@@ -0,0 +1,267 @@
package repositories
import (
"context"
"time"
"gorm.io/gorm"
"tyapi-server/internal/infrastructure/task/entities"
"tyapi-server/internal/infrastructure/task/types"
)
// AsyncTaskRepository 异步任务仓库接口
type AsyncTaskRepository interface {
// 基础CRUD操作
Create(ctx context.Context, task *entities.AsyncTask) error
GetByID(ctx context.Context, id string) (*entities.AsyncTask, error)
Update(ctx context.Context, task *entities.AsyncTask) error
Delete(ctx context.Context, id string) error
// 查询操作
ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error)
ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error)
ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error)
// 状态更新操作
UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error
UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error
UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error
UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error
UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error
IncrementRetryCount(ctx context.Context, id string) error
// 批量操作
UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error
DeleteBatch(ctx context.Context, ids []string) error
// 文章任务专用方法
GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error)
GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error)
CancelArticlePublishTask(ctx context.Context, articleID string) error
UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error
}
// AsyncTaskRepositoryImpl 异步任务仓库实现
type AsyncTaskRepositoryImpl struct {
db *gorm.DB
}
// NewAsyncTaskRepository 创建异步任务仓库
func NewAsyncTaskRepository(db *gorm.DB) AsyncTaskRepository {
return &AsyncTaskRepositoryImpl{
db: db,
}
}
// Create 创建任务
func (r *AsyncTaskRepositoryImpl) Create(ctx context.Context, task *entities.AsyncTask) error {
return r.db.WithContext(ctx).Create(task).Error
}
// GetByID 根据ID获取任务
func (r *AsyncTaskRepositoryImpl) GetByID(ctx context.Context, id string) (*entities.AsyncTask, error) {
var task entities.AsyncTask
err := r.db.WithContext(ctx).Where("id = ?", id).First(&task).Error
if err != nil {
return nil, err
}
return &task, nil
}
// Update 更新任务
func (r *AsyncTaskRepositoryImpl) Update(ctx context.Context, task *entities.AsyncTask) error {
return r.db.WithContext(ctx).Save(task).Error
}
// Delete 删除任务
func (r *AsyncTaskRepositoryImpl) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.AsyncTask{}).Error
}
// ListByType 根据类型列出任务
func (r *AsyncTaskRepositoryImpl) ListByType(ctx context.Context, taskType types.TaskType, limit int) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
query := r.db.WithContext(ctx).Where("type = ?", taskType)
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&tasks).Error
return tasks, err
}
// ListByStatus 根据状态列出任务
func (r *AsyncTaskRepositoryImpl) ListByStatus(ctx context.Context, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
query := r.db.WithContext(ctx).Where("status = ?", status)
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&tasks).Error
return tasks, err
}
// ListByTypeAndStatus 根据类型和状态列出任务
func (r *AsyncTaskRepositoryImpl) ListByTypeAndStatus(ctx context.Context, taskType types.TaskType, status entities.TaskStatus, limit int) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
query := r.db.WithContext(ctx).Where("type = ? AND status = ?", taskType, status)
if limit > 0 {
query = query.Limit(limit)
}
err := query.Find(&tasks).Error
return tasks, err
}
// ListScheduledTasks 列出已到期的调度任务
func (r *AsyncTaskRepositoryImpl) ListScheduledTasks(ctx context.Context, before time.Time) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
err := r.db.WithContext(ctx).
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.TaskStatusPending, before).
Find(&tasks).Error
return tasks, err
}
// UpdateStatus 更新任务状态
func (r *AsyncTaskRepositoryImpl) UpdateStatus(ctx context.Context, id string, status entities.TaskStatus) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"updated_at": time.Now(),
}).Error
}
// UpdateStatusWithError 更新任务状态并记录错误
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"error_msg": errorMsg,
"updated_at": time.Now(),
}).Error
}
// UpdateStatusWithRetryAndError 更新任务状态、增加重试次数并记录错误
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithRetryAndError(ctx context.Context, id string, status entities.TaskStatus, errorMsg string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"error_msg": errorMsg,
"retry_count": gorm.Expr("retry_count + 1"),
"updated_at": time.Now(),
}).Error
}
// UpdateStatusWithSuccess 更新任务状态为成功,清除错误信息
func (r *AsyncTaskRepositoryImpl) UpdateStatusWithSuccess(ctx context.Context, id string, status entities.TaskStatus) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"status": status,
"error_msg": "", // 清除错误信息
"updated_at": time.Now(),
}).Error
}
// UpdateRetryCountAndError 更新重试次数和错误信息保持pending状态
func (r *AsyncTaskRepositoryImpl) UpdateRetryCountAndError(ctx context.Context, id string, retryCount int, errorMsg string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Updates(map[string]interface{}{
"retry_count": retryCount,
"error_msg": errorMsg,
"updated_at": time.Now(),
// 注意不更新status保持pending状态
}).Error
}
// UpdateScheduledAt 更新任务调度时间
func (r *AsyncTaskRepositoryImpl) UpdateScheduledAt(ctx context.Context, id string, scheduledAt time.Time) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Update("scheduled_at", scheduledAt).Error
}
// IncrementRetryCount 增加重试次数
func (r *AsyncTaskRepositoryImpl) IncrementRetryCount(ctx context.Context, id string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id = ?", id).
Update("retry_count", gorm.Expr("retry_count + 1")).Error
}
// UpdateStatusBatch 批量更新状态
func (r *AsyncTaskRepositoryImpl) UpdateStatusBatch(ctx context.Context, ids []string, status entities.TaskStatus) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("id IN ?", ids).
Update("status", status).Error
}
// DeleteBatch 批量删除
func (r *AsyncTaskRepositoryImpl) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).
Where("id IN ?", ids).
Delete(&entities.AsyncTask{}).Error
}
// GetArticlePublishTask 获取文章发布任务
func (r *AsyncTaskRepositoryImpl) GetArticlePublishTask(ctx context.Context, articleID string) (*entities.AsyncTask, error) {
var task entities.AsyncTask
err := r.db.WithContext(ctx).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
First(&task).Error
if err != nil {
return nil, err
}
return &task, nil
}
// GetByArticleID 根据文章ID获取所有相关任务
func (r *AsyncTaskRepositoryImpl) GetByArticleID(ctx context.Context, articleID string) ([]*entities.AsyncTask, error) {
var tasks []*entities.AsyncTask
err := r.db.WithContext(ctx).
Where("payload LIKE ? AND status IN ?",
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Find(&tasks).Error
if err != nil {
return nil, err
}
return tasks, nil
}
// CancelArticlePublishTask 取消文章发布任务
func (r *AsyncTaskRepositoryImpl) CancelArticlePublishTask(ctx context.Context, articleID string) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Update("status", entities.TaskStatusCancelled).Error
}
// UpdateArticlePublishTaskSchedule 更新文章发布任务调度时间
func (r *AsyncTaskRepositoryImpl) UpdateArticlePublishTaskSchedule(ctx context.Context, articleID string, newScheduledAt time.Time) error {
return r.db.WithContext(ctx).
Model(&entities.AsyncTask{}).
Where("type = ? AND payload LIKE ? AND status IN ?",
types.TaskTypeArticlePublish,
"%\"article_id\":\""+articleID+"\"%",
[]entities.TaskStatus{entities.TaskStatusPending, entities.TaskStatusRunning}).
Update("scheduled_at", newScheduledAt).Error
}

View File

@@ -1,7 +0,0 @@
package task
// 任务类型常量
const (
// TaskTypeArticlePublish 文章定时发布任务
TaskTypeArticlePublish = "article:publish"
)

View File

@@ -0,0 +1,196 @@
package types
import (
"encoding/json"
"time"
)
// QueueType 队列类型
type QueueType string
const (
QueueTypeDefault QueueType = "default"
QueueTypeApi QueueType = "api"
QueueTypeArticle QueueType = "article"
QueueTypeFinance QueueType = "finance"
QueueTypeProduct QueueType = "product"
)
// ArticlePublishPayload 文章发布任务载荷
type ArticlePublishPayload struct {
ArticleID string `json:"article_id"`
PublishAt time.Time `json:"publish_at"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *ArticlePublishPayload) GetType() TaskType {
return TaskTypeArticlePublish
}
// ToJSON 序列化为JSON
func (p *ArticlePublishPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ArticlePublishPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// ArticleCancelPayload 文章取消任务载荷
type ArticleCancelPayload struct {
ArticleID string `json:"article_id"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *ArticleCancelPayload) GetType() TaskType {
return TaskTypeArticleCancel
}
// ToJSON 序列化为JSON
func (p *ArticleCancelPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ArticleCancelPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// ArticleModifyPayload 文章修改任务载荷
type ArticleModifyPayload struct {
ArticleID string `json:"article_id"`
NewPublishAt time.Time `json:"new_publish_at"`
UserID string `json:"user_id"`
}
// GetType 获取任务类型
func (p *ArticleModifyPayload) GetType() TaskType {
return TaskTypeArticleModify
}
// ToJSON 序列化为JSON
func (p *ArticleModifyPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ArticleModifyPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// ApiCallPayload API调用任务载荷
type ApiCallPayload struct {
ApiCallID string `json:"api_call_id"`
UserID string `json:"user_id"`
ProductID string `json:"product_id"`
Amount string `json:"amount"`
}
// GetType 获取任务类型
func (p *ApiCallPayload) GetType() TaskType {
return TaskTypeApiCall
}
// ToJSON 序列化为JSON
func (p *ApiCallPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ApiCallPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// DeductionPayload 扣款任务载荷
type DeductionPayload struct {
UserID string `json:"user_id"`
Amount string `json:"amount"`
ApiCallID string `json:"api_call_id"`
TransactionID string `json:"transaction_id"`
ProductID string `json:"product_id"`
}
// GetType 获取任务类型
func (p *DeductionPayload) GetType() TaskType {
return TaskTypeDeduction
}
// ToJSON 序列化为JSON
func (p *DeductionPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *DeductionPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// CompensationPayload 补偿任务载荷
type CompensationPayload struct {
TransactionID string `json:"transaction_id"`
Type string `json:"type"`
}
// GetType 获取任务类型
func (p *CompensationPayload) GetType() TaskType {
return TaskTypeCompensation
}
// ToJSON 序列化为JSON
func (p *CompensationPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *CompensationPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// UsageStatsPayload 使用统计任务载荷
type UsageStatsPayload struct {
SubscriptionID string `json:"subscription_id"`
UserID string `json:"user_id"`
ProductID string `json:"product_id"`
Increment int `json:"increment"`
}
// GetType 获取任务类型
func (p *UsageStatsPayload) GetType() TaskType {
return TaskTypeUsageStats
}
// ToJSON 序列化为JSON
func (p *UsageStatsPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *UsageStatsPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}
// ApiLogPayload API日志任务载荷
type ApiLogPayload struct {
TransactionID string `json:"transaction_id"`
UserID string `json:"user_id"`
ApiName string `json:"api_name"`
ProductID string `json:"product_id"`
}
// GetType 获取任务类型
func (p *ApiLogPayload) GetType() TaskType {
return TaskTypeApiLog
}
// ToJSON 序列化为JSON
func (p *ApiLogPayload) ToJSON() ([]byte, error) {
return json.Marshal(p)
}
// FromJSON 从JSON反序列化
func (p *ApiLogPayload) FromJSON(data []byte) error {
return json.Unmarshal(data, p)
}

View File

@@ -0,0 +1,29 @@
package types
// TaskType 任务类型
type TaskType string
const (
// 文章相关任务
TaskTypeArticlePublish TaskType = "article_publish"
TaskTypeArticleCancel TaskType = "article_cancel"
TaskTypeArticleModify TaskType = "article_modify"
// API相关任务
TaskTypeApiCall TaskType = "api_call"
TaskTypeApiLog TaskType = "api_log"
// 财务相关任务
TaskTypeDeduction TaskType = "deduction"
TaskTypeCompensation TaskType = "compensation"
// 产品相关任务
TaskTypeUsageStats TaskType = "usage_stats"
)
// TaskPayload 任务载荷接口
type TaskPayload interface {
GetType() TaskType
ToJSON() ([]byte, error)
FromJSON(data []byte) error
}

View File

@@ -0,0 +1,100 @@
package utils
import (
"context"
"github.com/hibiken/asynq"
"go.uber.org/zap"
)
// AsynqLogger Asynq日志适配器
type AsynqLogger struct {
logger *zap.Logger
}
// NewAsynqLogger 创建Asynq日志适配器
func NewAsynqLogger(logger *zap.Logger) *AsynqLogger {
return &AsynqLogger{
logger: logger,
}
}
// Debug 调试日志
func (l *AsynqLogger) Debug(args ...interface{}) {
l.logger.Debug("", zap.Any("args", args))
}
// Info 信息日志
func (l *AsynqLogger) Info(args ...interface{}) {
l.logger.Info("", zap.Any("args", args))
}
// Warn 警告日志
func (l *AsynqLogger) Warn(args ...interface{}) {
l.logger.Warn("", zap.Any("args", args))
}
// Error 错误日志
func (l *AsynqLogger) Error(args ...interface{}) {
l.logger.Error("", zap.Any("args", args))
}
// Fatal 致命错误日志
func (l *AsynqLogger) Fatal(args ...interface{}) {
l.logger.Fatal("", zap.Any("args", args))
}
// Debugf 格式化调试日志
func (l *AsynqLogger) Debugf(format string, args ...interface{}) {
l.logger.Debug("", zap.String("format", format), zap.Any("args", args))
}
// Infof 格式化信息日志
func (l *AsynqLogger) Infof(format string, args ...interface{}) {
l.logger.Info("", zap.String("format", format), zap.Any("args", args))
}
// Warnf 格式化警告日志
func (l *AsynqLogger) Warnf(format string, args ...interface{}) {
l.logger.Warn("", zap.String("format", format), zap.Any("args", args))
}
// Errorf 格式化错误日志
func (l *AsynqLogger) Errorf(format string, args ...interface{}) {
l.logger.Error("", zap.String("format", format), zap.Any("args", args))
}
// Fatalf 格式化致命错误日志
func (l *AsynqLogger) Fatalf(format string, args ...interface{}) {
l.logger.Fatal("", zap.String("format", format), zap.Any("args", args))
}
// WithField 添加字段
func (l *AsynqLogger) WithField(key string, value interface{}) asynq.Logger {
return &AsynqLogger{
logger: l.logger.With(zap.Any(key, value)),
}
}
// WithFields 添加多个字段
func (l *AsynqLogger) WithFields(fields map[string]interface{}) asynq.Logger {
zapFields := make([]zap.Field, 0, len(fields))
for k, v := range fields {
zapFields = append(zapFields, zap.Any(k, v))
}
return &AsynqLogger{
logger: l.logger.With(zapFields...),
}
}
// WithError 添加错误字段
func (l *AsynqLogger) WithError(err error) asynq.Logger {
return &AsynqLogger{
logger: l.logger.With(zap.Error(err)),
}
}
// WithContext 添加上下文
func (l *AsynqLogger) WithContext(ctx context.Context) asynq.Logger {
return l
}

View File

@@ -0,0 +1,17 @@
package utils
import (
"fmt"
"github.com/google/uuid"
)
// GenerateTaskID 生成统一格式的任务ID (UUID)
func GenerateTaskID() string {
return uuid.New().String()
}
// GenerateTaskIDWithPrefix 生成带前缀的任务ID (UUID)
func GenerateTaskIDWithPrefix(prefix string) string {
return fmt.Sprintf("%s-%s", prefix, uuid.New().String())
}