Files
tyapi-server/internal/domains/product/services/product_subscription_service.go
2025-08-02 02:54:21 +08:00

323 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package services
import (
"context"
"errors"
"fmt"
"time"
"gorm.io/gorm"
"go.uber.org/zap"
"tyapi-server/internal/domains/product/entities"
"tyapi-server/internal/domains/product/repositories"
"tyapi-server/internal/domains/product/repositories/queries"
"tyapi-server/internal/shared/interfaces"
"github.com/shopspring/decimal"
)
// ProductSubscriptionService 产品订阅领域服务
// 负责产品订阅相关的业务逻辑,包括订阅验证、订阅管理等
type ProductSubscriptionService struct {
productRepo repositories.ProductRepository
subscriptionRepo repositories.SubscriptionRepository
logger *zap.Logger
}
// NewProductSubscriptionService 创建产品订阅领域服务
func NewProductSubscriptionService(
productRepo repositories.ProductRepository,
subscriptionRepo repositories.SubscriptionRepository,
logger *zap.Logger,
) *ProductSubscriptionService {
return &ProductSubscriptionService{
productRepo: productRepo,
subscriptionRepo: subscriptionRepo,
logger: logger,
}
}
// UserSubscribedProductByCode 查找用户已订阅的产品
func (s *ProductSubscriptionService) UserSubscribedProductByCode(ctx context.Context, userID string, productCode string) (*entities.Subscription, error) {
product, err := s.productRepo.FindByCode(ctx, productCode)
if err != nil {
return nil, err
}
subscription, err := s.subscriptionRepo.FindByUserAndProduct(ctx, userID, product.ID)
if err != nil {
return nil, err
}
return subscription, nil
}
// GetUserSubscribedProduct 查找用户已订阅的产品
func (s *ProductSubscriptionService) GetUserSubscribedProduct(ctx context.Context, userID string, productID string) (*entities.Subscription, error) {
subscription, err := s.subscriptionRepo.FindByUserAndProduct(ctx, userID, productID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return subscription, nil
}
// CanUserSubscribeProduct 检查用户是否可以订阅产品
func (s *ProductSubscriptionService) CanUserSubscribeProduct(ctx context.Context, userID string, productID string) (bool, error) {
// 检查产品是否存在且可订阅
product, err := s.productRepo.GetByID(ctx, productID)
if err != nil {
return false, fmt.Errorf("产品不存在: %w", err)
}
if !product.CanBeSubscribed() {
return false, errors.New("产品不可订阅")
}
// 检查用户是否已有该产品的订阅
existingSubscription, err := s.subscriptionRepo.FindByUserAndProduct(ctx, userID, productID)
if err == nil && existingSubscription != nil {
return false, errors.New("用户已有该产品的订阅")
}
return true, nil
}
// CreateSubscription 创建订阅
func (s *ProductSubscriptionService) CreateSubscription(ctx context.Context, userID, productID string) (*entities.Subscription, error) {
// 检查是否可以订阅
canSubscribe, err := s.CanUserSubscribeProduct(ctx, userID, productID)
if err != nil {
return nil, err
}
if !canSubscribe {
return nil, errors.New("无法订阅该产品")
}
// 获取产品信息以获取价格
product, err := s.productRepo.GetByID(ctx, productID)
if err != nil {
return nil, fmt.Errorf("产品不存在: %w", err)
}
// 创建订阅
subscription := &entities.Subscription{
UserID: userID,
ProductID: productID,
Price: product.Price,
}
createdSubscription, err := s.subscriptionRepo.Create(ctx, *subscription)
if err != nil {
s.logger.Error("创建订阅失败", zap.Error(err))
return nil, fmt.Errorf("创建订阅失败: %w", err)
}
s.logger.Info("订阅创建成功",
zap.String("subscription_id", createdSubscription.ID),
zap.String("user_id", userID),
zap.String("product_id", productID),
)
return &createdSubscription, nil
}
// ListSubscriptions 获取订阅列表
func (s *ProductSubscriptionService) ListSubscriptions(ctx context.Context, query *queries.ListSubscriptionsQuery) ([]*entities.Subscription, int64, error) {
return s.subscriptionRepo.ListSubscriptions(ctx, query)
}
// GetUserSubscriptions 获取用户订阅列表
func (s *ProductSubscriptionService) GetUserSubscriptions(ctx context.Context, userID string) ([]*entities.Subscription, error) {
return s.subscriptionRepo.FindByUserID(ctx, userID)
}
// GetSubscriptionByID 根据ID获取订阅
func (s *ProductSubscriptionService) GetSubscriptionByID(ctx context.Context, subscriptionID string) (*entities.Subscription, error) {
subscription, err := s.subscriptionRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, fmt.Errorf("订阅不存在: %w", err)
}
return &subscription, nil
}
// CancelSubscription 取消订阅
func (s *ProductSubscriptionService) CancelSubscription(ctx context.Context, subscriptionID string) error {
// 由于订阅实体没有状态字段,这里直接删除订阅
if err := s.subscriptionRepo.Delete(ctx, subscriptionID); err != nil {
s.logger.Error("取消订阅失败", zap.Error(err))
return fmt.Errorf("取消订阅失败: %w", err)
}
s.logger.Info("订阅取消成功",
zap.String("subscription_id", subscriptionID),
)
return nil
}
// GetProductStats 获取产品统计信息
func (s *ProductSubscriptionService) GetProductStats(ctx context.Context) (map[string]int64, error) {
stats := make(map[string]int64)
total, err := s.productRepo.CountByCategory(ctx, "")
if err == nil {
stats["total"] = total
}
enabled, err := s.productRepo.CountEnabled(ctx)
if err == nil {
stats["enabled"] = enabled
}
visible, err := s.productRepo.CountVisible(ctx)
if err == nil {
stats["visible"] = visible
}
return stats, nil
}
func (s *ProductSubscriptionService) SaveSubscription(ctx context.Context, subscription *entities.Subscription) error {
exists, err := s.subscriptionRepo.Exists(ctx, subscription.ID)
if err != nil {
return fmt.Errorf("检查订阅是否存在失败: %w", err)
}
if exists {
return s.subscriptionRepo.Update(ctx, *subscription)
} else {
_, err := s.subscriptionRepo.Create(ctx, *subscription)
if err != nil {
return fmt.Errorf("创建订阅失败: %w", err)
}
return nil
}
}
// IncrementSubscriptionAPIUsage 增加订阅API使用次数使用乐观锁带重试机制
func (s *ProductSubscriptionService) IncrementSubscriptionAPIUsage(ctx context.Context, subscriptionID string, increment int64) error {
const maxRetries = 3
const baseDelay = 10 * time.Millisecond
for attempt := 0; attempt < maxRetries; attempt++ {
// 使用乐观锁直接更新数据库
err := s.subscriptionRepo.IncrementAPIUsageWithOptimisticLock(ctx, subscriptionID, increment)
if err == nil {
// 更新成功
if attempt > 0 {
s.logger.Info("订阅API使用次数更新成功重试后",
zap.String("subscription_id", subscriptionID),
zap.Int64("increment", increment),
zap.Int("retry_count", attempt))
} else {
s.logger.Info("订阅API使用次数更新成功",
zap.String("subscription_id", subscriptionID),
zap.Int64("increment", increment))
}
return nil
}
// 检查是否是版本冲突错误
if errors.Is(err, gorm.ErrRecordNotFound) {
// 版本冲突,等待后重试
if attempt < maxRetries-1 {
delay := time.Duration(attempt+1) * baseDelay
s.logger.Debug("订阅版本冲突,准备重试",
zap.String("subscription_id", subscriptionID),
zap.Int("attempt", attempt+1),
zap.Duration("delay", delay))
time.Sleep(delay)
continue
}
// 最后一次重试失败
s.logger.Error("订阅不存在或版本冲突,重试次数已用完",
zap.String("subscription_id", subscriptionID),
zap.Int("max_retries", maxRetries),
zap.Error(err))
return fmt.Errorf("订阅不存在或已被其他操作修改(重试%d次后失败: %w", maxRetries, err)
}
// 其他错误直接返回,不重试
s.logger.Error("更新订阅API使用次数失败",
zap.String("subscription_id", subscriptionID),
zap.Int64("increment", increment),
zap.Error(err))
return fmt.Errorf("更新订阅API使用次数失败: %w", err)
}
return fmt.Errorf("更新失败,已重试%d次", maxRetries)
}
// GetSubscriptionStats 获取订阅统计信息
func (s *ProductSubscriptionService) GetSubscriptionStats(ctx context.Context) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// 获取总订阅数
totalSubscriptions, err := s.subscriptionRepo.Count(ctx, interfaces.CountOptions{})
if err != nil {
s.logger.Error("获取订阅总数失败", zap.Error(err))
return nil, fmt.Errorf("获取订阅总数失败: %w", err)
}
stats["total_subscriptions"] = totalSubscriptions
// 获取总收入
totalRevenue, err := s.subscriptionRepo.GetTotalRevenue(ctx)
if err != nil {
s.logger.Error("获取总收入失败", zap.Error(err))
return nil, fmt.Errorf("获取总收入失败: %w", err)
}
stats["total_revenue"] = totalRevenue
return stats, nil
}
// GetUserSubscriptionStats 获取用户订阅统计信息
func (s *ProductSubscriptionService) GetUserSubscriptionStats(ctx context.Context, userID string) (map[string]interface{}, error) {
stats := make(map[string]interface{})
// 获取用户订阅数
userSubscriptions, err := s.subscriptionRepo.FindByUserID(ctx, userID)
if err != nil {
s.logger.Error("获取用户订阅失败", zap.Error(err))
return nil, fmt.Errorf("获取用户订阅失败: %w", err)
}
// 计算用户总收入
var totalRevenue float64
for _, subscription := range userSubscriptions {
totalRevenue += subscription.Price.InexactFloat64()
}
stats["total_subscriptions"] = int64(len(userSubscriptions))
stats["total_revenue"] = totalRevenue
return stats, nil
}
// UpdateSubscriptionPrice 更新订阅价格
func (s *ProductSubscriptionService) UpdateSubscriptionPrice(ctx context.Context, subscriptionID string, newPrice float64) error {
// 获取订阅
subscription, err := s.subscriptionRepo.GetByID(ctx, subscriptionID)
if err != nil {
return fmt.Errorf("订阅不存在: %w", err)
}
// 更新价格
subscription.Price = decimal.NewFromFloat(newPrice)
subscription.Version++ // 增加版本号
// 保存更新
if err := s.subscriptionRepo.Update(ctx, subscription); err != nil {
s.logger.Error("更新订阅价格失败", zap.Error(err))
return fmt.Errorf("更新订阅价格失败: %w", err)
}
s.logger.Info("订阅价格更新成功",
zap.String("subscription_id", subscriptionID),
zap.Float64("new_price", newPrice))
return nil
}