f
This commit is contained in:
@@ -0,0 +1,556 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/api/entities"
|
||||
"hyapi-server/internal/domains/api/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiCallsTable = "api_calls"
|
||||
ApiCallCacheTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ApiCallWithProduct 包含产品名称的API调用记录
|
||||
type ApiCallWithProduct struct {
|
||||
entities.ApiCall
|
||||
ProductName string `json:"product_name" gorm:"column:product_name"`
|
||||
}
|
||||
|
||||
type GormApiCallRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ApiCallRepository = (*GormApiCallRepository)(nil)
|
||||
|
||||
func NewGormApiCallRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiCallRepository {
|
||||
return &GormApiCallRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiCallsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) Create(ctx context.Context, call *entities.ApiCall) error {
|
||||
return r.CreateEntity(ctx, call)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) Update(ctx context.Context, call *entities.ApiCall) error {
|
||||
return r.UpdateEntity(ctx, call)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindById(ctx context.Context, id string) (*entities.ApiCall, error) {
|
||||
var call entities.ApiCall
|
||||
err := r.SmartGetByID(ctx, id, &call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &call, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindByUserId(ctx context.Context, userId string, limit, offset int) ([]*entities.ApiCall, error) {
|
||||
var calls []*entities.ApiCall
|
||||
options := database.CacheListOptions{
|
||||
Where: "user_id = ?",
|
||||
Args: []interface{}{userId},
|
||||
Order: "created_at DESC",
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
err := r.ListWithCache(ctx, &calls, ApiCallCacheTTL, options)
|
||||
return calls, err
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
|
||||
var calls []*entities.ApiCall
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
|
||||
return calls, total, err
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
|
||||
var calls []*entities.ApiCall
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品ID筛选
|
||||
if productId, ok := filters["product_id"].(string); ok && productId != "" {
|
||||
whereCondition += " AND product_id = ?"
|
||||
whereArgs = append(whereArgs, productId)
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
|
||||
return calls, total, err
|
||||
}
|
||||
|
||||
// ListByUserIdWithFiltersAndProductName 根据用户ID和筛选条件获取API调用记录(包含产品名称)
|
||||
func (r *GormApiCallRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
|
||||
var callsWithProduct []*ApiCallWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "ac.user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND ac.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND ac.status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
query := r.GetDB(ctx).Table("api_calls ac").
|
||||
Select("ac.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON ac.product_id = p.id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("ac." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("ac.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&callsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.ApiCall并构建产品名称映射
|
||||
var calls []*entities.ApiCall
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, c := range callsWithProduct {
|
||||
call := c.ApiCall
|
||||
calls = append(calls, &call)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if c.ProductName != "" {
|
||||
productNameMap[call.ID] = c.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, calls, total, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ?", userId)
|
||||
}
|
||||
|
||||
// CountByUserIdAndProductId 按用户ID和产品ID统计API调用次数
|
||||
func (r *GormApiCallRepository) CountByUserIdAndProductId(ctx context.Context, userId string, productId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND product_id = ?", userId, productId)
|
||||
}
|
||||
|
||||
// CountByUserIdAndDateRange 按用户ID和日期范围统计API调用次数
|
||||
func (r *GormApiCallRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
|
||||
}
|
||||
|
||||
// GetDailyStatsByUserId 获取用户每日API调用统计
|
||||
func (r *GormApiCallRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE user_id = $1
|
||||
AND DATE(created_at) >= $2
|
||||
AND DATE(created_at) <= $3
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStatsByUserId 获取用户每月API调用统计
|
||||
func (r *GormApiCallRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE user_id = $1
|
||||
AND created_at >= $2
|
||||
AND created_at <= $3
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindByTransactionId(ctx context.Context, transactionId string) (*entities.ApiCall, error) {
|
||||
var call entities.ApiCall
|
||||
err := r.FindOne(ctx, &call, "transaction_id = ?", transactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &call, nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndProductName 管理端:根据条件筛选所有API调用记录(包含产品名称)
|
||||
func (r *GormApiCallRepository) ListWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
|
||||
var callsWithProduct []*ApiCallWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "1=1"
|
||||
whereArgs := []interface{}{}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 用户ID筛选(支持单个user_id和多个user_ids)
|
||||
// 如果同时存在,优先使用user_ids(批量查询)
|
||||
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
|
||||
// 解析逗号分隔的用户ID列表
|
||||
userIdsList := strings.Split(userIds, ",")
|
||||
// 去除空白字符
|
||||
var cleanUserIds []string
|
||||
for _, id := range userIdsList {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
cleanUserIds = append(cleanUserIds, id)
|
||||
}
|
||||
}
|
||||
if len(cleanUserIds) > 0 {
|
||||
placeholders := strings.Repeat("?,", len(cleanUserIds))
|
||||
placeholders = placeholders[:len(placeholders)-1] // 移除最后一个逗号
|
||||
whereCondition += " AND ac.user_id IN (" + placeholders + ")"
|
||||
for _, id := range cleanUserIds {
|
||||
whereArgs = append(whereArgs, id)
|
||||
}
|
||||
}
|
||||
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
|
||||
// 单个用户ID筛选
|
||||
whereCondition += " AND ac.user_id = ?"
|
||||
whereArgs = append(whereArgs, userId)
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND ac.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 企业名称筛选
|
||||
if companyName, ok := filters["company_name"].(string); ok && companyName != "" {
|
||||
whereCondition += " AND ei.company_name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+companyName+"%")
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND ac.status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
// 需要JOIN product表获取产品名称,JOIN users和enterprise_infos表获取企业名称
|
||||
query := r.GetDB(ctx).Table("api_calls ac").
|
||||
Select("ac.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON ac.product_id = p.id").
|
||||
Joins("LEFT JOIN users u ON ac.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("ac." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("ac.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&callsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.ApiCall并构建产品名称映射
|
||||
var calls []*entities.ApiCall
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, c := range callsWithProduct {
|
||||
call := c.ApiCall
|
||||
calls = append(calls, &call)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if c.ProductName != "" {
|
||||
productNameMap[call.ID] = c.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, calls, total, nil
|
||||
}
|
||||
|
||||
// GetSystemTotalCalls 获取系统总API调用次数
|
||||
func (r *GormApiCallRepository) GetSystemTotalCalls(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ApiCall{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetSystemCallsByDateRange 获取系统指定时间范围内的API调用次数
|
||||
// endDate 应该是结束日期当天的次日00:00:00(日统计)或下个月1号00:00:00(月统计),使用 < 而不是 <=
|
||||
func (r *GormApiCallRepository) GetSystemCallsByDateRange(ctx context.Context, startDate, endDate time.Time) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ApiCall{}).
|
||||
Where("created_at >= ? AND created_at < ?", startDate, endDate).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetSystemDailyStats 获取系统每日API调用统计
|
||||
func (r *GormApiCallRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE DATE(created_at) >= $1
|
||||
AND DATE(created_at) <= $2
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyStats 获取系统每月API调用统计
|
||||
func (r *GormApiCallRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE created_at >= $1
|
||||
AND created_at < $2
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetApiPopularityRanking 获取API受欢迎程度排行榜
|
||||
func (r *GormApiCallRepository) GetApiPopularityRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
p.id as product_id,
|
||||
p.name as api_name,
|
||||
p.description as api_description,
|
||||
COUNT(ac.id) as call_count
|
||||
FROM product p
|
||||
LEFT JOIN api_calls ac ON p.id = ac.product_id
|
||||
AND DATE(ac.created_at) = CURRENT_DATE
|
||||
WHERE p.deleted_at IS NULL
|
||||
GROUP BY p.id, p.name, p.description
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY call_count DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
p.id as product_id,
|
||||
p.name as api_name,
|
||||
p.description as api_description,
|
||||
COUNT(ac.id) as call_count
|
||||
FROM product p
|
||||
LEFT JOIN api_calls ac ON p.id = ac.product_id
|
||||
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE p.deleted_at IS NULL
|
||||
GROUP BY p.id, p.name, p.description
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY call_count DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
p.id as product_id,
|
||||
p.name as api_name,
|
||||
p.description as api_description,
|
||||
COUNT(ac.id) as call_count
|
||||
FROM product p
|
||||
LEFT JOIN api_calls ac ON p.id = ac.product_id
|
||||
WHERE p.deleted_at IS NULL
|
||||
GROUP BY p.id, p.name, p.description
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY call_count DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hyapi-server/internal/domains/api/entities"
|
||||
"hyapi-server/internal/domains/api/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiUsersTable = "api_users"
|
||||
ApiUserCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type GormApiUserRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ApiUserRepository = (*GormApiUserRepository)(nil)
|
||||
|
||||
func NewGormApiUserRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiUserRepository {
|
||||
return &GormApiUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiUsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) Create(ctx context.Context, user *entities.ApiUser) error {
|
||||
return r.CreateEntity(ctx, user)
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) Update(ctx context.Context, user *entities.ApiUser) error {
|
||||
return r.UpdateEntity(ctx, user)
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) FindByAccessId(ctx context.Context, accessId string) (*entities.ApiUser, error) {
|
||||
var user entities.ApiUser
|
||||
err := r.SmartGetByField(ctx, &user, "access_id", accessId, ApiUserCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) FindByUserId(ctx context.Context, userId string) (*entities.ApiUser, error) {
|
||||
var user entities.ApiUser
|
||||
err := r.SmartGetByField(ctx, &user, "user_id", userId, ApiUserCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"hyapi-server/internal/domains/api/entities"
|
||||
"hyapi-server/internal/domains/api/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ReportsTable = "reports"
|
||||
)
|
||||
|
||||
// GormReportRepository 报告记录 GORM 仓储实现
|
||||
type GormReportRepository struct {
|
||||
*database.BaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ReportRepository = (*GormReportRepository)(nil)
|
||||
|
||||
// NewGormReportRepository 创建报告记录仓储实现
|
||||
func NewGormReportRepository(db *gorm.DB, logger *zap.Logger) repositories.ReportRepository {
|
||||
return &GormReportRepository{
|
||||
BaseRepositoryImpl: database.NewBaseRepositoryImpl(db, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建报告记录
|
||||
func (r *GormReportRepository) Create(ctx context.Context, report *entities.Report) error {
|
||||
return r.CreateEntity(ctx, report)
|
||||
}
|
||||
|
||||
// FindByReportID 根据报告编号查询记录
|
||||
func (r *GormReportRepository) FindByReportID(ctx context.Context, reportID string) (*entities.Report, error) {
|
||||
var report entities.Report
|
||||
if err := r.FindOneByField(ctx, &report, "report_id", reportID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &report, nil
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
repoQueries "hyapi-server/internal/domains/article/repositories/queries"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormAnnouncementRepository GORM公告仓储实现
|
||||
type GormAnnouncementRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.AnnouncementRepository = (*GormAnnouncementRepository)(nil)
|
||||
|
||||
// NewGormAnnouncementRepository 创建GORM公告仓储
|
||||
func NewGormAnnouncementRepository(db *gorm.DB, logger *zap.Logger) *GormAnnouncementRepository {
|
||||
return &GormAnnouncementRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建公告
|
||||
func (r *GormAnnouncementRepository) Create(ctx context.Context, entity entities.Announcement) (entities.Announcement, error) {
|
||||
r.logger.Info("创建公告", zap.String("id", entity.ID), zap.String("title", entity.Title))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("创建公告失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取公告
|
||||
func (r *GormAnnouncementRepository) GetByID(ctx context.Context, id string) (entities.Announcement, error) {
|
||||
var entity entities.Announcement
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("id = ?", id).
|
||||
First(&entity).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("公告不存在")
|
||||
}
|
||||
r.logger.Error("获取公告失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新公告
|
||||
func (r *GormAnnouncementRepository) Update(ctx context.Context, entity entities.Announcement) error {
|
||||
r.logger.Info("更新公告", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新公告失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除公告
|
||||
func (r *GormAnnouncementRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除公告", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除公告失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找公告
|
||||
func (r *GormAnnouncementRepository) FindByStatus(ctx context.Context, status entities.AnnouncementStatus) ([]*entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&announcements).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据状态查找公告失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Announcement, len(announcements))
|
||||
for i := range announcements {
|
||||
result[i] = &announcements[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindScheduled 查找定时发布的公告
|
||||
func (r *GormAnnouncementRepository) FindScheduled(ctx context.Context) ([]*entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
now := time.Now()
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.AnnouncementStatusDraft, now).
|
||||
Order("scheduled_at ASC").
|
||||
Find(&announcements).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("查找定时发布公告失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Announcement, len(announcements))
|
||||
for i := range announcements {
|
||||
result[i] = &announcements[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListAnnouncements 获取公告列表
|
||||
func (r *GormAnnouncementRepository) ListAnnouncements(ctx context.Context, query *repoQueries.ListAnnouncementQuery) ([]*entities.Announcement, int64, error) {
|
||||
var announcements []entities.Announcement
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.Title != "" {
|
||||
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取公告列表总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&announcements).Error; err != nil {
|
||||
r.logger.Error("获取公告列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Announcement, len(announcements))
|
||||
for i := range announcements {
|
||||
result[i] = &announcements[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// CountByStatus 根据状态统计公告数量
|
||||
func (r *GormAnnouncementRepository) CountByStatus(ctx context.Context, status entities.AnnouncementStatus) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.Announcement{}).
|
||||
Where("status = ?", status).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("统计公告数量失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// UpdateStatistics 更新统计信息
|
||||
// 注意:公告实体目前没有统计字段,此方法预留扩展
|
||||
func (r *GormAnnouncementRepository) UpdateStatistics(ctx context.Context, announcementID string) error {
|
||||
r.logger.Info("更新公告统计信息", zap.String("announcement_id", announcementID))
|
||||
// TODO: 如果将来需要统计字段(如阅读量等),可以在这里实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 实现 BaseRepository 接口的其他方法 ================
|
||||
|
||||
// Count 统计数量
|
||||
func (r *GormAnnouncementRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查是否存在
|
||||
func (r *GormAnnouncementRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Announcement{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除
|
||||
func (r *GormAnnouncementRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除
|
||||
func (r *GormAnnouncementRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Announcement{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建
|
||||
func (r *GormAnnouncementRepository) CreateBatch(ctx context.Context, entities []entities.Announcement) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取
|
||||
func (r *GormAnnouncementRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&announcements).Error
|
||||
return announcements, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新
|
||||
func (r *GormAnnouncementRepository) UpdateBatch(ctx context.Context, entities []entities.Announcement) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除
|
||||
func (r *GormAnnouncementRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 列表查询
|
||||
func (r *GormAnnouncementRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = strings.ToUpper(options.Order)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&announcements).Error
|
||||
return announcements, err
|
||||
}
|
||||
@@ -0,0 +1,592 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
repoQueries "hyapi-server/internal/domains/article/repositories/queries"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormArticleRepository GORM文章仓储实现
|
||||
type GormArticleRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ArticleRepository = (*GormArticleRepository)(nil)
|
||||
|
||||
// NewGormArticleRepository 创建GORM文章仓储
|
||||
func NewGormArticleRepository(db *gorm.DB, logger *zap.Logger) *GormArticleRepository {
|
||||
return &GormArticleRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建文章
|
||||
func (r *GormArticleRepository) Create(ctx context.Context, entity entities.Article) (entities.Article, error) {
|
||||
r.logger.Info("创建文章", zap.String("id", entity.ID), zap.String("title", entity.Title))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建文章失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取文章
|
||||
func (r *GormArticleRepository) GetByID(ctx context.Context, id string) (entities.Article, error) {
|
||||
var entity entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("id = ?", id).
|
||||
First(&entity).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("文章不存在")
|
||||
}
|
||||
r.logger.Error("获取文章失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新文章
|
||||
func (r *GormArticleRepository) Update(ctx context.Context, entity entities.Article) error {
|
||||
r.logger.Info("更新文章", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新文章失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除文章
|
||||
func (r *GormArticleRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除文章", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Article{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除文章失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByAuthorID 根据作者ID查找文章
|
||||
func (r *GormArticleRepository) FindByAuthorID(ctx context.Context, authorID string) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("author_id = ?", authorID).
|
||||
Order("created_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据作者ID查找文章失败", zap.String("author_id", authorID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根据分类ID查找文章
|
||||
func (r *GormArticleRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("category_id = ?", categoryID).
|
||||
Order("created_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据分类ID查找文章失败", zap.String("category_id", categoryID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找文章
|
||||
func (r *GormArticleRepository) FindByStatus(ctx context.Context, status entities.ArticleStatus) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据状态查找文章失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindFeatured 查找推荐文章
|
||||
func (r *GormArticleRepository) FindFeatured(ctx context.Context) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("is_featured = ? AND status = ?", true, entities.ArticleStatusPublished).
|
||||
Order("published_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("查找推荐文章失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Search 搜索文章
|
||||
func (r *GormArticleRepository) Search(ctx context.Context, query *repoQueries.SearchArticleQuery) ([]*entities.Article, int64, error) {
|
||||
var articles []entities.Article
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用搜索条件
|
||||
if query.Keyword != "" {
|
||||
keyword := "%" + query.Keyword + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ? OR summary LIKE ?", keyword, keyword, keyword)
|
||||
}
|
||||
|
||||
if query.CategoryID != "" {
|
||||
// 如果指定了分类ID,只查询该分类的文章(包括没有分类的文章,当CategoryID为空字符串时)
|
||||
if query.CategoryID == "null" || query.CategoryID == "" {
|
||||
// 查询没有分类的文章
|
||||
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
|
||||
} else {
|
||||
// 查询指定分类的文章
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
}
|
||||
|
||||
if query.AuthorID != "" {
|
||||
dbQuery = dbQuery.Where("author_id = ?", query.AuthorID)
|
||||
}
|
||||
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取搜索结果总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
dbQuery = dbQuery.Preload("Category").Preload("Tags")
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&articles).Error; err != nil {
|
||||
r.logger.Error("搜索文章失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// ListArticles 获取文章列表(用户端)
|
||||
func (r *GormArticleRepository) ListArticles(ctx context.Context, query *repoQueries.ListArticleQuery) ([]*entities.Article, int64, error) {
|
||||
var articles []entities.Article
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 用户端不显示归档文章
|
||||
dbQuery = dbQuery.Where("status != ?", entities.ArticleStatusArchived)
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.CategoryID != "" {
|
||||
// 如果指定了分类ID,只查询该分类的文章(包括没有分类的文章,当CategoryID为空字符串时)
|
||||
if query.CategoryID == "null" || query.CategoryID == "" {
|
||||
// 查询没有分类的文章
|
||||
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
|
||||
} else {
|
||||
// 查询指定分类的文章
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
}
|
||||
|
||||
if query.TagID != "" {
|
||||
// 如果指定了标签ID,只查询有关联该标签的文章
|
||||
// 使用子查询而不是JOIN,避免影响其他查询条件
|
||||
subQuery := r.db.WithContext(ctx).Table("article_tag_relations").
|
||||
Select("article_id").
|
||||
Where("tag_id = ?", query.TagID)
|
||||
dbQuery = dbQuery.Where("id IN (?)", subQuery)
|
||||
}
|
||||
|
||||
if query.Title != "" {
|
||||
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
|
||||
}
|
||||
|
||||
if query.Summary != "" {
|
||||
dbQuery = dbQuery.Where("summary ILIKE ?", "%"+query.Summary+"%")
|
||||
}
|
||||
|
||||
if query.IsFeatured != nil {
|
||||
dbQuery = dbQuery.Where("is_featured = ?", *query.IsFeatured)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取文章列表总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
dbQuery = dbQuery.Preload("Category").Preload("Tags")
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&articles).Error; err != nil {
|
||||
r.logger.Error("获取文章列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// ListArticlesForAdmin 获取文章列表(管理员端)
|
||||
func (r *GormArticleRepository) ListArticlesForAdmin(ctx context.Context, query *repoQueries.ListArticleQuery) ([]*entities.Article, int64, error) {
|
||||
var articles []entities.Article
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.CategoryID != "" {
|
||||
// 如果指定了分类ID,只查询该分类的文章(包括没有分类的文章,当CategoryID为空字符串时)
|
||||
if query.CategoryID == "null" || query.CategoryID == "" {
|
||||
// 查询没有分类的文章
|
||||
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
|
||||
} else {
|
||||
// 查询指定分类的文章
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
}
|
||||
|
||||
if query.TagID != "" {
|
||||
// 如果指定了标签ID,只查询有关联该标签的文章
|
||||
// 使用子查询而不是JOIN,避免影响其他查询条件
|
||||
subQuery := r.db.WithContext(ctx).Table("article_tag_relations").
|
||||
Select("article_id").
|
||||
Where("tag_id = ?", query.TagID)
|
||||
dbQuery = dbQuery.Where("id IN (?)", subQuery)
|
||||
}
|
||||
|
||||
if query.Title != "" {
|
||||
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
|
||||
}
|
||||
|
||||
if query.Summary != "" {
|
||||
dbQuery = dbQuery.Where("summary ILIKE ?", "%"+query.Summary+"%")
|
||||
}
|
||||
|
||||
if query.IsFeatured != nil {
|
||||
dbQuery = dbQuery.Where("is_featured = ?", *query.IsFeatured)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取文章列表总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
dbQuery = dbQuery.Preload("Category").Preload("Tags")
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&articles).Error; err != nil {
|
||||
r.logger.Error("获取文章列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CountByCategoryID 统计分类文章数量
|
||||
func (r *GormArticleRepository) CountByCategoryID(ctx context.Context, categoryID string) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.Article{}).
|
||||
Where("category_id = ?", categoryID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("统计分类文章数量失败", zap.String("category_id", categoryID), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByStatus 统计状态文章数量
|
||||
func (r *GormArticleRepository) CountByStatus(ctx context.Context, status entities.ArticleStatus) (int64, error) {
|
||||
var count int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
if status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", status)
|
||||
}
|
||||
|
||||
err := dbQuery.Count(&count).Error
|
||||
if err != nil {
|
||||
r.logger.Error("统计状态文章数量失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// IncrementViewCount 增加阅读量
|
||||
func (r *GormArticleRepository) IncrementViewCount(ctx context.Context, articleID string) error {
|
||||
err := r.db.WithContext(ctx).Model(&entities.Article{}).
|
||||
Where("id = ?", articleID).
|
||||
UpdateColumn("view_count", gorm.Expr("view_count + ?", 1)).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("增加阅读量失败", zap.String("article_id", articleID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// 实现 BaseRepository 接口的其他方法
|
||||
func (r *GormArticleRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Article{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Article{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Article{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) CreateBatch(ctx context.Context, entities []entities.Article) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&articles).Error
|
||||
return articles, err
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) UpdateBatch(ctx context.Context, entities []entities.Article) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Article{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = strings.ToUpper(options.Order)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&articles).Error
|
||||
return articles, err
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormCategoryRepository GORM分类仓储实现
|
||||
type GormCategoryRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CategoryRepository = (*GormCategoryRepository)(nil)
|
||||
|
||||
// NewGormCategoryRepository 创建GORM分类仓储
|
||||
func NewGormCategoryRepository(db *gorm.DB, logger *zap.Logger) *GormCategoryRepository {
|
||||
return &GormCategoryRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建分类
|
||||
func (r *GormCategoryRepository) Create(ctx context.Context, entity entities.Category) (entities.Category, error) {
|
||||
r.logger.Info("创建分类", zap.String("id", entity.ID), zap.String("name", entity.Name))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建分类失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取分类
|
||||
func (r *GormCategoryRepository) GetByID(ctx context.Context, id string) (entities.Category, error) {
|
||||
var entity entities.Category
|
||||
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("分类不存在")
|
||||
}
|
||||
r.logger.Error("获取分类失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新分类
|
||||
func (r *GormCategoryRepository) Update(ctx context.Context, entity entities.Category) error {
|
||||
r.logger.Info("更新分类", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新分类失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除分类
|
||||
func (r *GormCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除分类", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Category{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除分类失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindActive 查找启用的分类
|
||||
func (r *GormCategoryRepository) FindActive(ctx context.Context) ([]*entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("active = ?", true).
|
||||
Order("sort_order ASC, created_at ASC").
|
||||
Find(&categories).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("查找启用分类失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Category, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindBySortOrder 按排序查找分类
|
||||
func (r *GormCategoryRepository) FindBySortOrder(ctx context.Context) ([]*entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Order("sort_order ASC, created_at ASC").
|
||||
Find(&categories).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("按排序查找分类失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Category, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CountActive 统计启用分类数量
|
||||
func (r *GormCategoryRepository) CountActive(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.Category{}).
|
||||
Where("active = ?", true).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("统计启用分类数量失败", zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// 实现 BaseRepository 接口的其他方法
|
||||
func (r *GormCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Category{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ?", search, search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Category{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Category{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Category{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) CreateBatch(ctx context.Context, entities []entities.Category) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) UpdateBatch(ctx context.Context, entities []entities.Category) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Category{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Category{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ?", search, search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("sort_order ASC, created_at ASC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormScheduledTaskRepository GORM定时任务仓储实现
|
||||
type GormScheduledTaskRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ScheduledTaskRepository = (*GormScheduledTaskRepository)(nil)
|
||||
|
||||
// NewGormScheduledTaskRepository 创建GORM定时任务仓储
|
||||
func NewGormScheduledTaskRepository(db *gorm.DB, logger *zap.Logger) *GormScheduledTaskRepository {
|
||||
return &GormScheduledTaskRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建定时任务记录
|
||||
func (r *GormScheduledTaskRepository) Create(ctx context.Context, task entities.ScheduledTask) (entities.ScheduledTask, error) {
|
||||
r.logger.Info("创建定时任务记录", zap.String("task_id", task.TaskID), zap.String("article_id", task.ArticleID))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&task).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建定时任务记录失败", zap.Error(err))
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// GetByTaskID 根据Asynq任务ID获取任务记录
|
||||
func (r *GormScheduledTaskRepository) GetByTaskID(ctx context.Context, taskID string) (entities.ScheduledTask, error) {
|
||||
var task entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("task_id = ?", taskID).
|
||||
First(&task).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return task, fmt.Errorf("定时任务不存在")
|
||||
}
|
||||
r.logger.Error("获取定时任务失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// GetByArticleID 根据文章ID获取任务记录
|
||||
func (r *GormScheduledTaskRepository) GetByArticleID(ctx context.Context, articleID string) (entities.ScheduledTask, error) {
|
||||
var task entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("article_id = ? AND status IN (?)", articleID, []string{"pending", "running"}).
|
||||
First(&task).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return task, fmt.Errorf("文章没有活动的定时任务")
|
||||
}
|
||||
r.logger.Error("获取文章定时任务失败", zap.String("article_id", articleID), zap.Error(err))
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// Update 更新任务记录
|
||||
func (r *GormScheduledTaskRepository) Update(ctx context.Context, task entities.ScheduledTask) error {
|
||||
r.logger.Info("更新定时任务记录", zap.String("task_id", task.TaskID), zap.String("status", string(task.Status)))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&task).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新定时任务记录失败", zap.String("task_id", task.TaskID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除任务记录
|
||||
func (r *GormScheduledTaskRepository) Delete(ctx context.Context, taskID string) error {
|
||||
r.logger.Info("删除定时任务记录", zap.String("task_id", taskID))
|
||||
|
||||
err := r.db.WithContext(ctx).Where("task_id = ?", taskID).Delete(&entities.ScheduledTask{}).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除定时任务记录失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsCancelled 标记任务为已取消
|
||||
func (r *GormScheduledTaskRepository) MarkAsCancelled(ctx context.Context, taskID string) error {
|
||||
r.logger.Info("标记定时任务为已取消", zap.String("task_id", taskID))
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.ScheduledTask{}).
|
||||
Where("task_id = ? AND status IN (?)", taskID, []string{"pending", "running"}).
|
||||
Updates(map[string]interface{}{
|
||||
"status": entities.TaskStatusCancelled,
|
||||
"completed_at": time.Now(),
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
r.logger.Error("标记定时任务为已取消失败", zap.String("task_id", taskID), zap.Error(result.Error))
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
r.logger.Warn("没有找到需要取消的定时任务", zap.String("task_id", taskID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveTasks 获取活动状态的任务列表
|
||||
func (r *GormScheduledTaskRepository) GetActiveTasks(ctx context.Context) ([]entities.ScheduledTask, error) {
|
||||
var tasks []entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("status IN (?)", []string{"pending", "running"}).
|
||||
Order("scheduled_at ASC").
|
||||
Find(&tasks).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("获取活动定时任务列表失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// GetExpiredTasks 获取过期的任务列表
|
||||
func (r *GormScheduledTaskRepository) GetExpiredTasks(ctx context.Context) ([]entities.ScheduledTask, error) {
|
||||
var tasks []entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("status = ? AND scheduled_at < ?", entities.TaskStatusPending, time.Now()).
|
||||
Order("scheduled_at ASC").
|
||||
Find(&tasks).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("获取过期定时任务列表失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormTagRepository GORM标签仓储实现
|
||||
type GormTagRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.TagRepository = (*GormTagRepository)(nil)
|
||||
|
||||
// NewGormTagRepository 创建GORM标签仓储
|
||||
func NewGormTagRepository(db *gorm.DB, logger *zap.Logger) *GormTagRepository {
|
||||
return &GormTagRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建标签
|
||||
func (r *GormTagRepository) Create(ctx context.Context, entity entities.Tag) (entities.Tag, error) {
|
||||
r.logger.Info("创建标签", zap.String("id", entity.ID), zap.String("name", entity.Name))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建标签失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取标签
|
||||
func (r *GormTagRepository) GetByID(ctx context.Context, id string) (entities.Tag, error) {
|
||||
var entity entities.Tag
|
||||
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("标签不存在")
|
||||
}
|
||||
r.logger.Error("获取标签失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新标签
|
||||
func (r *GormTagRepository) Update(ctx context.Context, entity entities.Tag) error {
|
||||
r.logger.Info("更新标签", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新标签失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除标签
|
||||
func (r *GormTagRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除标签", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Tag{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除标签失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByArticleID 根据文章ID查找标签
|
||||
func (r *GormTagRepository) FindByArticleID(ctx context.Context, articleID string) ([]*entities.Tag, error) {
|
||||
var tags []entities.Tag
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN article_tag_relations ON article_tag_relations.tag_id = tags.id").
|
||||
Where("article_tag_relations.article_id = ?", articleID).
|
||||
Find(&tags).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据文章ID查找标签失败", zap.String("article_id", articleID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Tag, len(tags))
|
||||
for i := range tags {
|
||||
result[i] = &tags[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByName 根据名称查找标签
|
||||
func (r *GormTagRepository) FindByName(ctx context.Context, name string) (*entities.Tag, error) {
|
||||
var tag entities.Tag
|
||||
|
||||
err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
r.logger.Error("根据名称查找标签失败", zap.String("name", name), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tag, nil
|
||||
}
|
||||
|
||||
// AddTagToArticle 为文章添加标签
|
||||
func (r *GormTagRepository) AddTagToArticle(ctx context.Context, articleID string, tagID string) error {
|
||||
// 检查关联是否已存在
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Table("article_tag_relations").
|
||||
Where("article_id = ? AND tag_id = ?", articleID, tagID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("检查标签关联失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
// 关联已存在,不需要重复添加
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建关联
|
||||
err = r.db.WithContext(ctx).Exec(`
|
||||
INSERT INTO article_tag_relations (article_id, tag_id)
|
||||
VALUES (?, ?)
|
||||
`, articleID, tagID).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("添加标签到文章失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("添加标签到文章成功", zap.String("article_id", articleID), zap.String("tag_id", tagID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTagFromArticle 从文章移除标签
|
||||
func (r *GormTagRepository) RemoveTagFromArticle(ctx context.Context, articleID string, tagID string) error {
|
||||
err := r.db.WithContext(ctx).Exec(`
|
||||
DELETE FROM article_tag_relations
|
||||
WHERE article_id = ? AND tag_id = ?
|
||||
`, articleID, tagID).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("从文章移除标签失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("从文章移除标签成功", zap.String("article_id", articleID), zap.String("tag_id", tagID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetArticleTags 获取文章的所有标签
|
||||
func (r *GormTagRepository) GetArticleTags(ctx context.Context, articleID string) ([]*entities.Tag, error) {
|
||||
return r.FindByArticleID(ctx, articleID)
|
||||
}
|
||||
|
||||
// 实现 BaseRepository 接口的其他方法
|
||||
func (r *GormTagRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Tag{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ?", search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Tag{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Tag{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Tag{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) CreateBatch(ctx context.Context, entities []entities.Tag) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Tag, error) {
|
||||
var tags []entities.Tag
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&tags).Error
|
||||
return tags, err
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) UpdateBatch(ctx context.Context, entities []entities.Tag) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Tag{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Tag, error) {
|
||||
var tags []entities.Tag
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Tag{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ?", search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at ASC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&tags).Error
|
||||
return tags, err
|
||||
}
|
||||
@@ -0,0 +1,370 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/certification/entities"
|
||||
"hyapi-server/internal/domains/certification/enums"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// ================ 常量定义 ================
|
||||
|
||||
const (
|
||||
// 表名常量
|
||||
CertificationsTable = "certifications"
|
||||
|
||||
// 缓存时间常量
|
||||
CacheTTLPrimaryQuery = 30 * time.Minute // 主键查询缓存时间
|
||||
CacheTTLBusinessQuery = 15 * time.Minute // 业务查询缓存时间
|
||||
CacheTTLUserQuery = 10 * time.Minute // 用户相关查询缓存时间
|
||||
CacheTTLWarmupLong = 30 * time.Minute // 预热长期缓存
|
||||
CacheTTLWarmupMedium = 15 * time.Minute // 预热中期缓存
|
||||
|
||||
// 缓存键模式常量
|
||||
CachePatternTable = "gorm_cache:certifications:*"
|
||||
CachePatternUser = "certification:user_id:*"
|
||||
)
|
||||
|
||||
// ================ Repository 实现 ================
|
||||
|
||||
// GormCertificationCommandRepository 认证命令仓储GORM实现
|
||||
//
|
||||
// 特性说明:
|
||||
// - 基于 CachedBaseRepositoryImpl 实现自动缓存管理
|
||||
// - 支持多级缓存策略(主键查询30分钟,业务查询15分钟)
|
||||
// - 自动缓存失效:写操作时自动清理相关缓存
|
||||
// - 智能缓存选择:根据查询复杂度自动选择缓存策略
|
||||
// - 内置监控支持:提供缓存统计和性能监控
|
||||
type GormCertificationCommandRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CertificationCommandRepository = (*GormCertificationCommandRepository)(nil)
|
||||
|
||||
// NewGormCertificationCommandRepository 创建认证命令仓储
|
||||
//
|
||||
// 参数:
|
||||
// - db: GORM数据库连接实例
|
||||
// - logger: 日志记录器
|
||||
//
|
||||
// 返回:
|
||||
// - repositories.CertificationCommandRepository: 仓储接口实现
|
||||
func NewGormCertificationCommandRepository(db *gorm.DB, logger *zap.Logger) repositories.CertificationCommandRepository {
|
||||
return &GormCertificationCommandRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, CertificationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建认证
|
||||
//
|
||||
// 业务说明:
|
||||
// - 创建新的认证申请
|
||||
// - 自动触发相关缓存失效
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - cert: 认证实体
|
||||
//
|
||||
// 返回:
|
||||
// - error: 创建失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) Create(ctx context.Context, cert entities.Certification) error {
|
||||
r.GetLogger().Info("创建认证申请",
|
||||
zap.String("user_id", cert.UserID),
|
||||
zap.String("status", string(cert.Status)))
|
||||
|
||||
return r.CreateEntity(ctx, &cert)
|
||||
}
|
||||
|
||||
// Update 更新认证
|
||||
//
|
||||
// 缓存影响:
|
||||
// - GORM缓存插件会自动失效相关缓存
|
||||
// - 无需手动管理缓存一致性
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - cert: 认证实体
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) Update(ctx context.Context, cert entities.Certification) error {
|
||||
r.GetLogger().Info("更新认证",
|
||||
zap.String("id", cert.ID),
|
||||
zap.String("status", string(cert.Status)))
|
||||
|
||||
return r.UpdateEntity(ctx, &cert)
|
||||
}
|
||||
|
||||
// Delete 删除认证
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
//
|
||||
// 返回:
|
||||
// - error: 删除失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) Delete(ctx context.Context, id string) error {
|
||||
r.GetLogger().Info("删除认证", zap.String("id", id))
|
||||
return r.DeleteEntity(ctx, id, &entities.Certification{})
|
||||
}
|
||||
|
||||
// ================ 业务特定的更新操作 ================
|
||||
|
||||
// UpdateStatus 更新认证状态
|
||||
//
|
||||
// 业务说明:
|
||||
// - 更新认证的状态
|
||||
// - 自动更新时间戳
|
||||
//
|
||||
// 缓存影响:
|
||||
// - GORM缓存插件会自动失效表相关的缓存
|
||||
// - 状态更新会影响列表查询和统计结果
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - status: 新状态
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateStatus(ctx context.Context, id string, status enums.CertificationStatus) error {
|
||||
r.GetLogger().Info("更新认证状态",
|
||||
zap.String("id", id),
|
||||
zap.String("status", string(status)))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateAuthFlowID 更新认证流程ID
|
||||
//
|
||||
// 业务说明:
|
||||
// - 记录e签宝企业认证流程ID
|
||||
// - 用于回调处理和状态跟踪
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - authFlowID: 认证流程ID
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateAuthFlowID(ctx context.Context, id string, authFlowID string) error {
|
||||
r.GetLogger().Info("更新认证流程ID",
|
||||
zap.String("id", id),
|
||||
zap.String("auth_flow_id", authFlowID))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"auth_flow_id": authFlowID,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateContractInfo 更新合同信息
|
||||
//
|
||||
// 业务说明:
|
||||
// - 记录合同相关的ID和URL信息
|
||||
// - 用于合同管理和用户下载
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - contractFileID: 合同文件ID
|
||||
// - esignFlowID: e签宝流程ID
|
||||
// - contractURL: 合同URL
|
||||
// - contractSignURL: 合同签署URL
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateContractInfo(ctx context.Context, id string, contractFileID, esignFlowID, contractURL, contractSignURL string) error {
|
||||
r.GetLogger().Info("更新合同信息",
|
||||
zap.String("id", id),
|
||||
zap.String("contract_file_id", contractFileID),
|
||||
zap.String("esign_flow_id", esignFlowID))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"contract_file_id": contractFileID,
|
||||
"esign_flow_id": esignFlowID,
|
||||
"contract_url": contractURL,
|
||||
"contract_sign_url": contractSignURL,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateFailureInfo 更新失败信息
|
||||
//
|
||||
// 业务说明:
|
||||
// - 记录认证失败的原因和详细信息
|
||||
// - 用于错误分析和用户提示
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - reason: 失败原因
|
||||
// - message: 失败详细信息
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateFailureInfo(ctx context.Context, id string, reason enums.FailureReason, message string) error {
|
||||
r.GetLogger().Info("更新失败信息",
|
||||
zap.String("id", id),
|
||||
zap.String("reason", string(reason)),
|
||||
zap.String("message", message))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"failure_reason": reason,
|
||||
"failure_message": message,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// ================ 批量操作 ================
|
||||
|
||||
// BatchUpdateStatus 批量更新状态
|
||||
//
|
||||
// 业务说明:
|
||||
// - 批量更新多个认证的状态
|
||||
// - 适用于管理员批量操作
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - ids: 认证ID列表
|
||||
// - status: 新状态
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) BatchUpdateStatus(ctx context.Context, ids []string, status enums.CertificationStatus) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("批量更新状态:ID列表不能为空")
|
||||
}
|
||||
|
||||
r.GetLogger().Info("批量更新认证状态",
|
||||
zap.Strings("ids", ids),
|
||||
zap.String("status", string(status)))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
result := r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id IN ?", ids).
|
||||
Updates(updates)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量更新认证状态失败: %w", result.Error)
|
||||
}
|
||||
|
||||
r.GetLogger().Info("批量更新完成", zap.Int64("affected_rows", result.RowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 事务支持 ================
|
||||
|
||||
// WithTx 使用事务
|
||||
//
|
||||
// 业务说明:
|
||||
// - 返回支持事务的仓储实例
|
||||
// - 用于复杂业务操作的事务一致性保证
|
||||
//
|
||||
// 参数:
|
||||
// - tx: 事务对象
|
||||
//
|
||||
// 返回:
|
||||
// - repositories.CertificationCommandRepository: 支持事务的仓储实例
|
||||
func (r *GormCertificationCommandRepository) WithTx(tx interfaces.Transaction) repositories.CertificationCommandRepository {
|
||||
// 获取事务的底层*gorm.DB
|
||||
txDB := tx.GetDB()
|
||||
if gormDB, ok := txDB.(*gorm.DB); ok {
|
||||
return &GormCertificationCommandRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormDB, r.GetLogger(), CertificationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
r.GetLogger().Warn("不支持的事务类型,返回原始仓储")
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 缓存管理方法 ================
|
||||
|
||||
// WarmupCache 预热认证缓存
|
||||
//
|
||||
// 业务说明:
|
||||
// - 系统启动时预热常用查询的缓存
|
||||
// - 提升首次访问的响应速度
|
||||
//
|
||||
// 预热策略:
|
||||
// - 活跃认证:30分钟长期缓存
|
||||
// - 最近创建:15分钟中期缓存
|
||||
func (r *GormCertificationCommandRepository) WarmupCache(ctx context.Context) error {
|
||||
r.GetLogger().Info("开始预热认证缓存")
|
||||
|
||||
queries := []database.WarmupQuery{
|
||||
{
|
||||
Name: "active_certifications",
|
||||
TTL: CacheTTLWarmupLong,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
{
|
||||
Name: "recent_certifications",
|
||||
TTL: CacheTTLWarmupMedium,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
}
|
||||
|
||||
return r.WarmupCommonQueries(ctx, queries)
|
||||
}
|
||||
|
||||
// RefreshCache 刷新认证缓存
|
||||
//
|
||||
// 业务说明:
|
||||
// - 手动刷新认证相关的所有缓存
|
||||
// - 适用于数据迁移或批量更新后的缓存清理
|
||||
func (r *GormCertificationCommandRepository) RefreshCache(ctx context.Context) error {
|
||||
r.GetLogger().Info("刷新认证缓存")
|
||||
return r.CachedBaseRepositoryImpl.RefreshCache(ctx, CachePatternTable)
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计信息
|
||||
//
|
||||
// 返回当前Repository的缓存使用统计,包括:
|
||||
// - 基础缓存信息(命中率、键数量等)
|
||||
// - 特定的缓存模式列表
|
||||
// - 性能指标
|
||||
func (r *GormCertificationCommandRepository) GetCacheStats() map[string]interface{} {
|
||||
stats := r.GetCacheInfo()
|
||||
stats["specific_patterns"] = []string{
|
||||
CachePatternTable,
|
||||
CachePatternUser,
|
||||
}
|
||||
return stats
|
||||
}
|
||||
@@ -0,0 +1,469 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/certification/entities"
|
||||
"hyapi-server/internal/domains/certification/enums"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/domains/certification/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ================ 常量定义 ================
|
||||
|
||||
const (
|
||||
// 缓存时间常量
|
||||
QueryCacheTTLPrimaryQuery = 30 * time.Minute // 主键查询缓存时间
|
||||
QueryCacheTTLBusinessQuery = 15 * time.Minute // 业务查询缓存时间
|
||||
QueryCacheTTLUserQuery = 10 * time.Minute // 用户相关查询缓存时间
|
||||
QueryCacheTTLSearchQuery = 2 * time.Minute // 搜索查询缓存时间
|
||||
QueryCacheTTLActiveRecords = 5 * time.Minute // 活跃记录查询缓存时间
|
||||
QueryCacheTTLWarmupLong = 30 * time.Minute // 预热长期缓存
|
||||
QueryCacheTTLWarmupMedium = 15 * time.Minute // 预热中期缓存
|
||||
|
||||
// 缓存键模式常量
|
||||
QueryCachePatternTable = "gorm_cache:certifications:*"
|
||||
QueryCachePatternUser = "certification:user_id:*"
|
||||
)
|
||||
|
||||
// ================ Repository 实现 ================
|
||||
|
||||
// GormCertificationQueryRepository 认证查询仓储GORM实现
|
||||
//
|
||||
// 特性说明:
|
||||
// - 基于 CachedBaseRepositoryImpl 实现自动缓存管理
|
||||
// - 支持多级缓存策略(主键查询30分钟,业务查询15分钟,搜索2分钟)
|
||||
// - 自动缓存失效:写操作时自动清理相关缓存
|
||||
// - 智能缓存选择:根据查询复杂度自动选择缓存策略
|
||||
// - 内置监控支持:提供缓存统计和性能监控
|
||||
type GormCertificationQueryRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CertificationQueryRepository = (*GormCertificationQueryRepository)(nil)
|
||||
|
||||
// NewGormCertificationQueryRepository 创建认证查询仓储
|
||||
//
|
||||
// 参数:
|
||||
// - db: GORM数据库连接实例
|
||||
// - logger: 日志记录器
|
||||
//
|
||||
// 返回:
|
||||
// - repositories.CertificationQueryRepository: 仓储接口实现
|
||||
func NewGormCertificationQueryRepository(
|
||||
db *gorm.DB,
|
||||
logger *zap.Logger,
|
||||
) repositories.CertificationQueryRepository {
|
||||
return &GormCertificationQueryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, CertificationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础查询操作 ================
|
||||
|
||||
// GetByID 根据ID获取认证
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 使用智能主键查询,自动缓存30分钟
|
||||
// - 主键查询命中率高,适合长期缓存
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) GetByID(ctx context.Context, id string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
if err := r.SmartGetByID(ctx, id, &cert); err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("认证记录不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询认证记录失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取认证
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 业务查询,缓存15分钟
|
||||
// - 用户查询频率较高,适合中期缓存
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - userID: 用户ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) GetByUserID(ctx context.Context, userID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.SmartGetByField(ctx, &cert, "user_id", userID, QueryCacheTTLUserQuery)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("用户尚未创建认证申请")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户认证记录失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// Exists 检查认证是否存在
|
||||
func (r *GormCertificationQueryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.Certification{})
|
||||
}
|
||||
|
||||
func (r *GormCertificationQueryRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Certification{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("查询用户认证是否存在失败: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ================ 列表查询 ================
|
||||
|
||||
// List 分页列表查询
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 搜索查询:短期缓存2分钟(避免频繁数据库查询但保证实时性)
|
||||
// - 常规列表:智能缓存(根据查询复杂度自动选择缓存策略)
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - query: 列表查询条件
|
||||
//
|
||||
// 返回:
|
||||
// - []*entities.Certification: 查询结果列表
|
||||
// - int64: 总记录数
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) List(ctx context.Context, query *queries.ListCertificationsQuery) ([]*entities.Certification, int64, error) {
|
||||
db := r.GetDB(ctx).Model(&entities.Certification{})
|
||||
|
||||
// 应用过滤条件
|
||||
if query.UserID != "" {
|
||||
db = db.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Status != "" {
|
||||
db = db.Where("status = ?", query.Status)
|
||||
}
|
||||
if len(query.Statuses) > 0 {
|
||||
db = db.Where("status IN ?", query.Statuses)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
var total int64
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("查询认证总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 应用排序和分页
|
||||
if query.SortBy != "" {
|
||||
orderClause := query.SortBy
|
||||
if query.SortOrder != "" {
|
||||
orderClause += " " + strings.ToUpper(query.SortOrder)
|
||||
}
|
||||
db = db.Order(orderClause)
|
||||
} else {
|
||||
db = db.Order("created_at DESC")
|
||||
}
|
||||
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
db = db.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 执行查询
|
||||
var certifications []*entities.Certification
|
||||
if err := db.Find(&certifications).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("查询认证列表失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, total, nil
|
||||
}
|
||||
|
||||
// ListByUserIDs 根据用户ID列表查询
|
||||
func (r *GormCertificationQueryRepository) ListByUserIDs(ctx context.Context, userIDs []string) ([]*entities.Certification, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return []*entities.Certification{}, nil
|
||||
}
|
||||
|
||||
var certifications []*entities.Certification
|
||||
if err := r.GetDB(ctx).Where("user_id IN ?", userIDs).Order("created_at DESC").Find(&certifications).Error; err != nil {
|
||||
return nil, fmt.Errorf("根据用户ID列表查询认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态查询
|
||||
func (r *GormCertificationQueryRepository) ListByStatus(ctx context.Context, status enums.CertificationStatus, limit int) ([]*entities.Certification, error) {
|
||||
db := r.GetDB(ctx).Where("status = ?", status).Order("created_at DESC")
|
||||
if limit > 0 {
|
||||
db = db.Limit(limit)
|
||||
}
|
||||
|
||||
var certifications []*entities.Certification
|
||||
if err := db.Find(&certifications).Error; err != nil {
|
||||
return nil, fmt.Errorf("根据状态查询认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// ================ 业务查询 ================
|
||||
|
||||
// FindByAuthFlowID 根据认证流程ID查询
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 业务查询,缓存15分钟
|
||||
// - 回调查询频率较高
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - authFlowID: 认证流程ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) FindByAuthFlowID(ctx context.Context, authFlowID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.SmartGetByField(ctx, &cert, "auth_flow_id", authFlowID, QueryCacheTTLBusinessQuery)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("认证流程不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("根据认证流程ID查询失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// FindByEsignFlowID 根据e签宝流程ID查询
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 业务查询,缓存15分钟
|
||||
// - 回调查询频率较高
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - esignFlowID: e签宝流程ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) FindByEsignFlowID(ctx context.Context, esignFlowID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.SmartGetByField(ctx, &cert, "esign_flow_id", esignFlowID, QueryCacheTTLBusinessQuery)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("e签宝流程不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("根据e签宝流程ID查询失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// ListPendingRetry 查询待重试的认证
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 管理查询,不缓存保证数据实时性
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - maxRetryCount: 最大重试次数
|
||||
//
|
||||
// 返回:
|
||||
// - []*entities.Certification: 待重试的认证列表
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) ListPendingRetry(ctx context.Context, maxRetryCount int) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status IN ? AND retry_count < ?",
|
||||
[]enums.CertificationStatus{
|
||||
enums.StatusInfoRejected,
|
||||
enums.StatusContractRejected,
|
||||
enums.StatusContractExpired,
|
||||
},
|
||||
maxRetryCount).
|
||||
Order("created_at ASC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询待重试认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetPendingCertifications 获取待处理认证
|
||||
func (r *GormCertificationQueryRepository) GetPendingCertifications(ctx context.Context) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status IN ?", []enums.CertificationStatus{
|
||||
enums.StatusPending,
|
||||
enums.StatusInfoSubmitted,
|
||||
}).
|
||||
Order("created_at ASC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询待处理认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetExpiredContracts 获取过期合同
|
||||
func (r *GormCertificationQueryRepository) GetExpiredContracts(ctx context.Context) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status = ?", enums.StatusContractExpired).
|
||||
Order("updated_at DESC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询过期合同失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetCertificationsByDateRange 根据日期范围获取认证
|
||||
func (r *GormCertificationQueryRepository) GetCertificationsByDateRange(ctx context.Context, startDate, endDate time.Time) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.GetDB(ctx).
|
||||
Where("created_at BETWEEN ? AND ?", startDate, endDate).
|
||||
Order("created_at DESC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("根据日期范围查询认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetUserActiveCertification 获取用户当前活跃认证
|
||||
func (r *GormCertificationQueryRepository) GetUserActiveCertification(ctx context.Context, userID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND status NOT IN ?", userID, []enums.CertificationStatus{
|
||||
enums.StatusContractSigned,
|
||||
enums.StatusInfoRejected,
|
||||
enums.StatusContractRejected,
|
||||
enums.StatusContractExpired,
|
||||
}).
|
||||
First(&cert).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("用户没有活跃的认证申请")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户活跃认证失败: %w", err)
|
||||
}
|
||||
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// ================ 统计查询 ================
|
||||
|
||||
|
||||
|
||||
// CountByFailureReason 按失败原因统计
|
||||
func (r *GormCertificationQueryRepository) CountByFailureReason(ctx context.Context, reason enums.FailureReason) (int64, error) {
|
||||
var count int64
|
||||
if err := r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).Where("failure_reason = ?", reason).Count(&count).Error; err != nil {
|
||||
return 0, fmt.Errorf("按失败原因统计认证失败: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetProgressStatistics 获取进度统计
|
||||
func (r *GormCertificationQueryRepository) GetProgressStatistics(ctx context.Context) (*repositories.CertificationProgressStats, error) {
|
||||
// 简化实现
|
||||
return &repositories.CertificationProgressStats{
|
||||
StatusProgress: make(map[enums.CertificationStatus]int64),
|
||||
ProgressDistribution: make(map[int]int64),
|
||||
StageTimeStats: make(map[string]*repositories.CertificationStageTimeInfo),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchByCompanyName 按公司名搜索
|
||||
func (r *GormCertificationQueryRepository) SearchByCompanyName(ctx context.Context, companyName string, limit int) ([]*entities.Certification, error) {
|
||||
// 简化实现,暂时返回空结果
|
||||
r.GetLogger().Warn("按公司名搜索功能待实现,需要企业信息服务支持")
|
||||
return []*entities.Certification{}, nil
|
||||
}
|
||||
|
||||
// SearchByLegalPerson 按法人搜索
|
||||
func (r *GormCertificationQueryRepository) SearchByLegalPerson(ctx context.Context, legalPersonName string, limit int) ([]*entities.Certification, error) {
|
||||
// 简化实现,暂时返回空结果
|
||||
r.GetLogger().Warn("按法人搜索功能待实现,需要企业信息服务支持")
|
||||
return []*entities.Certification{}, nil
|
||||
}
|
||||
|
||||
// InvalidateCache 清除缓存
|
||||
func (r *GormCertificationQueryRepository) InvalidateCache(ctx context.Context, keys ...string) error {
|
||||
// 简化实现,暂不处理缓存
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshCache 刷新缓存
|
||||
func (r *GormCertificationQueryRepository) RefreshCache(ctx context.Context, certificationID string) error {
|
||||
// 简化实现,暂不处理缓存
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 缓存管理方法 ================
|
||||
|
||||
// WarmupCache 预热认证查询缓存
|
||||
//
|
||||
// 业务说明:
|
||||
// - 系统启动时预热常用查询的缓存
|
||||
// - 提升首次访问的响应速度
|
||||
//
|
||||
// 预热策略:
|
||||
// - 活跃认证:30分钟长期缓存
|
||||
// - 待处理认证:15分钟中期缓存
|
||||
func (r *GormCertificationQueryRepository) WarmupCache(ctx context.Context) error {
|
||||
r.GetLogger().Info("开始预热认证查询缓存")
|
||||
|
||||
queries := []database.WarmupQuery{
|
||||
{
|
||||
Name: "active_certifications",
|
||||
TTL: QueryCacheTTLWarmupLong,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
{
|
||||
Name: "pending_certifications",
|
||||
TTL: QueryCacheTTLWarmupMedium,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
}
|
||||
|
||||
return r.WarmupCommonQueries(ctx, queries)
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计信息
|
||||
//
|
||||
// 返回当前Repository的缓存使用统计,包括:
|
||||
// - 基础缓存信息(命中率、键数量等)
|
||||
// - 特定的缓存模式列表
|
||||
// - 性能指标
|
||||
func (r *GormCertificationQueryRepository) GetCacheStats() map[string]interface{} {
|
||||
stats := r.GetCacheInfo()
|
||||
stats["specific_patterns"] = []string{
|
||||
QueryCachePatternTable,
|
||||
QueryCachePatternUser,
|
||||
}
|
||||
return stats
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hyapi-server/internal/domains/certification/entities"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
EnterpriseInfoSubmitRecordsTable = "enterprise_info_submit_records"
|
||||
)
|
||||
|
||||
type GormEnterpriseInfoSubmitRecordRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.EnterpriseInfoSubmitRecord{})
|
||||
}
|
||||
|
||||
func NewGormEnterpriseInfoSubmitRecordRepository(db *gorm.DB, logger *zap.Logger) *GormEnterpriseInfoSubmitRecordRepository {
|
||||
return &GormEnterpriseInfoSubmitRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, EnterpriseInfoSubmitRecordsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Create(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
|
||||
return r.CreateEntity(ctx, record)
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Update(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
|
||||
return r.UpdateEntity(ctx, record)
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Exists(ctx context.Context, ID string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, ID, &entities.EnterpriseInfoSubmitRecord{})
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindByID(ctx context.Context, id string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).Where("id = ?", id).First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Order("submit_at DESC").
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestVerifiedByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND status = ?", userID, "verified").
|
||||
Order("verified_at DESC").
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// ExistsByUnifiedSocialCodeExcludeUser 检查该统一社会信用代码是否已被其他用户占用(已提交或已通过验证的记录)
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) ExistsByUnifiedSocialCodeExcludeUser(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
if unifiedSocialCode == "" {
|
||||
return false, nil
|
||||
}
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{}).
|
||||
Where("unified_social_code = ? AND status IN (?, ?)", unifiedSocialCode, "submitted", "verified")
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("user_id != ?", excludeUserID)
|
||||
}
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) List(ctx context.Context, filter repositories.ListSubmitRecordsFilter) (*repositories.ListSubmitRecordsResult, error) {
|
||||
base := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{})
|
||||
if filter.CertificationStatus != "" {
|
||||
base = base.Joins("JOIN certifications ON certifications.user_id = enterprise_info_submit_records.user_id AND certifications.deleted_at IS NULL").
|
||||
Where("certifications.status = ?", filter.CertificationStatus)
|
||||
}
|
||||
if filter.CompanyName != "" {
|
||||
base = base.Where("enterprise_info_submit_records.company_name LIKE ?", "%"+filter.CompanyName+"%")
|
||||
}
|
||||
if filter.LegalPersonPhone != "" {
|
||||
base = base.Where("enterprise_info_submit_records.legal_person_phone = ?", filter.LegalPersonPhone)
|
||||
}
|
||||
if filter.LegalPersonName != "" {
|
||||
base = base.Where("enterprise_info_submit_records.legal_person_name LIKE ?", "%"+filter.LegalPersonName+"%")
|
||||
}
|
||||
var total int64
|
||||
if err := base.Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if filter.PageSize <= 0 {
|
||||
filter.PageSize = 10
|
||||
}
|
||||
if filter.Page <= 0 {
|
||||
filter.Page = 1
|
||||
}
|
||||
offset := (filter.Page - 1) * filter.PageSize
|
||||
var records []*entities.EnterpriseInfoSubmitRecord
|
||||
q := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{})
|
||||
if filter.CertificationStatus != "" {
|
||||
q = q.Joins("JOIN certifications ON certifications.user_id = enterprise_info_submit_records.user_id AND certifications.deleted_at IS NULL").
|
||||
Where("certifications.status = ?", filter.CertificationStatus)
|
||||
}
|
||||
if filter.CompanyName != "" {
|
||||
q = q.Where("enterprise_info_submit_records.company_name LIKE ?", "%"+filter.CompanyName+"%")
|
||||
}
|
||||
if filter.LegalPersonPhone != "" {
|
||||
q = q.Where("enterprise_info_submit_records.legal_person_phone = ?", filter.LegalPersonPhone)
|
||||
}
|
||||
if filter.LegalPersonName != "" {
|
||||
q = q.Where("enterprise_info_submit_records.legal_person_name LIKE ?", "%"+filter.LegalPersonName+"%")
|
||||
}
|
||||
err := q.Order("enterprise_info_submit_records.submit_at DESC").Offset(offset).Limit(filter.PageSize).Find(&records).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &repositories.ListSubmitRecordsResult{Records: records, Total: total}, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
AlipayOrdersTable = "typay_orders"
|
||||
)
|
||||
|
||||
type GormAlipayOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.AlipayOrderRepository = (*GormAlipayOrderRepository)(nil)
|
||||
|
||||
func NewGormAlipayOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.AlipayOrderRepository {
|
||||
return &GormAlipayOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, AlipayOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Create(ctx context.Context, order entities.AlipayOrder) (entities.AlipayOrder, error) {
|
||||
err := r.CreateEntity(ctx, &order)
|
||||
return order, err
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByID(ctx context.Context, id string) (entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.AlipayOrder{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.AlipayOrder{}, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.AlipayOrder, error) {
|
||||
var orders []entities.AlipayOrder
|
||||
err := r.GetDB(ctx).
|
||||
Joins("JOIN recharge_records ON typay_orders.recharge_id = recharge_records.id").
|
||||
Where("recharge_records.user_id = ?", userID).
|
||||
Order("typay_orders.created_at DESC").
|
||||
Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Update(ctx context.Context, order entities.AlipayOrder) error {
|
||||
return r.UpdateEntity(ctx, &order)
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.AlipayOrderStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.AlipayOrder{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
"hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
PurchaseOrdersTable = "ty_purchase_orders"
|
||||
)
|
||||
|
||||
type GormPurchaseOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.PurchaseOrderRepository = (*GormPurchaseOrderRepository)(nil)
|
||||
|
||||
func NewGormPurchaseOrderRepository(db *gorm.DB, logger *zap.Logger) repositories.PurchaseOrderRepository {
|
||||
return &GormPurchaseOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, PurchaseOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) Create(ctx context.Context, order *entities.PurchaseOrder) (*entities.PurchaseOrder, error) {
|
||||
err := r.CreateEntity(ctx, order)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) Update(ctx context.Context, order *entities.PurchaseOrder) error {
|
||||
return r.UpdateEntity(ctx, order)
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByID(ctx context.Context, id string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByOrderNo(ctx context.Context, orderNo string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).Where("order_no = ?", orderNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.PurchaseOrder, int64, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
var count int64
|
||||
|
||||
db := r.GetDB(ctx).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
err := db.Model(&entities.PurchaseOrder{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
err = db.Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
result := make([]*entities.PurchaseOrder, len(orders))
|
||||
for i := range orders {
|
||||
result[i] = &orders[i]
|
||||
}
|
||||
|
||||
return result, count, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByUserIDAndProductID(ctx context.Context, userID, productID string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND product_id = ? AND status = ?", userID, productID, entities.PurchaseOrderStatusPaid).
|
||||
First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByPaymentTypeAndTransactionID(ctx context.Context, paymentType, transactionID string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Where("payment_type = ? AND trade_no = ?", paymentType, transactionID).
|
||||
First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByTradeNo(ctx context.Context, tradeNo string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).Where("trade_no = ?", tradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) UpdatePaymentStatus(ctx context.Context, orderID string, status entities.PurchaseOrderStatus, tradeNo *string, payAmount, receiptAmount *string, paymentTime *time.Time) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if tradeNo != nil {
|
||||
updates["trade_no"] = *tradeNo
|
||||
}
|
||||
|
||||
if payAmount != nil {
|
||||
updates["pay_amount"] = *payAmount
|
||||
}
|
||||
|
||||
if receiptAmount != nil {
|
||||
updates["receipt_amount"] = *receiptAmount
|
||||
}
|
||||
|
||||
if paymentTime != nil {
|
||||
updates["pay_time"] = *paymentTime
|
||||
updates["notify_time"] = *paymentTime
|
||||
}
|
||||
|
||||
err := r.GetDB(ctx).
|
||||
Model(&entities.PurchaseOrder{}).
|
||||
Where("id = ?", orderID).
|
||||
Updates(updates).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetUserPurchasedProductCodes(ctx context.Context, userID string) ([]string, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Select("product_code").
|
||||
Where("user_id = ? AND status = ?", userID, entities.PurchaseOrderStatusPaid).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
codesMap := make(map[string]bool)
|
||||
for _, order := range orders {
|
||||
// 添加主产品编号
|
||||
if order.ProductCode != "" {
|
||||
codesMap[order.ProductCode] = true
|
||||
}
|
||||
}
|
||||
|
||||
codes := make([]string, 0, len(codesMap))
|
||||
for code := range codesMap {
|
||||
codes = append(codes, code)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetUserPaidProductIDs(ctx context.Context, userID string) ([]string, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Select("product_id").
|
||||
Where("user_id = ? AND status = ?", userID, entities.PurchaseOrderStatusPaid).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idsMap := make(map[string]bool)
|
||||
for _, order := range orders {
|
||||
// 添加主产品ID
|
||||
if order.ProductID != "" {
|
||||
idsMap[order.ProductID] = true
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(idsMap))
|
||||
for id := range idsMap {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) HasUserPurchased(ctx context.Context, userID string, productCode string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.PurchaseOrder{}).
|
||||
Where("user_id = ? AND product_code = ? AND status = ?", userID, productCode, entities.PurchaseOrderStatusPaid).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetExpiringOrders(ctx context.Context, before time.Time, limit int) ([]*entities.PurchaseOrder, error) {
|
||||
// 购买订单实体没有过期时间字段,此方法返回空结果
|
||||
return []*entities.PurchaseOrder{}, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetExpiredOrders(ctx context.Context, limit int) ([]*entities.PurchaseOrder, error) {
|
||||
// 购买订单实体没有过期时间字段,此方法返回空结果
|
||||
return []*entities.PurchaseOrder{}, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByStatus(ctx context.Context, status entities.PurchaseOrderStatus, limit, offset int) ([]*entities.PurchaseOrder, int64, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
var count int64
|
||||
|
||||
db := r.GetDB(ctx).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
err := db.Model(&entities.PurchaseOrder{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
err = db.Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
result := make([]*entities.PurchaseOrder, len(orders))
|
||||
for i := range orders {
|
||||
result[i] = &orders[i]
|
||||
}
|
||||
|
||||
return result, count, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByFilters(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.PurchaseOrder, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
if userID, ok := filters["user_id"]; ok {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if status, ok := filters["status"]; ok && status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if paymentType, ok := filters["payment_type"]; ok && paymentType != "" {
|
||||
db = db.Where("payment_type = ?", paymentType)
|
||||
}
|
||||
if payChannel, ok := filters["pay_channel"]; ok && payChannel != "" {
|
||||
db = db.Where("pay_channel = ?", payChannel)
|
||||
}
|
||||
if startTime, ok := filters["start_time"]; ok && startTime != "" {
|
||||
db = db.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"]; ok && endTime != "" {
|
||||
db = db.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用排序和分页
|
||||
// 默认按创建时间倒序
|
||||
db = db.Order("created_at DESC")
|
||||
|
||||
// 应用分页
|
||||
if options.PageSize > 0 {
|
||||
db = db.Limit(options.PageSize)
|
||||
}
|
||||
|
||||
if options.Page > 0 {
|
||||
db = db.Offset((options.Page - 1) * options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err := db.Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.PurchaseOrder, len(orders))
|
||||
for i := range orders {
|
||||
result[i] = &orders[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) CountByFilters(ctx context.Context, filters map[string]interface{}) (int64, error) {
|
||||
var count int64
|
||||
|
||||
db := r.GetDB(ctx).Model(&entities.PurchaseOrder{})
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
if userID, ok := filters["user_id"]; ok {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if status, ok := filters["status"]; ok && status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if paymentType, ok := filters["payment_type"]; ok && paymentType != "" {
|
||||
db = db.Where("payment_type = ?", paymentType)
|
||||
}
|
||||
if payChannel, ok := filters["pay_channel"]; ok && payChannel != "" {
|
||||
db = db.Where("pay_channel = ?", payChannel)
|
||||
}
|
||||
if startTime, ok := filters["start_time"]; ok && startTime != "" {
|
||||
db = db.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"]; ok && endTime != "" {
|
||||
db = db.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
|
||||
// 执行计数
|
||||
err := db.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,509 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
RechargeRecordsTable = "recharge_records"
|
||||
)
|
||||
|
||||
type GormRechargeRecordRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.RechargeRecordRepository = (*GormRechargeRecordRepository)(nil)
|
||||
|
||||
func NewGormRechargeRecordRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.RechargeRecordRepository {
|
||||
return &GormRechargeRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, RechargeRecordsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Create(ctx context.Context, record entities.RechargeRecord) (entities.RechargeRecord, error) {
|
||||
err := r.CreateEntity(ctx, &record)
|
||||
return record, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByID(ctx context.Context, id string) (entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.SmartGetByID(ctx, id, &record)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.RechargeRecord{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.RechargeRecord{}, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByUserID(ctx context.Context, userID string) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByAlipayOrderID(ctx context.Context, alipayOrderID string) (*entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("alipay_order_id = ?", alipayOrderID).First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByTransferOrderID(ctx context.Context, transferOrderID string) (*entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("transfer_order_id = ?", transferOrderID).First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Update(ctx context.Context, record entities.RechargeRecord) error {
|
||||
return r.UpdateEntity(ctx, &record)
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) UpdateStatus(ctx context.Context, id string, status entities.RechargeStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
|
||||
// 检查是否有 company_name 筛选,如果有则需要 JOIN 表
|
||||
hasCompanyNameFilter := false
|
||||
if options.Filters != nil {
|
||||
if companyName, ok := options.Filters["company_name"].(string); ok && companyName != "" {
|
||||
hasCompanyNameFilter = true
|
||||
}
|
||||
}
|
||||
|
||||
var query *gorm.DB
|
||||
if hasCompanyNameFilter {
|
||||
// 使用 JOIN 查询以支持企业名称筛选
|
||||
query = r.GetDB(ctx).Table("recharge_records rr").
|
||||
Joins("LEFT JOIN users u ON rr.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id")
|
||||
} else {
|
||||
// 普通查询
|
||||
query = r.GetDB(ctx).Model(&entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
// 特殊处理时间范围过滤器
|
||||
if key == "start_time" {
|
||||
if startTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at >= ?", startTime)
|
||||
} else {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
}
|
||||
} else if key == "end_time" {
|
||||
if endTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at <= ?", endTime)
|
||||
} else {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
} else if key == "company_name" {
|
||||
// 处理企业名称筛选
|
||||
if companyName, ok := value.(string); ok && companyName != "" {
|
||||
query = query.Where("ei.company_name LIKE ?", "%"+companyName+"%")
|
||||
}
|
||||
} else if key == "min_amount" {
|
||||
// 处理最小金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount >= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount >= ?", amount)
|
||||
}
|
||||
}
|
||||
} else if key == "max_amount" {
|
||||
// 处理最大金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount <= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount <= ?", amount)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他过滤器使用等值查询
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr."+key+" = ?", value)
|
||||
} else {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.user_id LIKE ? OR rr.transfer_order_id LIKE ? OR rr.alipay_order_id LIKE ? OR rr.wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
} else {
|
||||
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ? OR wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
}
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
|
||||
// 检查是否有 company_name 筛选,如果有则需要 JOIN 表
|
||||
hasCompanyNameFilter := false
|
||||
if options.Filters != nil {
|
||||
if companyName, ok := options.Filters["company_name"].(string); ok && companyName != "" {
|
||||
hasCompanyNameFilter = true
|
||||
}
|
||||
}
|
||||
|
||||
var query *gorm.DB
|
||||
if hasCompanyNameFilter {
|
||||
// 使用 JOIN 查询以支持企业名称筛选
|
||||
query = r.GetDB(ctx).Table("recharge_records rr").
|
||||
Select("rr.*").
|
||||
Joins("LEFT JOIN users u ON rr.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id")
|
||||
} else {
|
||||
// 普通查询
|
||||
query = r.GetDB(ctx).Model(&entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
// 特殊处理 user_ids 过滤器
|
||||
if key == "user_ids" {
|
||||
if userIds, ok := value.(string); ok && userIds != "" {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.user_id IN ?", strings.Split(userIds, ","))
|
||||
} else {
|
||||
query = query.Where("user_id IN ?", strings.Split(userIds, ","))
|
||||
}
|
||||
}
|
||||
} else if key == "company_name" {
|
||||
// 处理企业名称筛选
|
||||
if companyName, ok := value.(string); ok && companyName != "" {
|
||||
query = query.Where("ei.company_name LIKE ?", "%"+companyName+"%")
|
||||
}
|
||||
} else if key == "start_time" {
|
||||
// 处理开始时间范围
|
||||
if startTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at >= ?", startTime)
|
||||
} else {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
}
|
||||
} else if key == "end_time" {
|
||||
// 处理结束时间范围
|
||||
if endTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at <= ?", endTime)
|
||||
} else {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
} else if key == "min_amount" {
|
||||
// 处理最小金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount >= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount >= ?", amount)
|
||||
}
|
||||
}
|
||||
} else if key == "max_amount" {
|
||||
// 处理最大金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount <= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount <= ?", amount)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他过滤器使用等值查询
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr."+key+" = ?", value)
|
||||
} else {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.user_id LIKE ? OR rr.transfer_order_id LIKE ? OR rr.alipay_order_id LIKE ? OR rr.wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
} else {
|
||||
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ? OR wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Order("rr." + options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
} else {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Order("rr.created_at DESC")
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) CreateBatch(ctx context.Context, records []entities.RechargeRecord) error {
|
||||
return r.GetDB(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) UpdateBatch(ctx context.Context, records []entities.RechargeRecord) error {
|
||||
return r.GetDB(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.RechargeRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.RechargeRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormRechargeRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), RechargeRecordsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserId 获取用户总充值金额(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ? AND status = ? AND recharge_type != ?", userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总充值金额(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ? AND status = ? AND recharge_type != ? AND created_at >= ? AND created_at < ?", userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetDailyStatsByUserId 获取用户每日充值统计(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE user_id = $1
|
||||
AND status = $2
|
||||
AND recharge_type != $3
|
||||
AND DATE(created_at) >= $4
|
||||
AND DATE(created_at) <= $5
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStatsByUserId 获取用户每月充值统计(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE user_id = $1
|
||||
AND status = $2
|
||||
AND recharge_type != $3
|
||||
AND created_at >= $4
|
||||
AND created_at <= $5
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemTotalAmount 获取系统总充值金额(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Where("status = ? AND recharge_type != ?", entities.RechargeStatusSuccess, entities.RechargeTypeGift).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemAmountByDateRange 获取系统指定时间范围内的充值金额(排除赠送)
|
||||
// endDate 应该是结束日期当天的次日00:00:00(日统计)或下个月1号00:00:00(月统计),使用 < 而不是 <=
|
||||
func (r *GormRechargeRecordRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Where("status = ? AND recharge_type != ? AND created_at >= ? AND created_at < ?", entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemDailyStats 获取系统每日充值统计(排除赠送)
|
||||
// startDate 和 endDate 应该是时间对象,endDate 应该是结束日期当天的次日00:00:00,使用 < 而不是 <=
|
||||
func (r *GormRechargeRecordRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE status = ?
|
||||
AND recharge_type != ?
|
||||
AND created_at >= ?
|
||||
AND created_at < ?
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyStats 获取系统每月充值统计(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE status = ?
|
||||
AND recharge_type != ?
|
||||
AND created_at >= ?
|
||||
AND created_at < ?
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// parseAmount 解析金额值,支持string、int、int64类型,转换为decimal.Decimal
|
||||
func (r *GormRechargeRecordRepository) parseAmount(value interface{}) (decimal.Decimal, error) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return decimal.Zero, fmt.Errorf("empty string")
|
||||
}
|
||||
return decimal.NewFromString(v)
|
||||
case int:
|
||||
return decimal.NewFromInt(int64(v)), nil
|
||||
case int64:
|
||||
return decimal.NewFromInt(v), nil
|
||||
case float64:
|
||||
return decimal.NewFromFloat(v), nil
|
||||
case decimal.Decimal:
|
||||
return v, nil
|
||||
default:
|
||||
return decimal.Zero, fmt.Errorf("unsupported type: %T", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
WalletsTable = "wallets"
|
||||
)
|
||||
|
||||
type GormWalletRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
|
||||
|
||||
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
|
||||
return &GormWalletRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
|
||||
err := r.CreateEntity(ctx, &wallet)
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.SmartGetByID(ctx, id, &wallet)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Wallet{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Wallet{}, err
|
||||
}
|
||||
return wallet, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
|
||||
return r.UpdateEntity(ctx, &wallet)
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Wallet{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
return r.GetDB(ctx).Create(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&wallets).Error
|
||||
return wallets, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
return r.GetDB(ctx).Save(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
query := r.GetDB(ctx).Model(&entities.Wallet{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
return wallets, query.Find(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// UpdateBalanceWithVersion 乐观锁自动重试,最大重试maxRetry次
|
||||
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, amount decimal.Decimal, operation string) (bool, error) {
|
||||
maxRetry := 10
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
// 每次重试都重新获取最新的钱包信息
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("id = ?", walletID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("获取钱包信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 重新计算新余额
|
||||
var newBalance decimal.Decimal
|
||||
switch operation {
|
||||
case "add":
|
||||
newBalance = wallet.Balance.Add(amount)
|
||||
case "subtract":
|
||||
newBalance = wallet.Balance.Sub(amount)
|
||||
default:
|
||||
return false, fmt.Errorf("不支持的操作类型: %s", operation)
|
||||
}
|
||||
|
||||
// 乐观锁更新
|
||||
result := r.GetDB(ctx).Model(&entities.Wallet{}).
|
||||
Where("id = ? AND version = ?", walletID, wallet.Version).
|
||||
Updates(map[string]interface{}{
|
||||
"balance": newBalance.String(),
|
||||
"version": wallet.Version + 1,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 乐观锁冲突,继续重试
|
||||
// 注意:这里可以添加日志记录,但需要确保logger可用
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
|
||||
}
|
||||
|
||||
// UpdateBalanceByUserID 乐观锁更新(通过用户ID直接更新,使用原生SQL)
|
||||
func (r *GormWalletRepository) UpdateBalanceByUserID(ctx context.Context, userID string, amount decimal.Decimal, operation string) (bool, error) {
|
||||
maxRetry := 20 // 增加重试次数
|
||||
baseDelay := 1 // 基础延迟毫秒
|
||||
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
// 每次重试都重新获取最新的钱包信息
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("获取钱包信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 重新计算新余额
|
||||
var newBalance decimal.Decimal
|
||||
switch operation {
|
||||
case "add":
|
||||
newBalance = wallet.Balance.Add(amount)
|
||||
case "subtract":
|
||||
newBalance = wallet.Balance.Sub(amount)
|
||||
default:
|
||||
return false, fmt.Errorf("不支持的操作类型: %s", operation)
|
||||
}
|
||||
|
||||
// 使用原生SQL进行乐观锁更新
|
||||
newVersion := wallet.Version + 1
|
||||
result := r.GetDB(ctx).Exec(`
|
||||
UPDATE wallets
|
||||
SET balance = ?, version = ?, updated_at = NOW()
|
||||
WHERE user_id = ? AND version = ?
|
||||
`, newBalance.String(), newVersion, userID, wallet.Version)
|
||||
|
||||
if result.Error != nil {
|
||||
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 乐观锁冲突,添加指数退避延迟
|
||||
if i < maxRetry-1 {
|
||||
delay := baseDelay * (1 << i) // 指数退避: 1ms, 2ms, 4ms, 8ms...
|
||||
if delay > 50 {
|
||||
delay = 50 // 最大延迟50ms
|
||||
}
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 总钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 激活钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 用户钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户激活钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 用户今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,643 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WalletTransactionWithProduct 包含产品名称的钱包交易记录
|
||||
type WalletTransactionWithProduct struct {
|
||||
entities.WalletTransaction
|
||||
ProductName string `json:"product_name" gorm:"column:product_name"`
|
||||
}
|
||||
|
||||
const (
|
||||
WalletTransactionsTable = "wallet_transactions"
|
||||
)
|
||||
|
||||
type GormWalletTransactionRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WalletTransactionRepository = (*GormWalletTransactionRepository)(nil)
|
||||
|
||||
func NewGormWalletTransactionRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletTransactionRepository {
|
||||
return &GormWalletTransactionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletTransactionsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Create(ctx context.Context, transaction entities.WalletTransaction) (entities.WalletTransaction, error) {
|
||||
err := r.CreateEntity(ctx, &transaction)
|
||||
return transaction, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Update(ctx context.Context, transaction entities.WalletTransaction) error {
|
||||
return r.UpdateEntity(ctx, &transaction)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByID(ctx context.Context, id string) (entities.WalletTransaction, error) {
|
||||
var transaction entities.WalletTransaction
|
||||
err := r.SmartGetByID(ctx, id, &transaction)
|
||||
return transaction, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.WalletTransaction, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
options := database.CacheListOptions{
|
||||
Where: "user_id = ?",
|
||||
Args: []interface{}{userID},
|
||||
Order: "created_at DESC",
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
err := r.ListWithCache(ctx, &transactions, 10*time.Minute, options)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByApiCallID(ctx context.Context, apiCallID string) (*entities.WalletTransaction, error) {
|
||||
var transaction entities.WalletTransaction
|
||||
err := r.FindOne(ctx, &transaction, "api_call_id = ?", apiCallID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 关键词筛选(支持transaction_id和product_name)
|
||||
if keyword, ok := filters["keyword"].(string); ok && keyword != "" {
|
||||
whereCondition += " AND (transaction_id LIKE ? OR product_id IN (SELECT id FROM product WHERE name LIKE ?))"
|
||||
whereArgs = append(whereArgs, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// API调用ID筛选
|
||||
if apiCallId, ok := filters["api_call_id"].(string); ok && apiCallId != "" {
|
||||
whereCondition += " AND api_call_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+apiCallId+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ?", userId)
|
||||
}
|
||||
|
||||
// CountByUserIdAndDateRange 按用户ID和日期范围统计钱包交易次数
|
||||
func (r *GormWalletTransactionRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserId 获取用户总消费金额
|
||||
func (r *GormWalletTransactionRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ?", userId).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总消费金额
|
||||
func (r *GormWalletTransactionRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetDailyStatsByUserId 获取用户每日消费统计
|
||||
func (r *GormWalletTransactionRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE user_id = $1
|
||||
AND DATE(created_at) >= $2
|
||||
AND DATE(created_at) <= $3
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStatsByUserId 获取用户每月消费统计
|
||||
func (r *GormWalletTransactionRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE user_id = $1
|
||||
AND created_at >= $2
|
||||
AND created_at <= $3
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 实现interfaces.Repository接口的其他方法
|
||||
func (r *GormWalletTransactionRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.WalletTransaction, error) {
|
||||
var transactions []entities.WalletTransaction
|
||||
err := r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
return r.CountWithOptions(ctx, &entities.WalletTransaction{}, options)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) CreateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
|
||||
return r.CreateBatchEntity(ctx, &transactions)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.WalletTransaction, error) {
|
||||
var transactions []entities.WalletTransaction
|
||||
err := r.GetEntitiesByIDs(ctx, ids, &transactions)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) UpdateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
|
||||
return r.UpdateBatchEntity(ctx, &transactions)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.DeleteBatchEntity(ctx, ids, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) WithTx(tx interface{}) interfaces.Repository[entities.WalletTransaction] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletTransactionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletTransactionsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
|
||||
var transactionsWithProduct []*WalletTransactionWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "wt.user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND wt.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND wt.amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND wt.amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("wt." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("wt.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction并构建产品名称映射
|
||||
var transactions []*entities.WalletTransaction
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if t.ProductName != "" {
|
||||
productNameMap[transaction.ProductID] = t.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, transactions, total, nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndProductName 管理端:根据条件筛选所有钱包交易记录(包含产品名称)
|
||||
func (r *GormWalletTransactionRepository) ListWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
|
||||
var transactionsWithProduct []*WalletTransactionWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "1=1"
|
||||
whereArgs := []interface{}{}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 用户ID筛选(支持单个和多个)
|
||||
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
|
||||
// 多个用户ID,逗号分隔
|
||||
userIdsList := strings.Split(userIds, ",")
|
||||
whereCondition += " AND wt.user_id IN ?"
|
||||
whereArgs = append(whereArgs, userIdsList)
|
||||
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
|
||||
// 单个用户ID
|
||||
whereCondition += " AND wt.user_id = ?"
|
||||
whereArgs = append(whereArgs, userId)
|
||||
}
|
||||
|
||||
// 产品ID筛选(支持多个)
|
||||
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
|
||||
// 多个产品ID,逗号分隔
|
||||
productIdsList := strings.Split(productIds, ",")
|
||||
whereCondition += " AND wt.product_id IN ?"
|
||||
whereArgs = append(whereArgs, productIdsList)
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND wt.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 企业名称筛选
|
||||
if companyName, ok := filters["company_name"].(string); ok && companyName != "" {
|
||||
whereCondition += " AND ei.company_name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+companyName+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND wt.amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND wt.amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
// 需要JOIN product表获取产品名称,JOIN users和enterprise_infos表获取企业名称
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id").
|
||||
Joins("LEFT JOIN users u ON wt.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("wt." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("wt.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction并构建产品名称映射
|
||||
var transactions []*entities.WalletTransaction
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if t.ProductName != "" {
|
||||
productNameMap[transaction.ProductID] = t.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, transactions, total, nil
|
||||
}
|
||||
|
||||
// ExportWithFiltersAndProductName 导出钱包交易记录(包含产品名称和企业信息)
|
||||
func (r *GormWalletTransactionRepository) ExportWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}) ([]*entities.WalletTransaction, error) {
|
||||
var transactionsWithProduct []WalletTransactionWithProduct
|
||||
|
||||
// 构建查询
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id")
|
||||
|
||||
// 构建WHERE条件
|
||||
var whereConditions []string
|
||||
var whereArgs []interface{}
|
||||
|
||||
// 用户ID筛选
|
||||
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
|
||||
whereConditions = append(whereConditions, "wt.user_id IN (?)")
|
||||
whereArgs = append(whereArgs, strings.Split(userIds, ","))
|
||||
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
|
||||
whereConditions = append(whereConditions, "wt.user_id = ?")
|
||||
whereArgs = append(whereArgs, userId)
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereConditions = append(whereConditions, "wt.created_at >= ?")
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereConditions = append(whereConditions, "wt.created_at <= ?")
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereConditions = append(whereConditions, "wt.transaction_id LIKE ?")
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereConditions = append(whereConditions, "p.name LIKE ?")
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 产品ID列表筛选
|
||||
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
|
||||
whereConditions = append(whereConditions, "wt.product_id IN (?)")
|
||||
whereArgs = append(whereArgs, strings.Split(productIds, ","))
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereConditions = append(whereConditions, "wt.amount >= ?")
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereConditions = append(whereConditions, "wt.amount <= ?")
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
|
||||
// 应用WHERE条件
|
||||
if len(whereConditions) > 0 {
|
||||
query = query.Where(strings.Join(whereConditions, " AND "), whereArgs...)
|
||||
}
|
||||
|
||||
// 排序
|
||||
query = query.Order("wt.created_at DESC")
|
||||
|
||||
// 执行查询
|
||||
err := query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction
|
||||
var transactions []*entities.WalletTransaction
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
}
|
||||
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetSystemTotalAmount 获取系统总消费金额
|
||||
func (r *GormWalletTransactionRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemAmountByDateRange 获取系统指定时间范围内的消费金额
|
||||
// endDate 应该是结束日期当天的次日00:00:00(日统计)或下个月1号00:00:00(月统计),使用 < 而不是 <=
|
||||
func (r *GormWalletTransactionRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Where("created_at >= ? AND created_at < ?", startDate, endDate).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemDailyStats 获取系统每日消费统计
|
||||
func (r *GormWalletTransactionRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE DATE(created_at) >= ?
|
||||
AND DATE(created_at) <= ?
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyStats 获取系统每月消费统计
|
||||
func (r *GormWalletTransactionRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE created_at >= ?
|
||||
AND created_at < ?
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
WechatOrdersTable = "typay_orders"
|
||||
)
|
||||
|
||||
type GormWechatOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WechatOrderRepository = (*GormWechatOrderRepository)(nil)
|
||||
|
||||
func NewGormWechatOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WechatOrderRepository {
|
||||
return &GormWechatOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WechatOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Create(ctx context.Context, order entities.WechatOrder) (entities.WechatOrder, error) {
|
||||
err := r.CreateEntity(ctx, &order)
|
||||
return order, err
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByID(ctx context.Context, id string) (entities.WechatOrder, error) {
|
||||
var order entities.WechatOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.WechatOrder{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.WechatOrder{}, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.WechatOrder, error) {
|
||||
var order entities.WechatOrder
|
||||
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.WechatOrder, error) {
|
||||
var order entities.WechatOrder
|
||||
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.WechatOrder, error) {
|
||||
var orders []entities.WechatOrder
|
||||
// 需要通过充值记录关联查询,这里简化处理
|
||||
err := r.GetDB(ctx).Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Update(ctx context.Context, order entities.WechatOrder) error {
|
||||
return r.UpdateEntity(ctx, &order)
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.WechatOrderStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.WechatOrder{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.WechatOrder{})
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.WechatOrder{})
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
"hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/domains/finance/value_objects"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormInvoiceApplicationRepository 发票申请仓储的GORM实现
|
||||
type GormInvoiceApplicationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormInvoiceApplicationRepository 创建发票申请仓储
|
||||
func NewGormInvoiceApplicationRepository(db *gorm.DB) repositories.InvoiceApplicationRepository {
|
||||
return &GormInvoiceApplicationRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建发票申请
|
||||
func (r *GormInvoiceApplicationRepository) Create(ctx context.Context, application *entities.InvoiceApplication) error {
|
||||
return r.db.WithContext(ctx).Create(application).Error
|
||||
}
|
||||
|
||||
// Update 更新发票申请
|
||||
func (r *GormInvoiceApplicationRepository) Update(ctx context.Context, application *entities.InvoiceApplication) error {
|
||||
return r.db.WithContext(ctx).Save(application).Error
|
||||
}
|
||||
|
||||
// Save 保存发票申请
|
||||
func (r *GormInvoiceApplicationRepository) Save(ctx context.Context, application *entities.InvoiceApplication) error {
|
||||
return r.db.WithContext(ctx).Save(application).Error
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByID(ctx context.Context, id string) (*entities.InvoiceApplication, error) {
|
||||
var application entities.InvoiceApplication
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&application).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &application, nil
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找发票申请列表
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
// 获取总数
|
||||
err := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.WithContext(ctx).Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindPendingApplications 查找待处理的发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindPendingApplications(ctx context.Context, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
// 获取总数
|
||||
err := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).
|
||||
Where("status = ?", entities.ApplicationStatusPending).
|
||||
Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.WithContext(ctx).
|
||||
Where("status = ?", entities.ApplicationStatusPending).
|
||||
Order("created_at ASC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByUserIDAndStatus 根据用户ID和状态查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserIDAndStatus(ctx context.Context, userID string, status entities.ApplicationStatus, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID)
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByUserIDAndStatusWithTimeRange 根据用户ID、状态和时间范围查找发票申请列表
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserIDAndStatusWithTimeRange(ctx context.Context, userID string, status entities.ApplicationStatus, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID)
|
||||
|
||||
// 添加状态筛选
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 添加时间范围筛选
|
||||
if startTime != nil {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime != nil {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByStatus(ctx context.Context, status entities.ApplicationStatus) ([]*entities.InvoiceApplication, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&applications).Error
|
||||
return applications, err
|
||||
}
|
||||
|
||||
// GetUserInvoiceInfo 获取用户发票信息
|
||||
|
||||
|
||||
|
||||
|
||||
// GetUserTotalInvoicedAmount 获取用户已开票总金额
|
||||
func (r *GormInvoiceApplicationRepository) GetUserTotalInvoicedAmount(ctx context.Context, userID string) (string, error) {
|
||||
var total string
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&entities.InvoiceApplication{}).
|
||||
Select("COALESCE(SUM(CAST(amount AS DECIMAL(10,2))), '0')").
|
||||
Where("user_id = ? AND status = ?", userID, entities.ApplicationStatusCompleted).
|
||||
Scan(&total).Error
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetUserTotalAppliedAmount 获取用户申请开票总金额
|
||||
func (r *GormInvoiceApplicationRepository) GetUserTotalAppliedAmount(ctx context.Context, userID string) (string, error) {
|
||||
var total string
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&entities.InvoiceApplication{}).
|
||||
Select("COALESCE(SUM(CAST(amount AS DECIMAL(10,2))), '0')").
|
||||
Where("user_id = ?", userID).
|
||||
Scan(&total).Error
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
// FindByUserIDAndInvoiceType 根据用户ID和发票类型查找申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserIDAndInvoiceType(ctx context.Context, userID string, invoiceType value_objects.InvoiceType, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ? AND invoice_type = ?", userID, invoiceType)
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByDateRange 根据日期范围查找申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByDateRange(ctx context.Context, startDate, endDate string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{})
|
||||
if startDate != "" {
|
||||
query = query.Where("DATE(created_at) >= ?", startDate)
|
||||
}
|
||||
if endDate != "" {
|
||||
query = query.Where("DATE(created_at) <= ?", endDate)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// SearchApplications 搜索发票申请
|
||||
func (r *GormInvoiceApplicationRepository) SearchApplications(ctx context.Context, keyword string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).
|
||||
Where("company_name LIKE ? OR email LIKE ? OR tax_number LIKE ?",
|
||||
fmt.Sprintf("%%%s%%", keyword),
|
||||
fmt.Sprintf("%%%s%%", keyword),
|
||||
fmt.Sprintf("%%%s%%", keyword))
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByStatusWithTimeRange 根据状态和时间范围查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByStatusWithTimeRange(ctx context.Context, status entities.ApplicationStatus, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("status = ?", status)
|
||||
|
||||
// 添加时间范围筛选
|
||||
if startTime != nil {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime != nil {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindAllWithTimeRange 根据时间范围查找所有发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindAllWithTimeRange(ctx context.Context, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{})
|
||||
|
||||
// 添加时间范围筛选
|
||||
if startTime != nil {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime != nil {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
"hyapi-server/internal/domains/finance/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormUserInvoiceInfoRepository 用户开票信息仓储的GORM实现
|
||||
type GormUserInvoiceInfoRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormUserInvoiceInfoRepository 创建用户开票信息仓储
|
||||
func NewGormUserInvoiceInfoRepository(db *gorm.DB) repositories.UserInvoiceInfoRepository {
|
||||
return &GormUserInvoiceInfoRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建用户开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) Create(ctx context.Context, info *entities.UserInvoiceInfo) error {
|
||||
return r.db.WithContext(ctx).Create(info).Error
|
||||
}
|
||||
|
||||
// Update 更新用户开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) Update(ctx context.Context, info *entities.UserInvoiceInfo) error {
|
||||
return r.db.WithContext(ctx).Save(info).Error
|
||||
}
|
||||
|
||||
// Save 保存用户开票信息(创建或更新)
|
||||
func (r *GormUserInvoiceInfoRepository) Save(ctx context.Context, info *entities.UserInvoiceInfo) error {
|
||||
return r.db.WithContext(ctx).Save(info).Error
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserInvoiceInfo, error) {
|
||||
var info entities.UserInvoiceInfo
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&info).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) FindByID(ctx context.Context, id string) (*entities.UserInvoiceInfo, error) {
|
||||
var info entities.UserInvoiceInfo
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&info).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// Delete 删除用户开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) Delete(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&entities.UserInvoiceInfo{}).Error
|
||||
}
|
||||
|
||||
// Exists 检查用户开票信息是否存在
|
||||
func (r *GormUserInvoiceInfoRepository) Exists(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserInvoiceInfo{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ComponentReportDownloadsTable = "component_report_downloads"
|
||||
)
|
||||
|
||||
type GormComponentReportRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ComponentReportRepository = (*GormComponentReportRepository)(nil)
|
||||
|
||||
func NewGormComponentReportRepository(db *gorm.DB, logger *zap.Logger) repositories.ComponentReportRepository {
|
||||
return &GormComponentReportRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ComponentReportDownloadsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) Create(ctx context.Context, download *entities.ComponentReportDownload) error {
|
||||
return r.CreateEntity(ctx, download)
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) UpdateDownload(ctx context.Context, download *entities.ComponentReportDownload) error {
|
||||
return r.UpdateEntity(ctx, download)
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetDownloadByID(ctx context.Context, id string) (*entities.ComponentReportDownload, error) {
|
||||
var download entities.ComponentReportDownload
|
||||
err := r.SmartGetByID(ctx, id, &download)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &download, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetUserDownloads(ctx context.Context, userID string, productID *string) ([]*entities.ComponentReportDownload, error) {
|
||||
var downloads []entities.ComponentReportDownload
|
||||
query := r.GetDB(ctx).Where("user_id = ?", userID)
|
||||
|
||||
if productID != nil && *productID != "" {
|
||||
query = query.Where("product_id = ?", *productID)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&downloads).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ComponentReportDownload, len(downloads))
|
||||
for i := range downloads {
|
||||
result[i] = &downloads[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) HasUserDownloaded(ctx context.Context, userID string, productCode string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).
|
||||
Where("user_id = ? AND product_code = ?", userID, productCode).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetUserDownloadedProductCodes(ctx context.Context, userID string) ([]string, error) {
|
||||
var downloads []entities.ComponentReportDownload
|
||||
err := r.GetDB(ctx).
|
||||
Select("DISTINCT sub_product_codes").
|
||||
Where("user_id = ?", userID).
|
||||
Find(&downloads).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
codesMap := make(map[string]bool)
|
||||
for _, download := range downloads {
|
||||
if download.SubProductCodes != "" {
|
||||
var codes []string
|
||||
if err := json.Unmarshal([]byte(download.SubProductCodes), &codes); err == nil {
|
||||
for _, code := range codes {
|
||||
codesMap[code] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 也添加主产品编号
|
||||
if download.ProductCode != "" {
|
||||
codesMap[download.ProductCode] = true
|
||||
}
|
||||
}
|
||||
|
||||
codes := make([]string, 0, len(codesMap))
|
||||
for code := range codesMap {
|
||||
codes = append(codes, code)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetDownloadByPaymentOrderID(ctx context.Context, orderID string) (*entities.ComponentReportDownload, error) {
|
||||
var download entities.ComponentReportDownload
|
||||
err := r.GetDB(ctx).Where("order_id = ?", orderID).First(&download).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &download, nil
|
||||
}
|
||||
|
||||
// GetActiveDownload 获取用户有效的下载记录
|
||||
func (r *GormComponentReportRepository) GetActiveDownload(ctx context.Context, userID, productID string) (*entities.ComponentReportDownload, error) {
|
||||
var download entities.ComponentReportDownload
|
||||
|
||||
// 先尝试查找有支付订单号的下载记录(已支付)
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND product_id = ? AND order_number IS NOT NULL AND deleted_at IS NULL", userID, productID).
|
||||
Order("created_at DESC").
|
||||
First(&download).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 如果没有找到有支付订单号的记录,尝试查找任何有效的下载记录
|
||||
err = r.GetDB(ctx).
|
||||
Where("user_id = ? AND product_id = ? AND deleted_at IS NULL", userID, productID).
|
||||
Order("created_at DESC").
|
||||
First(&download).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到了下载记录,检查关联的购买订单状态
|
||||
if download.OrderID != nil {
|
||||
// 这里需要查询购买订单状态,但当前仓库没有依赖购买订单仓库
|
||||
// 所以只检查是否有过期时间设置,如果有则认为已支付
|
||||
if download.ExpiresAt == nil {
|
||||
return nil, nil // 没有过期时间,表示未支付
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否已过期
|
||||
if download.IsExpired() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &download, nil
|
||||
}
|
||||
|
||||
// UpdateFilePath 更新下载记录文件路径
|
||||
func (r *GormComponentReportRepository) UpdateFilePath(ctx context.Context, downloadID, filePath string) error {
|
||||
return r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).Where("id = ?", downloadID).Update("file_path", filePath).Error
|
||||
}
|
||||
|
||||
// IncrementDownloadCount 增加下载次数
|
||||
func (r *GormComponentReportRepository) IncrementDownloadCount(ctx context.Context, downloadID string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).
|
||||
Where("id = ?", downloadID).
|
||||
Updates(map[string]interface{}{
|
||||
"download_count": gorm.Expr("download_count + 1"),
|
||||
"last_download_at": &now,
|
||||
}).Error
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductApiConfigsTable = "product_api_configs"
|
||||
ProductApiConfigCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type GormProductApiConfigRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ProductApiConfigRepository = (*GormProductApiConfigRepository)(nil)
|
||||
|
||||
func NewGormProductApiConfigRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductApiConfigRepository {
|
||||
return &GormProductApiConfigRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductApiConfigsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Create(ctx context.Context, config entities.ProductApiConfig) error {
|
||||
return r.CreateEntity(ctx, &config)
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Update(ctx context.Context, config entities.ProductApiConfig) error {
|
||||
return r.UpdateEntity(ctx, &config)
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductApiConfig{})
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) GetByID(ctx context.Context, id string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.SmartGetByID(ctx, id, &config)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.SmartGetByField(ctx, &config, "product_id", productID, ProductApiConfigCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductCode(ctx context.Context, productCode string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.GetDB(ctx).Joins("JOIN products ON products.id = product_api_configs.product_id").
|
||||
Where("products.code = ?", productCode).
|
||||
First(&config).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductApiConfig, error) {
|
||||
var configs []*entities.ProductApiConfig
|
||||
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) ExistsByProductID(ctx context.Context, productID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductApiConfig{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/domains/product/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductCategoriesTable = "product_categories"
|
||||
)
|
||||
|
||||
type GormProductCategoryRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductCategory{})
|
||||
}
|
||||
|
||||
var _ repositories.ProductCategoryRepository = (*GormProductCategoryRepository)(nil)
|
||||
|
||||
func NewGormProductCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductCategoryRepository {
|
||||
return &GormProductCategoryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductCategoriesTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Create(ctx context.Context, entity entities.ProductCategory) (entities.ProductCategory, error) {
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string) (entities.ProductCategory, error) {
|
||||
var entity entities.ProductCategory
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.ProductCategory{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.ProductCategory{}, err
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Update(ctx context.Context, entity entities.ProductCategory) error {
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// FindByCode 根据编号查找产品分类
|
||||
func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductCategory, error) {
|
||||
var entity entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindVisible 查找可见分类
|
||||
func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindEnabled 查找启用分类
|
||||
func (r *GormProductCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("is_enabled = ?", true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListCategories 获取分类列表
|
||||
func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, query *queries.ListCategoriesQuery) ([]*entities.ProductCategory, int64, error) {
|
||||
var categories []entities.ProductCategory
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.IsEnabled != nil {
|
||||
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
|
||||
}
|
||||
if query.IsVisible != nil {
|
||||
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
// 默认按排序字段和创建时间排序
|
||||
dbQuery = dbQuery.Order("sort ASC, created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&categories).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// CountEnabled 统计启用分类数量
|
||||
func (r *GormProductCategoryRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountVisible 统计可见分类数量
|
||||
func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// 基础Repository接口方法
|
||||
|
||||
// Count 返回分类总数
|
||||
func (r *GormProductCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取分类
|
||||
func (r *GormProductCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建分类
|
||||
func (r *GormProductCategoryRepository) CreateBatch(ctx context.Context, categories []entities.ProductCategory) error {
|
||||
return r.GetDB(ctx).Create(&categories).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新分类
|
||||
func (r *GormProductCategoryRepository) UpdateBatch(ctx context.Context, categories []entities.ProductCategory) error {
|
||||
return r.GetDB(ctx).Save(&categories).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除分类
|
||||
func (r *GormProductCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取分类列表(基础方法)
|
||||
func (r *GormProductCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按排序字段和创建时间倒序
|
||||
query = query.Order("sort ASC, created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
// Exists 检查分类是否存在
|
||||
func (r *GormProductCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除分类
|
||||
func (r *GormProductCategoryRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的分类
|
||||
func (r *GormProductCategoryRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormProductCategoryRepository) WithTx(tx interface{}) interfaces.Repository[entities.ProductCategory] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormProductCategoryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductCategoriesTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductDocumentationsTable = "product_documentations"
|
||||
)
|
||||
|
||||
type GormProductDocumentationRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductDocumentationRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductDocumentation{})
|
||||
}
|
||||
|
||||
var _ repositories.ProductDocumentationRepository = (*GormProductDocumentationRepository)(nil)
|
||||
|
||||
func NewGormProductDocumentationRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductDocumentationRepository {
|
||||
return &GormProductDocumentationRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductDocumentationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建文档
|
||||
func (r *GormProductDocumentationRepository) Create(ctx context.Context, documentation *entities.ProductDocumentation) error {
|
||||
return r.CreateEntity(ctx, documentation)
|
||||
}
|
||||
|
||||
// Update 更新文档
|
||||
func (r *GormProductDocumentationRepository) Update(ctx context.Context, documentation *entities.ProductDocumentation) error {
|
||||
return r.UpdateEntity(ctx, documentation)
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找文档
|
||||
func (r *GormProductDocumentationRepository) FindByID(ctx context.Context, id string) (*entities.ProductDocumentation, error) {
|
||||
var entity entities.ProductDocumentation
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByProductID 根据产品ID查找文档
|
||||
func (r *GormProductDocumentationRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductDocumentation, error) {
|
||||
var entity entities.ProductDocumentation
|
||||
err := r.GetDB(ctx).Where("product_id = ?", productID).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByProductIDs 根据产品ID列表批量查找文档
|
||||
func (r *GormProductDocumentationRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductDocumentation, error) {
|
||||
var documentations []entities.ProductDocumentation
|
||||
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&documentations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductDocumentation, len(documentations))
|
||||
for i := range documentations {
|
||||
result[i] = &documentations[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新文档
|
||||
func (r *GormProductDocumentationRepository) UpdateBatch(ctx context.Context, documentations []*entities.ProductDocumentation) error {
|
||||
if len(documentations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用事务进行批量更新
|
||||
return r.GetDB(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for _, doc := range documentations {
|
||||
if err := tx.Save(doc).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CountByProductID 统计指定产品的文档数量
|
||||
func (r *GormProductDocumentationRepository) CountByProductID(ctx context.Context, productID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductDocumentation{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/domains/product/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductsTable = "products"
|
||||
)
|
||||
|
||||
type GormProductRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Product{})
|
||||
}
|
||||
|
||||
var _ repositories.ProductRepository = (*GormProductRepository)(nil)
|
||||
|
||||
func NewGormProductRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductRepository {
|
||||
return &GormProductRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Create(ctx context.Context, entity entities.Product) (entities.Product, error) {
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Product{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Product{}, err
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Update(ctx context.Context, entity entities.Product) error {
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// 其它方法同理迁移,全部用r.GetDB(ctx)
|
||||
|
||||
// FindByCode 根据编号查找产品
|
||||
func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.SmartGetByField(ctx, &entity, "code", code) // 自动缓存
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByOldID 根据旧ID查找产品
|
||||
func (r *GormProductRepository) FindByOldID(ctx context.Context, oldID string) (*entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.GetDB(ctx).Where("old_id = ?", oldID).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根据分类ID查找产品
|
||||
func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Preload("Category").Where("category_id = ?", categoryID).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindVisible 查找可见产品
|
||||
func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindEnabled 查找启用产品
|
||||
func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Preload("Category").Where("is_enabled = ?", true).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListProducts 获取产品列表
|
||||
func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, int64, error) {
|
||||
var productEntities []entities.Product
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
|
||||
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
if query.CategoryID != "" {
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
if query.MinPrice != nil {
|
||||
dbQuery = dbQuery.Where("price >= ?", *query.MinPrice)
|
||||
}
|
||||
if query.MaxPrice != nil {
|
||||
dbQuery = dbQuery.Where("price <= ?", *query.MaxPrice)
|
||||
}
|
||||
if query.IsEnabled != nil {
|
||||
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
|
||||
}
|
||||
if query.IsVisible != nil {
|
||||
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
|
||||
}
|
||||
if query.IsPackage != nil {
|
||||
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载分类信息并获取数据
|
||||
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// ListProductsWithSubscriptionStatus 获取产品列表(包含订阅状态)
|
||||
func (r *GormProductRepository) ListProductsWithSubscriptionStatus(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, map[string]bool, int64, error) {
|
||||
var productEntities []entities.Product
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
|
||||
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
if query.CategoryID != "" {
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
if query.MinPrice != nil {
|
||||
dbQuery = dbQuery.Where("price >= ?", *query.MinPrice)
|
||||
}
|
||||
if query.MaxPrice != nil {
|
||||
dbQuery = dbQuery.Where("price <= ?", *query.MaxPrice)
|
||||
}
|
||||
if query.IsEnabled != nil {
|
||||
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
|
||||
}
|
||||
if query.IsVisible != nil {
|
||||
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
|
||||
}
|
||||
if query.IsPackage != nil {
|
||||
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
|
||||
}
|
||||
|
||||
// 如果指定了用户ID,添加订阅状态筛选
|
||||
if query.UserID != "" && query.IsSubscribed != nil {
|
||||
if *query.IsSubscribed {
|
||||
// 筛选已订阅的产品
|
||||
dbQuery = dbQuery.Where("EXISTS (SELECT 1 FROM subscription WHERE subscription.product_id = product.id AND subscription.user_id = ?)", query.UserID)
|
||||
} else {
|
||||
// 筛选未订阅的产品
|
||||
dbQuery = dbQuery.Where("NOT EXISTS (SELECT 1 FROM subscription WHERE subscription.product_id = product.id AND subscription.user_id = ?)", query.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载分类信息并获取数据
|
||||
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
|
||||
// 获取订阅状态映射
|
||||
subscriptionStatusMap := make(map[string]bool)
|
||||
if query.UserID != "" && len(result) > 0 {
|
||||
productIDs := make([]string, len(result))
|
||||
for i, product := range result {
|
||||
productIDs[i] = product.ID
|
||||
}
|
||||
|
||||
// 查询用户的订阅状态
|
||||
var subscriptions []struct {
|
||||
ProductID string `gorm:"column:product_id"`
|
||||
}
|
||||
err := r.GetDB(ctx).Table("subscription").
|
||||
Select("product_id").
|
||||
Where("user_id = ? AND product_id IN ?", query.UserID, productIDs).
|
||||
Find(&subscriptions).Error
|
||||
|
||||
if err == nil {
|
||||
for _, sub := range subscriptions {
|
||||
subscriptionStatusMap[sub.ProductID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, subscriptionStatusMap, total, nil
|
||||
}
|
||||
|
||||
// FindSubscribableProducts 查找可订阅产品
|
||||
func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, userID string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindProductsByIDs 根据ID列表查找产品
|
||||
func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CountByCategory 统计分类下的产品数量
|
||||
func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID string) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
if categoryID != "" {
|
||||
query = query.Where("category_id = ?", categoryID)
|
||||
}
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountEnabled 统计启用产品数量
|
||||
func (r *GormProductRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountVisible 统计可见产品数量
|
||||
func (r *GormProductRepository) CountVisible(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Count 返回产品总数
|
||||
func (r *GormProductRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取产品
|
||||
func (r *GormProductRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Product, error) {
|
||||
var products []entities.Product
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建产品
|
||||
func (r *GormProductRepository) CreateBatch(ctx context.Context, products []entities.Product) error {
|
||||
return r.GetDB(ctx).Create(&products).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新产品
|
||||
func (r *GormProductRepository) UpdateBatch(ctx context.Context, products []entities.Product) error {
|
||||
return r.GetDB(ctx).Save(&products).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除产品
|
||||
func (r *GormProductRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取产品列表(基础方法)
|
||||
func (r *GormProductRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Product, error) {
|
||||
var products []entities.Product
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// Exists 检查产品是否存在
|
||||
func (r *GormProductRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除产品
|
||||
func (r *GormProductRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Product{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的产品
|
||||
func (r *GormProductRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// GetPackageItems 获取组合包项目
|
||||
func (r *GormProductRepository) GetPackageItems(ctx context.Context, packageID string) ([]*entities.ProductPackageItem, error) {
|
||||
var packageItems []entities.ProductPackageItem
|
||||
err := r.GetDB(ctx).
|
||||
Preload("Product").
|
||||
Where("package_id = ?", packageID).
|
||||
Order("sort_order ASC").
|
||||
Find(&packageItems).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductPackageItem, len(packageItems))
|
||||
for i := range packageItems {
|
||||
result[i] = &packageItems[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CreatePackageItem 创建组合包项目
|
||||
func (r *GormProductRepository) CreatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
|
||||
return r.GetDB(ctx).Create(packageItem).Error
|
||||
}
|
||||
|
||||
// GetPackageItemByID 根据ID获取组合包项目
|
||||
func (r *GormProductRepository) GetPackageItemByID(ctx context.Context, itemID string) (*entities.ProductPackageItem, error) {
|
||||
var packageItem entities.ProductPackageItem
|
||||
err := r.GetDB(ctx).
|
||||
Preload("Product").
|
||||
Preload("Package").
|
||||
Where("id = ?", itemID).
|
||||
First(&packageItem).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packageItem, nil
|
||||
}
|
||||
|
||||
// UpdatePackageItem 更新组合包项目
|
||||
func (r *GormProductRepository) UpdatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
|
||||
return r.GetDB(ctx).Save(packageItem).Error
|
||||
}
|
||||
|
||||
// DeletePackageItem 删除组合包项目(硬删除)
|
||||
func (r *GormProductRepository) DeletePackageItem(ctx context.Context, itemID string) error {
|
||||
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "id = ?", itemID).Error
|
||||
}
|
||||
|
||||
// DeletePackageItemsByPackageID 根据组合包ID删除所有子产品(硬删除)
|
||||
func (r *GormProductRepository) DeletePackageItemsByPackageID(ctx context.Context, packageID string) error {
|
||||
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "package_id = ?", packageID).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormProductRepository) WithTx(tx interface{}) interfaces.Repository[entities.Product] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormProductRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductSubCategoriesTable = "product_sub_categories"
|
||||
)
|
||||
|
||||
type GormProductSubCategoryRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ProductSubCategoryRepository = (*GormProductSubCategoryRepository)(nil)
|
||||
|
||||
func NewGormProductSubCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductSubCategoryRepository {
|
||||
return &GormProductSubCategoryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductSubCategoriesTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建二级分类
|
||||
func (r *GormProductSubCategoryRepository) Create(ctx context.Context, category entities.ProductSubCategory) (*entities.ProductSubCategory, error) {
|
||||
err := r.CreateEntity(ctx, &category)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &category, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取二级分类
|
||||
func (r *GormProductSubCategoryRepository) GetByID(ctx context.Context, id string) (*entities.ProductSubCategory, error) {
|
||||
var entity entities.ProductSubCategory
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// Update 更新二级分类
|
||||
func (r *GormProductSubCategoryRepository) Update(ctx context.Context, category entities.ProductSubCategory) error {
|
||||
return r.UpdateEntity(ctx, &category)
|
||||
}
|
||||
|
||||
// Delete 删除二级分类
|
||||
func (r *GormProductSubCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductSubCategory{})
|
||||
}
|
||||
|
||||
// List 获取所有二级分类
|
||||
func (r *GormProductSubCategoryRepository) List(ctx context.Context) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByCode 根据编号查找二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductSubCategory, error) {
|
||||
var entity entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根据一级分类ID查找二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("category_id = ?", categoryID).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindVisible 查找可见的二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindEnabled 查找启用的二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("is_enabled = ?", true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormProductUIComponentRepository 产品UI组件关联仓储实现
|
||||
type GormProductUIComponentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormProductUIComponentRepository 创建产品UI组件关联仓储实例
|
||||
func NewGormProductUIComponentRepository(db *gorm.DB) repositories.ProductUIComponentRepository {
|
||||
return &GormProductUIComponentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建产品UI组件关联
|
||||
func (r *GormProductUIComponentRepository) Create(ctx context.Context, relation entities.ProductUIComponent) (entities.ProductUIComponent, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&relation).Error; err != nil {
|
||||
return entities.ProductUIComponent{}, fmt.Errorf("创建产品UI组件关联失败: %w", err)
|
||||
}
|
||||
return relation, nil
|
||||
}
|
||||
|
||||
// GetByProductID 根据产品ID获取UI组件关联列表
|
||||
func (r *GormProductUIComponentRepository) GetByProductID(ctx context.Context, productID string) ([]entities.ProductUIComponent, error) {
|
||||
var relations []entities.ProductUIComponent
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("UIComponent").
|
||||
Where("product_id = ?", productID).
|
||||
Find(&relations).Error; err != nil {
|
||||
return nil, fmt.Errorf("获取产品UI组件关联列表失败: %w", err)
|
||||
}
|
||||
return relations, nil
|
||||
}
|
||||
|
||||
// GetByUIComponentID 根据UI组件ID获取产品关联列表
|
||||
func (r *GormProductUIComponentRepository) GetByUIComponentID(ctx context.Context, componentID string) ([]entities.ProductUIComponent, error) {
|
||||
var relations []entities.ProductUIComponent
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("Product").
|
||||
Where("ui_component_id = ?", componentID).
|
||||
Find(&relations).Error; err != nil {
|
||||
return nil, fmt.Errorf("获取UI组件产品关联列表失败: %w", err)
|
||||
}
|
||||
return relations, nil
|
||||
}
|
||||
|
||||
// Delete 删除产品UI组件关联
|
||||
func (r *GormProductUIComponentRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.ProductUIComponent{}, id).Error; err != nil {
|
||||
return fmt.Errorf("删除产品UI组件关联失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByProductID 根据产品ID删除所有关联
|
||||
func (r *GormProductUIComponentRepository) DeleteByProductID(ctx context.Context, productID string) error {
|
||||
if err := r.db.WithContext(ctx).Where("product_id = ?", productID).Delete(&entities.ProductUIComponent{}).Error; err != nil {
|
||||
return fmt.Errorf("根据产品ID删除UI组件关联失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建产品UI组件关联
|
||||
func (r *GormProductUIComponentRepository) BatchCreate(ctx context.Context, relations []entities.ProductUIComponent) error {
|
||||
if len(relations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).CreateInBatches(relations, 100).Error; err != nil {
|
||||
return fmt.Errorf("批量创建产品UI组件关联失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/domains/product/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
SubscriptionsTable = "subscription"
|
||||
SubscriptionCacheTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
type GormSubscriptionRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Subscription{})
|
||||
}
|
||||
|
||||
var _ repositories.SubscriptionRepository = (*GormSubscriptionRepository)(nil)
|
||||
|
||||
func NewGormSubscriptionRepository(db *gorm.DB, logger *zap.Logger) repositories.SubscriptionRepository {
|
||||
return &GormSubscriptionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SubscriptionsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Create(ctx context.Context, entity entities.Subscription) (entities.Subscription, error) {
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (entities.Subscription, error) {
|
||||
var entity entities.Subscription
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Subscription{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Subscription{}, err
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Update(ctx context.Context, entity entities.Subscription) error {
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&subscriptions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Subscription, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
result[i] = &subscriptions[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByProductID 根据产品ID查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, productID string) ([]*entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("product_id = ?", productID).Order("created_at DESC").Find(&subscriptions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Subscription, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
result[i] = &subscriptions[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByUserAndProduct 根据用户和产品查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByUserAndProduct(ctx context.Context, userID, productID string) (*entities.Subscription, error) {
|
||||
var entity entities.Subscription
|
||||
// 组合缓存key的条件
|
||||
where := "user_id = ? AND product_id = ?"
|
||||
ttl := SubscriptionCacheTTL // 缓存10分钟,可根据业务调整
|
||||
err := r.GetWithCache(ctx, &entity, ttl, where, userID, productID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// ListSubscriptions 获取订阅列表
|
||||
func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, query *queries.ListSubscriptionsQuery) ([]*entities.Subscription, int64, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("subscription.user_id = ?", query.UserID)
|
||||
}
|
||||
|
||||
// 关键词搜索(产品名称或编码)
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Joins("LEFT JOIN product ON product.id = subscription.product_id").
|
||||
Where("product.name LIKE ? OR product.code LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if query.ProductName != "" {
|
||||
dbQuery = dbQuery.Joins("LEFT JOIN product ON product.id = subscription.product_id").
|
||||
Where("product.name LIKE ?", "%"+query.ProductName+"%")
|
||||
}
|
||||
|
||||
// 企业名称筛选(需要关联用户和企业信息)
|
||||
if query.CompanyName != "" {
|
||||
dbQuery = dbQuery.Joins("LEFT JOIN users ON users.id = subscription.user_id").
|
||||
Joins("LEFT JOIN enterprise_infos ON enterprise_infos.user_id = users.id").
|
||||
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if query.StartTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", query.StartTime); err == nil {
|
||||
dbQuery = dbQuery.Where("subscription.created_at >= ?", t)
|
||||
}
|
||||
}
|
||||
if query.EndTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", query.EndTime); err == nil {
|
||||
dbQuery = dbQuery.Where("subscription.created_at <= ?", t)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("subscription.created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载Product的id、name、code、price、cost_price、is_package字段,并同时预加载ProductCategory的id、name、code字段
|
||||
if err := dbQuery.
|
||||
Preload("Product", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Select("id", "name", "code", "price", "cost_price", "is_package", "category_id").
|
||||
Preload("Category", func(db2 *gorm.DB) *gorm.DB {
|
||||
return db2.Select("id", "name", "code")
|
||||
})
|
||||
}).
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Subscription, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
result[i] = &subscriptions[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// CountByUser 统计用户订阅数量
|
||||
func (r *GormSubscriptionRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByProduct 统计产品的订阅数量
|
||||
func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, productID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetTotalRevenue 获取总收入
|
||||
func (r *GormSubscriptionRepository) GetTotalRevenue(ctx context.Context) (float64, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Select("COALESCE(SUM(price), 0)").Scan(&total).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total.InexactFloat64(), nil
|
||||
}
|
||||
|
||||
// 基础Repository接口方法
|
||||
|
||||
// Count 返回订阅总数
|
||||
func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR product_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取订阅
|
||||
func (r *GormSubscriptionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&subscriptions).Error
|
||||
return subscriptions, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建订阅
|
||||
func (r *GormSubscriptionRepository) CreateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Create(&subscriptions).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新订阅
|
||||
func (r *GormSubscriptionRepository) UpdateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Save(&subscriptions).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除订阅
|
||||
func (r *GormSubscriptionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取订阅列表(基础方法)
|
||||
func (r *GormSubscriptionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR product_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&subscriptions).Error
|
||||
return subscriptions, err
|
||||
}
|
||||
|
||||
// Exists 检查订阅是否存在
|
||||
func (r *GormSubscriptionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除订阅
|
||||
func (r *GormSubscriptionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的订阅
|
||||
func (r *GormSubscriptionRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormSubscriptionRepository) WithTx(tx interface{}) interfaces.Repository[entities.Subscription] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormSubscriptionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), SubscriptionsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// IncrementAPIUsageWithOptimisticLock 使用乐观锁增加API使用次数
|
||||
func (r *GormSubscriptionRepository) IncrementAPIUsageWithOptimisticLock(ctx context.Context, subscriptionID string, increment int64) error {
|
||||
// 使用原生SQL进行乐观锁更新
|
||||
result := r.GetDB(ctx).WithContext(ctx).Exec(`
|
||||
UPDATE subscription
|
||||
SET api_used = api_used + ?, version = version + 1, updated_at = NOW()
|
||||
WHERE id = ? AND version = (
|
||||
SELECT version FROM subscription WHERE id = ?
|
||||
)
|
||||
`, increment, subscriptionID, subscriptionID)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// 检查是否有行被更新
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormUIComponentRepository UI组件仓储实现
|
||||
type GormUIComponentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormUIComponentRepository 创建UI组件仓储实例
|
||||
func NewGormUIComponentRepository(db *gorm.DB) repositories.UIComponentRepository {
|
||||
return &GormUIComponentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建UI组件
|
||||
func (r *GormUIComponentRepository) Create(ctx context.Context, component entities.UIComponent) (entities.UIComponent, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&component).Error; err != nil {
|
||||
return entities.UIComponent{}, fmt.Errorf("创建UI组件失败: %w", err)
|
||||
}
|
||||
return component, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取UI组件
|
||||
func (r *GormUIComponentRepository) GetByID(ctx context.Context, id string) (*entities.UIComponent, error) {
|
||||
var component entities.UIComponent
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&component).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("获取UI组件失败: %w", err)
|
||||
}
|
||||
return &component, nil
|
||||
}
|
||||
|
||||
// GetByCode 根据编码获取UI组件
|
||||
func (r *GormUIComponentRepository) GetByCode(ctx context.Context, code string) (*entities.UIComponent, error) {
|
||||
var component entities.UIComponent
|
||||
if err := r.db.WithContext(ctx).Where("component_code = ?", code).First(&component).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("获取UI组件失败: %w", err)
|
||||
}
|
||||
return &component, nil
|
||||
}
|
||||
|
||||
// List 获取UI组件列表
|
||||
func (r *GormUIComponentRepository) List(ctx context.Context, filters map[string]interface{}) ([]entities.UIComponent, int64, error) {
|
||||
var components []entities.UIComponent
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.UIComponent{})
|
||||
|
||||
// 应用过滤条件
|
||||
if isActive, ok := filters["is_active"]; ok {
|
||||
query = query.Where("is_active = ?", isActive)
|
||||
}
|
||||
|
||||
if keyword, ok := filters["keyword"]; ok && keyword != "" {
|
||||
query = query.Where("component_name LIKE ? OR component_code LIKE ? OR description LIKE ?",
|
||||
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("获取UI组件总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页
|
||||
if page, ok := filters["page"]; ok {
|
||||
if pageSize, ok := filters["page_size"]; ok {
|
||||
offset := (page.(int) - 1) * pageSize.(int)
|
||||
query = query.Offset(offset).Limit(pageSize.(int))
|
||||
}
|
||||
}
|
||||
|
||||
// 排序
|
||||
if sortBy, ok := filters["sort_by"]; ok {
|
||||
if sortOrder, ok := filters["sort_order"]; ok {
|
||||
query = query.Order(fmt.Sprintf("%s %s", sortBy, sortOrder))
|
||||
}
|
||||
} else {
|
||||
query = query.Order("sort_order ASC, created_at DESC")
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
if err := query.Find(&components).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("获取UI组件列表失败: %w", err)
|
||||
}
|
||||
|
||||
return components, total, nil
|
||||
}
|
||||
|
||||
// Update 更新UI组件
|
||||
func (r *GormUIComponentRepository) Update(ctx context.Context, component entities.UIComponent) error {
|
||||
if err := r.db.WithContext(ctx).Save(&component).Error; err != nil {
|
||||
return fmt.Errorf("更新UI组件失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除UI组件
|
||||
func (r *GormUIComponentRepository) Delete(ctx context.Context, id string) error {
|
||||
// 记录删除操作的详细信息
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.UIComponent{}).Error; err != nil {
|
||||
return fmt.Errorf("删除UI组件失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByCodes 根据编码列表获取UI组件
|
||||
func (r *GormUIComponentRepository) GetByCodes(ctx context.Context, codes []string) ([]entities.UIComponent, error) {
|
||||
var components []entities.UIComponent
|
||||
if len(codes) == 0 {
|
||||
return components, nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("component_code IN ?", codes).Find(&components).Error; err != nil {
|
||||
return nil, fmt.Errorf("根据编码列表获取UI组件失败: %w", err)
|
||||
}
|
||||
|
||||
return components, nil
|
||||
}
|
||||
@@ -0,0 +1,461 @@
|
||||
package statistics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/statistics/entities"
|
||||
"hyapi-server/internal/domains/statistics/repositories"
|
||||
)
|
||||
|
||||
// GormStatisticsDashboardRepository GORM统计仪表板仓储实现
|
||||
type GormStatisticsDashboardRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormStatisticsDashboardRepository 创建GORM统计仪表板仓储
|
||||
func NewGormStatisticsDashboardRepository(db *gorm.DB) repositories.StatisticsDashboardRepository {
|
||||
return &GormStatisticsDashboardRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) Save(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
|
||||
if dashboard == nil {
|
||||
return fmt.Errorf("统计仪表板不能为空")
|
||||
}
|
||||
|
||||
// 验证仪表板
|
||||
if err := dashboard.Validate(); err != nil {
|
||||
return fmt.Errorf("统计仪表板验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result := r.db.WithContext(ctx).Save(dashboard)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsDashboard, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
var dashboard entities.StatisticsDashboard
|
||||
result := r.db.WithContext(ctx).Where("id = ?", id).First(&dashboard)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("统计仪表板不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &dashboard, nil
|
||||
}
|
||||
|
||||
// FindByUser 根据用户查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("created_by = ?", userID)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindByUserRole 根据用户角色查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByUserRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// Update 更新统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) Update(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
|
||||
if dashboard == nil {
|
||||
return fmt.Errorf("统计仪表板不能为空")
|
||||
}
|
||||
|
||||
if dashboard.ID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
// 验证仪表板
|
||||
if err := dashboard.Validate(); err != nil {
|
||||
return fmt.Errorf("统计仪表板验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
result := r.db.WithContext(ctx).Save(dashboard)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("统计仪表板不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByRole 根据角色查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindDefaultByRole 根据角色查找默认统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindDefaultByRole(ctx context.Context, userRole string) (*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboard entities.StatisticsDashboard
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("user_role = ? AND is_default = ? AND is_active = ?", userRole, true, true).
|
||||
First(&dashboard)
|
||||
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("默认统计仪表板不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询默认统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &dashboard, nil
|
||||
}
|
||||
|
||||
// FindActiveByRole 根据角色查找激活的统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindActiveByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("user_role = ? AND is_active = ?", userRole, true)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询激活统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByStatus(ctx context.Context, isActive bool, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("is_active = ?", isActive)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindByAccessLevel 根据访问级别查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByAccessLevel(ctx context.Context, accessLevel string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if accessLevel == "" {
|
||||
return nil, fmt.Errorf("访问级别不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("access_level = ?", accessLevel)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// CountByUser 根据用户统计数量
|
||||
func (r *GormStatisticsDashboardRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
if userID == "" {
|
||||
return 0, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("created_by = ?", userID).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByRole 根据角色统计数量
|
||||
func (r *GormStatisticsDashboardRepository) CountByRole(ctx context.Context, userRole string) (int64, error) {
|
||||
if userRole == "" {
|
||||
return 0, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("user_role = ?", userRole).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByStatus 根据状态统计数量
|
||||
func (r *GormStatisticsDashboardRepository) CountByStatus(ctx context.Context, isActive bool) (int64, error) {
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("is_active = ?", isActive).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BatchSave 批量保存统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) BatchSave(ctx context.Context, dashboards []*entities.StatisticsDashboard) error {
|
||||
if len(dashboards) == 0 {
|
||||
return fmt.Errorf("统计仪表板列表不能为空")
|
||||
}
|
||||
|
||||
// 验证所有仪表板
|
||||
for _, dashboard := range dashboards {
|
||||
if err := dashboard.Validate(); err != nil {
|
||||
return fmt.Errorf("统计仪表板验证失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存
|
||||
result := r.db.WithContext(ctx).CreateInBatches(dashboards, 100)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量保存统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) BatchDelete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("仪表板ID列表不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id IN ?", ids)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量删除统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultDashboard 设置默认仪表板
|
||||
func (r *GormStatisticsDashboardRepository) SetDefaultDashboard(ctx context.Context, dashboardID string) error {
|
||||
if dashboardID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
tx := r.db.WithContext(ctx).Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 先取消同角色的所有默认状态
|
||||
var dashboard entities.StatisticsDashboard
|
||||
if err := tx.Where("id = ?", dashboardID).First(&dashboard).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("查询仪表板失败: %w", err)
|
||||
}
|
||||
|
||||
// 取消同角色的所有默认状态
|
||||
if err := tx.Model(&entities.StatisticsDashboard{}).
|
||||
Where("user_role = ? AND is_default = ?", dashboard.UserRole, true).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("取消默认状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置新的默认状态
|
||||
if err := tx.Model(&entities.StatisticsDashboard{}).
|
||||
Where("id = ?", dashboardID).
|
||||
Update("is_default", true).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("设置默认状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDefaultDashboard 移除默认仪表板
|
||||
func (r *GormStatisticsDashboardRepository) RemoveDefaultDashboard(ctx context.Context, userRole string) error {
|
||||
if userRole == "" {
|
||||
return fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("user_role = ? AND is_default = ?", userRole, true).
|
||||
Update("is_default", false)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("移除默认仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActivateDashboard 激活仪表板
|
||||
func (r *GormStatisticsDashboardRepository) ActivateDashboard(ctx context.Context, dashboardID string) error {
|
||||
if dashboardID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("id = ?", dashboardID).
|
||||
Update("is_active", true)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("激活仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("仪表板不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeactivateDashboard 停用仪表板
|
||||
func (r *GormStatisticsDashboardRepository) DeactivateDashboard(ctx context.Context, dashboardID string) error {
|
||||
if dashboardID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("id = ?", dashboardID).
|
||||
Update("is_active", false)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("停用仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("仪表板不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package statistics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/statistics/entities"
|
||||
"hyapi-server/internal/domains/statistics/repositories"
|
||||
)
|
||||
|
||||
// GormStatisticsReportRepository GORM统计报告仓储实现
|
||||
type GormStatisticsReportRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormStatisticsReportRepository 创建GORM统计报告仓储
|
||||
func NewGormStatisticsReportRepository(db *gorm.DB) repositories.StatisticsReportRepository {
|
||||
return &GormStatisticsReportRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存统计报告
|
||||
func (r *GormStatisticsReportRepository) Save(ctx context.Context, report *entities.StatisticsReport) error {
|
||||
if report == nil {
|
||||
return fmt.Errorf("统计报告不能为空")
|
||||
}
|
||||
|
||||
// 验证报告
|
||||
if err := report.Validate(); err != nil {
|
||||
return fmt.Errorf("统计报告验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result := r.db.WithContext(ctx).Save(report)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsReport, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("报告ID不能为空")
|
||||
}
|
||||
|
||||
var report entities.StatisticsReport
|
||||
result := r.db.WithContext(ctx).Where("id = ?", id).First(&report)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("统计报告不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &report, nil
|
||||
}
|
||||
|
||||
// FindByUser 根据用户查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).Where("generated_by = ?", userID)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByStatus(ctx context.Context, status string) ([]*entities.StatisticsReport, error) {
|
||||
if status == "" {
|
||||
return nil, fmt.Errorf("报告状态不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&reports)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// Update 更新统计报告
|
||||
func (r *GormStatisticsReportRepository) Update(ctx context.Context, report *entities.StatisticsReport) error {
|
||||
if report == nil {
|
||||
return fmt.Errorf("统计报告不能为空")
|
||||
}
|
||||
|
||||
if report.ID == "" {
|
||||
return fmt.Errorf("报告ID不能为空")
|
||||
}
|
||||
|
||||
// 验证报告
|
||||
if err := report.Validate(); err != nil {
|
||||
return fmt.Errorf("统计报告验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
result := r.db.WithContext(ctx).Save(report)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除统计报告
|
||||
func (r *GormStatisticsReportRepository) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("报告ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("统计报告不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByType 根据类型查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByType(ctx context.Context, reportType string, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if reportType == "" {
|
||||
return nil, fmt.Errorf("报告类型不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).Where("report_type = ?", reportType)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByTypeAndPeriod 根据类型和周期查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByTypeAndPeriod(ctx context.Context, reportType, period string, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if reportType == "" {
|
||||
return nil, fmt.Errorf("报告类型不能为空")
|
||||
}
|
||||
|
||||
if period == "" {
|
||||
return nil, fmt.Errorf("统计周期不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("report_type = ? AND period = ?", reportType, period)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByDateRange 根据日期范围查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByDateRange(ctx context.Context, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("created_at >= ? AND created_at < ?", startDate, endDate)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByUserAndDateRange 根据用户和日期范围查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByUserAndDateRange(ctx context.Context, userID string, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("generated_by = ? AND created_at >= ? AND created_at < ?", userID, startDate, endDate)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// CountByUser 根据用户统计数量
|
||||
func (r *GormStatisticsReportRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
if userID == "" {
|
||||
return 0, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsReport{}).
|
||||
Where("generated_by = ?", userID).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByType 根据类型统计数量
|
||||
func (r *GormStatisticsReportRepository) CountByType(ctx context.Context, reportType string) (int64, error) {
|
||||
if reportType == "" {
|
||||
return 0, fmt.Errorf("报告类型不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsReport{}).
|
||||
Where("report_type = ?", reportType).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByStatus 根据状态统计数量
|
||||
func (r *GormStatisticsReportRepository) CountByStatus(ctx context.Context, status string) (int64, error) {
|
||||
if status == "" {
|
||||
return 0, fmt.Errorf("报告状态不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsReport{}).
|
||||
Where("status = ?", status).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BatchSave 批量保存统计报告
|
||||
func (r *GormStatisticsReportRepository) BatchSave(ctx context.Context, reports []*entities.StatisticsReport) error {
|
||||
if len(reports) == 0 {
|
||||
return fmt.Errorf("统计报告列表不能为空")
|
||||
}
|
||||
|
||||
// 验证所有报告
|
||||
for _, report := range reports {
|
||||
if err := report.Validate(); err != nil {
|
||||
return fmt.Errorf("统计报告验证失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存
|
||||
result := r.db.WithContext(ctx).CreateInBatches(reports, 100)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量保存统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除统计报告
|
||||
func (r *GormStatisticsReportRepository) BatchDelete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("报告ID列表不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id IN ?", ids)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量删除统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredReports 删除过期报告
|
||||
func (r *GormStatisticsReportRepository) DeleteExpiredReports(ctx context.Context, expiredBefore time.Time) error {
|
||||
if expiredBefore.IsZero() {
|
||||
return fmt.Errorf("过期时间不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsReport{}, "expires_at IS NOT NULL AND expires_at < ?", expiredBefore)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除过期报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByStatus 根据状态删除统计报告
|
||||
func (r *GormStatisticsReportRepository) DeleteByStatus(ctx context.Context, status string) error {
|
||||
if status == "" {
|
||||
return fmt.Errorf("报告状态不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsReport{}, "status = ?", status)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("根据状态删除统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,381 @@
|
||||
package statistics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/statistics/entities"
|
||||
"hyapi-server/internal/domains/statistics/repositories"
|
||||
)
|
||||
|
||||
// GormStatisticsRepository GORM统计指标仓储实现
|
||||
type GormStatisticsRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormStatisticsRepository 创建GORM统计指标仓储
|
||||
func NewGormStatisticsRepository(db *gorm.DB) repositories.StatisticsRepository {
|
||||
return &GormStatisticsRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存统计指标
|
||||
func (r *GormStatisticsRepository) Save(ctx context.Context, metric *entities.StatisticsMetric) error {
|
||||
if metric == nil {
|
||||
return fmt.Errorf("统计指标不能为空")
|
||||
}
|
||||
|
||||
// 验证指标
|
||||
if err := metric.Validate(); err != nil {
|
||||
return fmt.Errorf("统计指标验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result := r.db.WithContext(ctx).Create(metric)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsMetric, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("指标ID不能为空")
|
||||
}
|
||||
|
||||
var metric entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).Where("id = ?", id).First(&metric)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("统计指标不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &metric, nil
|
||||
}
|
||||
|
||||
// FindByType 根据类型查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByType(ctx context.Context, metricType string, limit, offset int) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
query := r.db.WithContext(ctx).Where("metric_type = ?", metricType)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&metrics)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// Update 更新统计指标
|
||||
func (r *GormStatisticsRepository) Update(ctx context.Context, metric *entities.StatisticsMetric) error {
|
||||
if metric == nil {
|
||||
return fmt.Errorf("统计指标不能为空")
|
||||
}
|
||||
|
||||
if metric.ID == "" {
|
||||
return fmt.Errorf("指标ID不能为空")
|
||||
}
|
||||
|
||||
// 验证指标
|
||||
if err := metric.Validate(); err != nil {
|
||||
return fmt.Errorf("统计指标验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
result := r.db.WithContext(ctx).Save(metric)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除统计指标
|
||||
func (r *GormStatisticsRepository) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("指标ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("统计指标不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByTypeAndDateRange 根据类型和日期范围查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
|
||||
Order("date ASC").
|
||||
Find(&metrics)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// FindByTypeDimensionAndDateRange 根据类型、维度和日期范围查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByTypeDimensionAndDateRange(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate)
|
||||
|
||||
if dimension != "" {
|
||||
query = query.Where("dimension = ?", dimension)
|
||||
}
|
||||
|
||||
result := query.Order("date ASC").Find(&metrics)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// FindByTypeNameAndDateRange 根据类型、名称和日期范围查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByTypeNameAndDateRange(ctx context.Context, metricType, metricName string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if metricName == "" {
|
||||
return nil, fmt.Errorf("指标名称不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("metric_type = ? AND metric_name = ? AND date >= ? AND date < ?",
|
||||
metricType, metricName, startDate, endDate).
|
||||
Order("date ASC").
|
||||
Find(&metrics)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// GetAggregatedMetrics 获取聚合指标
|
||||
func (r *GormStatisticsRepository) GetAggregatedMetrics(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) (map[string]float64, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
type AggregatedResult struct {
|
||||
MetricName string `json:"metric_name"`
|
||||
TotalValue float64 `json:"total_value"`
|
||||
}
|
||||
|
||||
var results []AggregatedResult
|
||||
query := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsMetric{}).
|
||||
Select("metric_name, SUM(value) as total_value").
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
|
||||
Group("metric_name")
|
||||
|
||||
if dimension != "" {
|
||||
query = query.Where("dimension = ?", dimension)
|
||||
}
|
||||
|
||||
result := query.Find(&results)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询聚合指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
// 转换为map
|
||||
aggregated := make(map[string]float64)
|
||||
for _, res := range results {
|
||||
aggregated[res.MetricName] = res.TotalValue
|
||||
}
|
||||
|
||||
return aggregated, nil
|
||||
}
|
||||
|
||||
// GetMetricsByDimension 根据维度获取指标
|
||||
func (r *GormStatisticsRepository) GetMetricsByDimension(ctx context.Context, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if dimension == "" {
|
||||
return nil, fmt.Errorf("统计维度不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("dimension = ? AND date >= ? AND date < ?", dimension, startDate, endDate).
|
||||
Order("date ASC").
|
||||
Find(&metrics)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// CountByType 根据类型统计数量
|
||||
func (r *GormStatisticsRepository) CountByType(ctx context.Context, metricType string) (int64, error) {
|
||||
if metricType == "" {
|
||||
return 0, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsMetric{}).
|
||||
Where("metric_type = ?", metricType).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByTypeAndDateRange 根据类型和日期范围统计数量
|
||||
func (r *GormStatisticsRepository) CountByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) (int64, error) {
|
||||
if metricType == "" {
|
||||
return 0, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return 0, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsMetric{}).
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BatchSave 批量保存统计指标
|
||||
func (r *GormStatisticsRepository) BatchSave(ctx context.Context, metrics []*entities.StatisticsMetric) error {
|
||||
if len(metrics) == 0 {
|
||||
return fmt.Errorf("统计指标列表不能为空")
|
||||
}
|
||||
|
||||
// 验证所有指标
|
||||
for _, metric := range metrics {
|
||||
if err := metric.Validate(); err != nil {
|
||||
return fmt.Errorf("统计指标验证失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存
|
||||
result := r.db.WithContext(ctx).CreateInBatches(metrics, 100)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量保存统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除统计指标
|
||||
func (r *GormStatisticsRepository) BatchDelete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("指标ID列表不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id IN ?", ids)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByDateRange 根据日期范围删除统计指标
|
||||
func (r *GormStatisticsRepository) DeleteByDateRange(ctx context.Context, startDate, endDate time.Time) error {
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsMetric{}, "date >= ? AND date < ?", startDate, endDate)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("根据日期范围删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByTypeAndDateRange 根据类型和日期范围删除统计指标
|
||||
func (r *GormStatisticsRepository) DeleteByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) error {
|
||||
if metricType == "" {
|
||||
return fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsMetric{}, "metric_type = ? AND date >= ? AND date < ?",
|
||||
metricType, startDate, endDate)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("根据类型和日期范围删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ContractInfosTable = "contract_infos"
|
||||
)
|
||||
|
||||
type GormContractInfoRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
|
||||
return &GormContractInfoRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
|
||||
return r.CreateEntity(ctx, contract)
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
|
||||
var contract entities.ContractInfo
|
||||
err := r.SmartGetByID(ctx, contractID, &contract)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &contract, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
|
||||
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormEnterpriseInfoRepository 企业信息GORM仓储实现
|
||||
type GormEnterpriseInfoRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGormEnterpriseInfoRepository 创建企业信息GORM仓储
|
||||
func NewGormEnterpriseInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.EnterpriseInfoRepository {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) (entities.EnterpriseInfo, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("创建企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("创建企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// Update 更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
|
||||
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("更新企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("更新企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("软删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("软删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("恢复企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByUnifiedSocialCode 根据统一社会信用代码获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// CheckUnifiedSocialCodeExists 检查统一社会信用代码是否已存在
|
||||
func (r *GormEnterpriseInfoRepository) CheckUnifiedSocialCodeExists(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("user_id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
r.logger.Error("检查统一社会信用代码失败", zap.Error(err))
|
||||
return false, fmt.Errorf("检查统一社会信用代码失败: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// UpdateVerificationStatus 更新验证状态
|
||||
func (r *GormEnterpriseInfoRepository) UpdateVerificationStatus(ctx context.Context, userID string, isOCRVerified, isFaceVerified, isCertified bool) error {
|
||||
updates := map[string]interface{}{
|
||||
"is_ocr_verified": isOCRVerified,
|
||||
"is_face_verified": isFaceVerified,
|
||||
"is_certified": isCertified,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新验证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新验证状态失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOCRData 更新OCR数据
|
||||
func (r *GormEnterpriseInfoRepository) UpdateOCRData(ctx context.Context, userID string, rawData string, confidence float64) error {
|
||||
updates := map[string]interface{}{
|
||||
"ocr_raw_data": rawData,
|
||||
"ocr_confidence": confidence,
|
||||
"is_ocr_verified": true,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新OCR数据失败", zap.Error(err))
|
||||
return fmt.Errorf("更新OCR数据失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteCertification 完成认证
|
||||
func (r *GormEnterpriseInfoRepository) CompleteCertification(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
updates := map[string]interface{}{
|
||||
"is_certified": true,
|
||||
"certified_at": &now,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("完成认证失败", zap.Error(err))
|
||||
return fmt.Errorf("完成认证失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计企业信息数量
|
||||
func (r *GormEnterpriseInfoRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查企业信息是否存在
|
||||
func (r *GormEnterpriseInfoRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) CreateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Create(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) UpdateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Save(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取企业信息列表
|
||||
func (r *GormEnterpriseInfoRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormEnterpriseInfoRepository) WithTx(tx interface{}) interfaces.Repository[entities.EnterpriseInfo] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/domains/user/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
SMSCodesTable = "sms_codes"
|
||||
)
|
||||
|
||||
// GormSMSCodeRepository 短信验证码GORM仓储实现(无缓存,确保安全性)
|
||||
type GormSMSCodeRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// NewGormSMSCodeRepository 创建短信验证码仓储
|
||||
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
|
||||
return &GormSMSCodeRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
|
||||
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
|
||||
|
||||
// ================ Repository[T] 接口实现 ================
|
||||
|
||||
// Create 创建短信验证码记录(不缓存,确保安全性)
|
||||
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
|
||||
err := r.GetDB(ctx).Create(&smsCode).Error
|
||||
return smsCode, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// Update 更新验证码记录
|
||||
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Save(&smsCode).Error
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建短信验证码
|
||||
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Create(&smsCodes).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&smsCodes).Error
|
||||
return smsCodes, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新短信验证码
|
||||
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Save(&smsCodes).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除短信验证码
|
||||
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取短信验证码列表
|
||||
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, include := range options.Include {
|
||||
query = query.Preload(include)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return smsCodes, query.Find(&smsCodes).Error
|
||||
}
|
||||
|
||||
// ================ BaseRepository 接口实现 ================
|
||||
|
||||
// Delete 删除短信验证码
|
||||
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Exists 检查短信验证码是否存在
|
||||
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Count 统计短信验证码数量
|
||||
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除短信验证码
|
||||
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复短信验证码
|
||||
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
// GetByPhone 根据手机号获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetLatestByPhone 根据手机号获取最新短信验证码
|
||||
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetValidByPhone 根据手机号获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("有效的短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("有效的短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
|
||||
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
|
||||
var smsCodes []*entities.SMSCode
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
db := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("phone = ?", query.Phone)
|
||||
}
|
||||
if query.Purpose != "" {
|
||||
db = db.Where("scene = ?", query.Purpose)
|
||||
}
|
||||
if query.Status != "" {
|
||||
db = db.Where("used = ?", query.Status == "used")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
if err := db.Offset(offset).Limit(query.PageSize).Order("created_at DESC").Find(&smsCodes).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return smsCodes, total, nil
|
||||
}
|
||||
|
||||
// CreateCode 创建验证码
|
||||
func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, code string, purpose string) (entities.SMSCode, error) {
|
||||
smsCode := entities.SMSCode{
|
||||
Phone: phone,
|
||||
Code: code,
|
||||
Scene: entities.SMSScene(purpose), // 使用Scene字段
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
|
||||
}
|
||||
|
||||
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// ValidateCode 验证验证码
|
||||
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// InvalidateCode 使验证码失效
|
||||
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND used_at IS NULL", phone).
|
||||
Update("used_at", &now).Error
|
||||
}
|
||||
|
||||
// CheckSendFrequency 检查发送频率
|
||||
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
|
||||
// 检查1分钟内是否已发送
|
||||
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
|
||||
var count int64
|
||||
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
|
||||
Count(&count).Error
|
||||
|
||||
// 如果1分钟内已发送,则返回false(不允许发送)
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
// GetTodaySendCount 获取今日发送数量
|
||||
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
var count int64
|
||||
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetCodeStats 获取验证码统计
|
||||
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
|
||||
var stats repositories.SMSCodeStats
|
||||
|
||||
// 计算指定天数前的日期
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
// 总发送数
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, startDate).
|
||||
Count(&stats.TotalSent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总验证数
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
|
||||
Count(&stats.TotalValidated).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 成功率
|
||||
if stats.TotalSent > 0 {
|
||||
stats.SuccessRate = float64(stats.TotalValidated) / float64(stats.TotalSent) * 100
|
||||
}
|
||||
|
||||
// 今日发送数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&stats.TodaySent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,720 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/domains/user/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
UsersTable = "users"
|
||||
UserCacheTTL = 30 * 60 // 30分钟
|
||||
)
|
||||
|
||||
// 定义错误常量
|
||||
var (
|
||||
// ErrUserNotFound 用户不存在错误
|
||||
ErrUserNotFound = errors.New("用户不存在")
|
||||
)
|
||||
|
||||
type GormUserRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.UserRepository = (*GormUserRepository)(nil)
|
||||
|
||||
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
|
||||
return &GormUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
|
||||
err := r.CreateEntity(ctx, &user)
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.SmartGetByID(ctx, id, &user)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, errors.New("用户不存在")
|
||||
}
|
||||
return entities.User{}, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) BatchGetByIDsWithEnterpriseInfo(ctx context.Context, ids []string) ([]*entities.User, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*entities.User{}, nil
|
||||
}
|
||||
|
||||
var users []*entities.User
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id IN ?", ids).Find(&users).Error; err != nil {
|
||||
r.GetLogger().Error("批量查询用户失败", zap.Error(err), zap.Strings("ids", ids))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
// 如果指定了排除的用户ID,则排除该用户的记录
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("users.id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
|
||||
return r.UpdateEntity(ctx, &user)
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Create(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Save(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.SmartList(ctx, &users, options)
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
|
||||
var users []*entities.User
|
||||
err := r.GetDB(ctx).Where("user_type = ?", userType).Order("created_at DESC").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
|
||||
var users []*entities.User
|
||||
var total int64
|
||||
|
||||
// 构建查询条件,预加载企业信息
|
||||
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
|
||||
}
|
||||
if query.UserType != "" {
|
||||
db = db.Where("users.user_type = ?", query.UserType)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
db = db.Where("users.active = ?", *query.IsActive)
|
||||
}
|
||||
if query.IsCertified != nil {
|
||||
db = db.Where("users.is_certified = ?", *query.IsCertified)
|
||||
}
|
||||
if query.CompanyName != "" {
|
||||
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("users.created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("users.created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序(默认按创建时间倒序)
|
||||
db = db.Order("users.created_at DESC")
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
if err := db.Offset(offset).Limit(query.PageSize).Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_login_at": &now,
|
||||
"updated_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("password", newPassword).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ? AND password = ?", userID, password).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", true).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", false).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"login_count": gorm.Expr("login_count + 1"),
|
||||
"last_login_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数
|
||||
if err := db.Model(&entities.User{}).Where("active = ?", true).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 指定时间范围内的登录数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemUserStats 获取系统用户统计信息
|
||||
func (r *GormUserRepository) GetSystemUserStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数(最近30天有登录)
|
||||
thirtyDaysAgo := time.Now().AddDate(0, 0, -30)
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", thirtyDaysAgo).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemUserStatsByDateRange 获取系统指定时间范围内的用户统计信息
|
||||
func (r *GormUserRepository) GetSystemUserStatsByDateRange(ctx context.Context, startDate, endDate time.Time) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 指定时间范围内的登录数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemDailyUserStats 获取系统每日用户统计
|
||||
func (r *GormUserRepository) GetSystemDailyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE DATE(created_at) >= $1
|
||||
AND DATE(created_at) <= $2
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyUserStats 获取系统每月用户统计
|
||||
func (r *GormUserRepository) GetSystemMonthlyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE created_at >= $1
|
||||
AND created_at <= $2
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemDailyCertificationStats 获取系统每日认证用户统计(基于is_certified字段)
|
||||
func (r *GormUserRepository) GetSystemDailyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(updated_at) as date,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_certified = true
|
||||
AND DATE(updated_at) >= $1
|
||||
AND DATE(updated_at) <= $2
|
||||
GROUP BY DATE(updated_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyCertificationStats 获取系统每月认证用户统计(基于is_certified字段)
|
||||
func (r *GormUserRepository) GetSystemMonthlyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(updated_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_certified = true
|
||||
AND updated_at >= $1
|
||||
AND updated_at <= $2
|
||||
GROUP BY TO_CHAR(updated_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserCallRankingByCalls 按调用次数获取用户排行
|
||||
func (r *GormUserRepository) GetUserCallRankingByCalls(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
AND DATE(ac.created_at) = CURRENT_DATE
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserCallRankingByConsumption 按消费金额获取用户排行
|
||||
func (r *GormUserRepository) GetUserCallRankingByConsumption(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
AND DATE(wt.created_at) = CURRENT_DATE
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
AND DATE_TRUNC('month', wt.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetRechargeRanking 获取充值排行(排除赠送,只统计成功状态)
|
||||
func (r *GormUserRepository) GetRechargeRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND DATE(rr.created_at) = CURRENT_DATE
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND DATE_TRUNC('month', rr.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
Reference in New Issue
Block a user