This commit is contained in:
2025-07-28 01:46:39 +08:00
parent b03129667a
commit 357639462a
219 changed files with 21634 additions and 8138 deletions

View File

@@ -0,0 +1,238 @@
package api
import (
"context"
"time"
"tyapi-server/internal/domains/api/entities"
"tyapi-server/internal/domains/api/repositories"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ApiCallsTable = "api_calls"
ApiCallCacheTTL = 10 * time.Minute
)
// ApiCallWithProduct 包含产品名称的API调用记录
type ApiCallWithProduct struct {
entities.ApiCall
ProductName string `json:"product_name" gorm:"column:product_name"`
}
type GormApiCallRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ApiCallRepository = (*GormApiCallRepository)(nil)
func NewGormApiCallRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiCallRepository {
return &GormApiCallRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiCallsTable),
}
}
func (r *GormApiCallRepository) Create(ctx context.Context, call *entities.ApiCall) error {
return r.CreateEntity(ctx, call)
}
func (r *GormApiCallRepository) Update(ctx context.Context, call *entities.ApiCall) error {
return r.UpdateEntity(ctx, call)
}
func (r *GormApiCallRepository) FindById(ctx context.Context, id string) (*entities.ApiCall, error) {
var call entities.ApiCall
err := r.SmartGetByID(ctx, id, &call)
if err != nil {
return nil, err
}
return &call, nil
}
func (r *GormApiCallRepository) FindByUserId(ctx context.Context, userId string, limit, offset int) ([]*entities.ApiCall, error) {
var calls []*entities.ApiCall
options := database.CacheListOptions{
Where: "user_id = ?",
Args: []interface{}{userId},
Order: "created_at DESC",
Limit: limit,
Offset: offset,
}
err := r.ListWithCache(ctx, &calls, ApiCallCacheTTL, options)
return calls, err
}
func (r *GormApiCallRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
var calls []*entities.ApiCall
var total int64
// 构建查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 获取总数
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
return calls, total, err
}
func (r *GormApiCallRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
var calls []*entities.ApiCall
var total int64
// 构建基础查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// TransactionID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品ID筛选
if productId, ok := filters["product_id"].(string); ok && productId != "" {
whereCondition += " AND product_id = ?"
whereArgs = append(whereArgs, productId)
}
// 状态筛选
if status, ok := filters["status"].(string); ok && status != "" {
whereCondition += " AND status = ?"
whereArgs = append(whereArgs, status)
}
}
// 获取总数
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
return calls, total, err
}
// ListByUserIdWithFiltersAndProductName 根据用户ID和筛选条件获取API调用记录包含产品名称
func (r *GormApiCallRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
var callsWithProduct []*ApiCallWithProduct
var total int64
// 构建基础查询条件
whereCondition := "ac.user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND ac.created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND ac.created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// TransactionID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND ac.transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereCondition += " AND p.name LIKE ?"
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 状态筛选
if status, ok := filters["status"].(string); ok && status != "" {
whereCondition += " AND ac.status = ?"
whereArgs = append(whereArgs, status)
}
}
// 构建JOIN查询
query := r.GetDB(ctx).Table("api_calls ac").
Select("ac.*, p.name as product_name").
Joins("LEFT JOIN product p ON ac.product_id = p.id").
Where(whereCondition, whereArgs...)
// 获取总数
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, nil, 0, err
}
total = count
// 应用排序和分页
if options.Sort != "" {
query = query.Order("ac." + options.Sort + " " + options.Order)
} else {
query = query.Order("ac.created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
// 执行查询
err = query.Find(&callsWithProduct).Error
if err != nil {
return nil, nil, 0, err
}
// 转换为entities.ApiCall并构建产品名称映射
var calls []*entities.ApiCall
productNameMap := make(map[string]string)
for _, c := range callsWithProduct {
call := c.ApiCall
calls = append(calls, &call)
// 构建产品ID到产品名称的映射
if c.ProductName != "" {
productNameMap[call.ID] = c.ProductName
}
}
return productNameMap, calls, total, nil
}
func (r *GormApiCallRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ?", userId)
}
func (r *GormApiCallRepository) FindByTransactionId(ctx context.Context, transactionId string) (*entities.ApiCall, error) {
var call entities.ApiCall
err := r.FindOne(ctx, &call, "transaction_id = ?", transactionId)
if err != nil {
return nil, err
}
return &call, nil
}

View File

@@ -0,0 +1,56 @@
package api
import (
"context"
"tyapi-server/internal/domains/api/entities"
"tyapi-server/internal/domains/api/repositories"
"tyapi-server/internal/shared/database"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ApiUsersTable = "api_users"
ApiUserCacheTTL = 30 * time.Minute
)
type GormApiUserRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ApiUserRepository = (*GormApiUserRepository)(nil)
func NewGormApiUserRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiUserRepository {
return &GormApiUserRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiUsersTable),
}
}
func (r *GormApiUserRepository) Create(ctx context.Context, user *entities.ApiUser) error {
return r.CreateEntity(ctx, user)
}
func (r *GormApiUserRepository) Update(ctx context.Context, user *entities.ApiUser) error {
return r.UpdateEntity(ctx, user)
}
func (r *GormApiUserRepository) FindByAccessId(ctx context.Context, accessId string) (*entities.ApiUser, error) {
var user entities.ApiUser
err := r.SmartGetByField(ctx, &user, "access_id", accessId, ApiUserCacheTTL)
if err != nil {
return nil, err
}
return &user, nil
}
func (r *GormApiUserRepository) FindByUserId(ctx context.Context, userId string) (*entities.ApiUser, error) {
var user entities.ApiUser
err := r.SmartGetByField(ctx, &user, "user_id", userId, ApiUserCacheTTL)
if err != nil {
return nil, err
}
return &user, nil
}

View File

@@ -123,6 +123,15 @@ func (r *GormCertificationQueryRepository) Exists(ctx context.Context, id string
return r.ExistsEntity(ctx, id, &entities.Certification{})
}
func (r *GormCertificationQueryRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Certification{}).Where("user_id = ?", userID).Count(&count).Error
if err != nil {
return false, fmt.Errorf("查询用户认证是否存在失败: %w", err)
}
return count > 0, nil
}
// ================ 列表查询 ================
// List 分页列表查询
@@ -278,12 +287,12 @@ func (r *GormCertificationQueryRepository) FindByEsignFlowID(ctx context.Context
func (r *GormCertificationQueryRepository) ListPendingRetry(ctx context.Context, maxRetryCount int) ([]*entities.Certification, error) {
var certifications []*entities.Certification
err := r.WithoutCache().GetDB(ctx).
Where("status IN ? AND retry_count < ?",
Where("status IN ? AND retry_count < ?",
[]enums.CertificationStatus{
enums.StatusInfoRejected,
enums.StatusContractRejected,
enums.StatusContractExpired,
},
},
maxRetryCount).
Order("created_at ASC").
Find(&certifications).Error
@@ -367,61 +376,7 @@ func (r *GormCertificationQueryRepository) GetUserActiveCertification(ctx contex
// ================ 统计查询 ================
// GetStatistics 获取统计数据
func (r *GormCertificationQueryRepository) GetStatistics(ctx context.Context, period repositories.CertificationTimePeriod) (*repositories.CertificationStatistics, error) {
now := time.Now()
var startDate time.Time
switch period {
case repositories.PeriodDaily:
startDate = now.AddDate(0, 0, -1)
case repositories.PeriodWeekly:
startDate = now.AddDate(0, 0, -7)
case repositories.PeriodMonthly:
startDate = now.AddDate(0, -1, 0)
case repositories.PeriodYearly:
startDate = now.AddDate(-1, 0, 0)
default:
startDate = now.AddDate(0, 0, -7)
}
// 使用短期缓存进行统计查询
var totalCount int64
r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).
Where("created_at BETWEEN ? AND ?", startDate, now).
Count(&totalCount)
var completedCount int64
r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).
Where("created_at BETWEEN ? AND ? AND status = ?", startDate, now, enums.StatusContractSigned).
Count(&completedCount)
successRate := float64(0)
if totalCount > 0 {
successRate = float64(completedCount) / float64(totalCount)
}
return &repositories.CertificationStatistics{
Period: period,
StartDate: startDate,
EndDate: now,
TotalCertifications: totalCount,
CompletedCount: completedCount,
SuccessRate: successRate,
StatusDistribution: make(map[enums.CertificationStatus]int64),
FailureDistribution: make(map[enums.FailureReason]int64),
AvgProcessingTime: 24 * time.Hour, // 简化实现
}, nil
}
// CountByStatus 按状态统计
func (r *GormCertificationQueryRepository) CountByStatus(ctx context.Context, status enums.CertificationStatus) (int64, error) {
var count int64
if err := r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).Where("status = ?", status).Count(&count).Error; err != nil {
return 0, fmt.Errorf("按状态统计认证失败: %w", err)
}
return count, nil
}
// CountByFailureReason 按失败原因统计
func (r *GormCertificationQueryRepository) CountByFailureReason(ctx context.Context, reason enums.FailureReason) (int64, error) {
@@ -511,4 +466,4 @@ func (r *GormCertificationQueryRepository) GetCacheStats() map[string]interface{
QueryCachePatternUser,
}
return stats
}
}

View File

@@ -0,0 +1,64 @@
package certification
import (
"context"
"tyapi-server/internal/domains/certification/entities"
"tyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
EnterpriseInfoSubmitRecordsTable = "enterprise_info_submit_records"
)
type GormEnterpriseInfoSubmitRecordRepository struct {
*database.CachedBaseRepositoryImpl
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.EnterpriseInfoSubmitRecord{})
}
func NewGormEnterpriseInfoSubmitRecordRepository(db *gorm.DB, logger *zap.Logger) *GormEnterpriseInfoSubmitRecordRepository {
return &GormEnterpriseInfoSubmitRecordRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, EnterpriseInfoSubmitRecordsTable),
}
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Create(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
return r.CreateEntity(ctx, record)
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Update(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
return r.UpdateEntity(ctx, record)
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Exists(ctx context.Context, ID string) (bool, error) {
return r.ExistsEntity(ctx, ID, &entities.EnterpriseInfoSubmitRecord{})
}
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
var record entities.EnterpriseInfoSubmitRecord
err := r.GetDB(ctx).
Where("user_id = ?", userID).
Order("submit_at DESC").
First(&record).Error
if err != nil {
return nil, err
}
return &record, nil
}
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestVerifiedByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
var record entities.EnterpriseInfoSubmitRecord
err := r.GetDB(ctx).
Where("user_id = ? AND status = ?", userID, "verified").
Order("verified_at DESC").
First(&record).Error
if err != nil {
return nil, err
}
return &record, nil
}

View File

@@ -0,0 +1,98 @@
package repositories
import (
"context"
"errors"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
AlipayOrdersTable = "alipay_orders"
)
type GormAlipayOrderRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.AlipayOrderRepository = (*GormAlipayOrderRepository)(nil)
func NewGormAlipayOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.AlipayOrderRepository {
return &GormAlipayOrderRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, AlipayOrdersTable),
}
}
func (r *GormAlipayOrderRepository) Create(ctx context.Context, order entities.AlipayOrder) (entities.AlipayOrder, error) {
err := r.CreateEntity(ctx, &order)
return order, err
}
func (r *GormAlipayOrderRepository) GetByID(ctx context.Context, id string) (entities.AlipayOrder, error) {
var order entities.AlipayOrder
err := r.SmartGetByID(ctx, id, &order)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.AlipayOrder{}, gorm.ErrRecordNotFound
}
return entities.AlipayOrder{}, err
}
return order, nil
}
func (r *GormAlipayOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.AlipayOrder, error) {
var order entities.AlipayOrder
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &order, nil
}
func (r *GormAlipayOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.AlipayOrder, error) {
var order entities.AlipayOrder
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &order, nil
}
func (r *GormAlipayOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.AlipayOrder, error) {
var orders []entities.AlipayOrder
err := r.GetDB(ctx).
Joins("JOIN recharge_records ON alipay_orders.recharge_id = recharge_records.id").
Where("recharge_records.user_id = ?", userID).
Order("alipay_orders.created_at DESC").
Find(&orders).Error
return orders, err
}
func (r *GormAlipayOrderRepository) Update(ctx context.Context, order entities.AlipayOrder) error {
return r.UpdateEntity(ctx, &order)
}
func (r *GormAlipayOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.AlipayOrderStatus) error {
return r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Update("status", status).Error
}
func (r *GormAlipayOrderRepository) Delete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.AlipayOrder{}, "id = ?", id).Error
}
func (r *GormAlipayOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}

View File

@@ -1,697 +0,0 @@
package repositories
import (
"context"
"errors"
"time"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/domains/finance/repositories/queries"
"tyapi-server/internal/shared/interfaces"
)
// GormWalletRepository 钱包GORM仓储实现
type GormWalletRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
// NewGormWalletRepository 创建钱包GORM仓储
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
return &GormWalletRepository{
db: db,
logger: logger,
}
}
// Create 创建钱包
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
r.logger.Info("创建钱包", zap.String("user_id", wallet.UserID))
err := r.db.WithContext(ctx).Create(&wallet).Error
return wallet, err
}
// GetByID 根据ID获取钱包
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("id = ?", id).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Wallet{}, gorm.ErrRecordNotFound
}
return entities.Wallet{}, err
}
return wallet, err
}
// Update 更新钱包
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
r.logger.Info("更新钱包", zap.String("id", wallet.ID))
return r.db.WithContext(ctx).Save(&wallet).Error
}
// Delete 删除钱包
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除钱包", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
}
// SoftDelete 软删除钱包
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除钱包", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
}
// Restore 恢复钱包
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复钱包", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.Wallet{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计钱包数量
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查钱包是否存在
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建钱包
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
r.logger.Info("批量创建钱包", zap.Int("count", len(wallets)))
return r.db.WithContext(ctx).Create(&wallets).Error
}
// GetByIDs 根据ID列表获取钱包
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
var wallets []entities.Wallet
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&wallets).Error
return wallets, err
}
// UpdateBatch 批量更新钱包
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
r.logger.Info("批量更新钱包", zap.Int("count", len(wallets)))
return r.db.WithContext(ctx).Save(&wallets).Error
}
// DeleteBatch 批量删除钱包
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除钱包", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
}
// List 获取钱包列表
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
var wallets []entities.Wallet
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return wallets, query.Find(&wallets).Error
}
// WithTx 使用事务
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormWalletRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// FindByUserID 根据用户ID查找钱包
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
return nil, err
}
return &wallet, nil
}
// ExistsByUserID 检查用户钱包是否存在
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}
// GetTotalBalance 获取总余额
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
var total decimal.Decimal
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
return total, err
}
// GetActiveWalletCount 获取激活钱包数量
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
return count, err
}
// ================ 接口要求的方法 ================
// GetByUserID 根据用户ID获取钱包
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
// GetByWalletAddress 根据钱包地址获取钱包
func (r *GormWalletRepository) GetByWalletAddress(ctx context.Context, walletAddress string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("wallet_address = ?", walletAddress).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
// GetByWalletType 根据钱包类型获取钱包
func (r *GormWalletRepository) GetByWalletType(ctx context.Context, userID string, walletType string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.db.WithContext(ctx).Where("user_id = ? AND wallet_type = ?", userID, walletType).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
// ListWallets 获取钱包列表(带分页和筛选)
func (r *GormWalletRepository) ListWallets(ctx context.Context, query *queries.ListWalletsQuery) ([]*entities.Wallet, int64, error) {
var wallets []entities.Wallet
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Wallet{})
// 应用筛选条件
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.WalletType != "" {
dbQuery = dbQuery.Where("wallet_type = ?", query.WalletType)
}
if query.WalletAddress != "" {
dbQuery = dbQuery.Where("wallet_address LIKE ?", "%"+query.WalletAddress+"%")
}
if query.IsActive != nil {
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&wallets).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
walletPtrs := make([]*entities.Wallet, len(wallets))
for i := range wallets {
walletPtrs[i] = &wallets[i]
}
return walletPtrs, total, nil
}
// UpdateBalance 更新钱包余额
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
}
// AddBalance 增加钱包余额
func (r *GormWalletRepository) AddBalance(ctx context.Context, walletID string, amount string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance + ?", amount)).Error
}
// SubtractBalance 减少钱包余额
func (r *GormWalletRepository) SubtractBalance(ctx context.Context, walletID string, amount string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance - ?", amount)).Error
}
// ActivateWallet 激活钱包
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
}
// DeactivateWallet 停用钱包
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
}
// GetStats 获取财务统计信息
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 总钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 激活钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 总余额
var totalBalance decimal.Decimal
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}
// GetUserWalletStats 获取用户钱包统计信息
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 用户钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 用户激活钱包数
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 用户总余额
var totalBalance decimal.Decimal
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 用户今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}
// GormUserSecretsRepository 用户密钥GORM仓储实现
type GormUserSecretsRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ domain_finance_repo.UserSecretsRepository = (*GormUserSecretsRepository)(nil)
// NewGormUserSecretsRepository 创建用户密钥GORM仓储
func NewGormUserSecretsRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.UserSecretsRepository {
return &GormUserSecretsRepository{
db: db,
logger: logger,
}
}
// Create 创建用户密钥
func (r *GormUserSecretsRepository) Create(ctx context.Context, secrets entities.UserSecrets) (entities.UserSecrets, error) {
r.logger.Info("创建用户密钥", zap.String("user_id", secrets.UserID))
err := r.db.WithContext(ctx).Create(&secrets).Error
return secrets, err
}
// GetByID 根据ID获取用户密钥
func (r *GormUserSecretsRepository) GetByID(ctx context.Context, id string) (entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("id = ?", id).First(&secrets).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.UserSecrets{}, gorm.ErrRecordNotFound
}
return entities.UserSecrets{}, err
}
return secrets, err
}
// Update 更新用户密钥
func (r *GormUserSecretsRepository) Update(ctx context.Context, secrets entities.UserSecrets) error {
r.logger.Info("更新用户密钥", zap.String("id", secrets.ID))
return r.db.WithContext(ctx).Save(&secrets).Error
}
// Delete 删除用户密钥
func (r *GormUserSecretsRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除用户密钥", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
}
// SoftDelete 软删除用户密钥
func (r *GormUserSecretsRepository) SoftDelete(ctx context.Context, id string) error {
r.logger.Info("软删除用户密钥", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
}
// Restore 恢复用户密钥
func (r *GormUserSecretsRepository) Restore(ctx context.Context, id string) error {
r.logger.Info("恢复用户密钥", zap.String("id", id))
return r.db.WithContext(ctx).Unscoped().Model(&entities.UserSecrets{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// Count 统计用户密钥数量
func (r *GormUserSecretsRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
// Exists 检查用户密钥是否存在
func (r *GormUserSecretsRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建用户密钥
func (r *GormUserSecretsRepository) CreateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
r.logger.Info("批量创建用户密钥", zap.Int("count", len(secrets)))
return r.db.WithContext(ctx).Create(&secrets).Error
}
// GetByIDs 根据ID列表获取用户密钥
func (r *GormUserSecretsRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.UserSecrets, error) {
var secrets []entities.UserSecrets
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&secrets).Error
return secrets, err
}
// UpdateBatch 批量更新用户密钥
func (r *GormUserSecretsRepository) UpdateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
r.logger.Info("批量更新用户密钥", zap.Int("count", len(secrets)))
return r.db.WithContext(ctx).Save(&secrets).Error
}
// DeleteBatch 批量删除用户密钥
func (r *GormUserSecretsRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除用户密钥", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id IN ?", ids).Error
}
// List 获取用户密钥列表
func (r *GormUserSecretsRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.UserSecrets, error) {
var secrets []entities.UserSecrets
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return secrets, query.Find(&secrets).Error
}
// WithTx 使用事务
func (r *GormUserSecretsRepository) WithTx(tx interface{}) interfaces.Repository[entities.UserSecrets] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormUserSecretsRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}
// FindByUserID 根据用户ID查找用户密钥
func (r *GormUserSecretsRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &secrets, nil
}
// FindByAccessID 根据访问ID查找用户密钥
func (r *GormUserSecretsRepository) FindByAccessID(ctx context.Context, accessID string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("access_id = ?", accessID).First(&secrets).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &secrets, nil
}
// ExistsByUserID 检查用户密钥是否存在
func (r *GormUserSecretsRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}
// ExistsByAccessID 检查访问ID是否存在
func (r *GormUserSecretsRepository) ExistsByAccessID(ctx context.Context, accessID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Count(&count).Error
return count > 0, err
}
// UpdateLastUsedAt 更新最后使用时间
func (r *GormUserSecretsRepository) UpdateLastUsedAt(ctx context.Context, accessID string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Update("last_used_at", time.Now()).Error
}
// DeactivateByUserID 停用用户密钥
func (r *GormUserSecretsRepository) DeactivateByUserID(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Update("is_active", false).Error
}
// RegenerateAccessKey 重新生成访问密钥
func (r *GormUserSecretsRepository) RegenerateAccessKey(ctx context.Context, userID string, accessID, accessKey string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
"access_id": accessID,
"access_key": accessKey,
"updated_at": time.Now(),
}).Error
}
// GetExpiredSecrets 获取过期的密钥
func (r *GormUserSecretsRepository) GetExpiredSecrets(ctx context.Context) ([]entities.UserSecrets, error) {
var secrets []entities.UserSecrets
err := r.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Find(&secrets).Error
return secrets, err
}
// DeleteExpiredSecrets 删除过期的密钥
func (r *GormUserSecretsRepository) DeleteExpiredSecrets(ctx context.Context) error {
return r.db.WithContext(ctx).Where("expires_at < ?", time.Now()).Delete(&entities.UserSecrets{}).Error
}
// ================ 接口要求的方法 ================
// GetByUserID 根据用户ID获取用户密钥
func (r *GormUserSecretsRepository) GetByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &secrets, nil
}
// GetBySecretType 根据用户ID和密钥类型获取用户密钥
func (r *GormUserSecretsRepository) GetBySecretType(ctx context.Context, userID string, secretType string) (*entities.UserSecrets, error) {
var secrets entities.UserSecrets
err := r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).First(&secrets).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &secrets, nil
}
// ListUserSecrets 获取用户密钥列表(带分页和筛选)
func (r *GormUserSecretsRepository) ListUserSecrets(ctx context.Context, query *queries.ListUserSecretsQuery) ([]*entities.UserSecrets, int64, error) {
var secrets []entities.UserSecrets
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
// 应用筛选条件
if query.UserID != "" {
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
}
if query.SecretType != "" {
dbQuery = dbQuery.Where("secret_type = ?", query.SecretType)
}
if query.IsActive != nil {
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
}
if query.StartDate != "" {
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
// 默认排序
dbQuery = dbQuery.Order("created_at DESC")
// 查询数据
if err := dbQuery.Find(&secrets).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
secretPtrs := make([]*entities.UserSecrets, len(secrets))
for i := range secrets {
secretPtrs[i] = &secrets[i]
}
return secretPtrs, total, nil
}
// UpdateSecret 更新密钥
func (r *GormUserSecretsRepository) UpdateSecret(ctx context.Context, userID string, secretType string, secretValue string) error {
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
Where("user_id = ? AND secret_type = ?", userID, secretType).
Update("secret_value", secretValue).Error
}
// DeleteSecret 删除密钥
func (r *GormUserSecretsRepository) DeleteSecret(ctx context.Context, userID string, secretType string) error {
return r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).
Delete(&entities.UserSecrets{}).Error
}
// ValidateSecret 验证密钥
func (r *GormUserSecretsRepository) ValidateSecret(ctx context.Context, userID string, secretType string, secretValue string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
Where("user_id = ? AND secret_type = ? AND secret_value = ?", userID, secretType, secretValue).
Count(&count).Error
return count > 0, err
}

View File

@@ -0,0 +1,178 @@
package repositories
import (
"context"
"errors"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
RechargeRecordsTable = "recharge_records"
)
type GormRechargeRecordRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.RechargeRecordRepository = (*GormRechargeRecordRepository)(nil)
func NewGormRechargeRecordRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.RechargeRecordRepository {
return &GormRechargeRecordRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, RechargeRecordsTable),
}
}
func (r *GormRechargeRecordRepository) Create(ctx context.Context, record entities.RechargeRecord) (entities.RechargeRecord, error) {
err := r.CreateEntity(ctx, &record)
return record, err
}
func (r *GormRechargeRecordRepository) GetByID(ctx context.Context, id string) (entities.RechargeRecord, error) {
var record entities.RechargeRecord
err := r.SmartGetByID(ctx, id, &record)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.RechargeRecord{}, gorm.ErrRecordNotFound
}
return entities.RechargeRecord{}, err
}
return record, nil
}
func (r *GormRechargeRecordRepository) GetByUserID(ctx context.Context, userID string) ([]entities.RechargeRecord, error) {
var records []entities.RechargeRecord
err := r.GetDB(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&records).Error
return records, err
}
func (r *GormRechargeRecordRepository) GetByAlipayOrderID(ctx context.Context, alipayOrderID string) (*entities.RechargeRecord, error) {
var record entities.RechargeRecord
err := r.GetDB(ctx).Where("alipay_order_id = ?", alipayOrderID).First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &record, nil
}
func (r *GormRechargeRecordRepository) GetByTransferOrderID(ctx context.Context, transferOrderID string) (*entities.RechargeRecord, error) {
var record entities.RechargeRecord
err := r.GetDB(ctx).Where("transfer_order_id = ?", transferOrderID).First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &record, nil
}
func (r *GormRechargeRecordRepository) Update(ctx context.Context, record entities.RechargeRecord) error {
return r.UpdateEntity(ctx, &record)
}
func (r *GormRechargeRecordRepository) UpdateStatus(ctx context.Context, id string, status entities.RechargeStatus) error {
return r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Update("status", status).Error
}
func (r *GormRechargeRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.RechargeRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
func (r *GormRechargeRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
func (r *GormRechargeRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.RechargeRecord, error) {
var records []entities.RechargeRecord
query := r.GetDB(ctx).Model(&entities.RechargeRecord{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
query = query.Order("created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&records).Error
return records, err
}
func (r *GormRechargeRecordRepository) CreateBatch(ctx context.Context, records []entities.RechargeRecord) error {
return r.GetDB(ctx).Create(&records).Error
}
func (r *GormRechargeRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.RechargeRecord, error) {
var records []entities.RechargeRecord
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&records).Error
return records, err
}
func (r *GormRechargeRecordRepository) UpdateBatch(ctx context.Context, records []entities.RechargeRecord) error {
return r.GetDB(ctx).Save(&records).Error
}
func (r *GormRechargeRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.RechargeRecord{}, "id IN ?", ids).Error
}
func (r *GormRechargeRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.RechargeRecord] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormRechargeRecordRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), RechargeRecordsTable),
}
}
return r
}
func (r *GormRechargeRecordRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.RechargeRecord{})
}
func (r *GormRechargeRecordRepository) SoftDelete(ctx context.Context, id string) error {
return r.SoftDeleteEntity(ctx, id, &entities.RechargeRecord{})
}
func (r *GormRechargeRecordRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.RechargeRecord{})
}

View File

@@ -0,0 +1,273 @@
package repositories
import (
"context"
"errors"
"fmt"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
WalletsTable = "wallets"
)
type GormWalletRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
return &GormWalletRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletsTable),
}
}
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
err := r.CreateEntity(ctx, &wallet)
return wallet, err
}
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
var wallet entities.Wallet
err := r.SmartGetByID(ctx, id, &wallet)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Wallet{}, gorm.ErrRecordNotFound
}
return entities.Wallet{}, err
}
return wallet, nil
}
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
return r.UpdateEntity(ctx, &wallet)
}
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.Wallet{})
}
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
return r.SoftDeleteEntity(ctx, id, &entities.Wallet{})
}
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.Wallet{})
}
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
return r.GetDB(ctx).Create(&wallets).Error
}
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
var wallets []entities.Wallet
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&wallets).Error
return wallets, err
}
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
return r.GetDB(ctx).Save(&wallets).Error
}
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
}
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
var wallets []entities.Wallet
query := r.GetDB(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
if options.Sort != "" {
order := options.Sort
if options.Order == "desc" {
order += " DESC"
} else {
order += " ASC"
}
query = query.Order(order)
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return wallets, query.Find(&wallets).Error
}
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormWalletRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletsTable),
}
}
return r
}
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
var total decimal.Decimal
err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
return total, err
}
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
return count, err
}
// ================ 接口要求的方法 ================
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
// UpdateBalanceWithVersionRetry 乐观锁自动重试最大重试maxRetry次
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, newBalance string, oldVersion int64) (bool, error) {
maxRetry := 10
for i := 0; i < maxRetry; i++ {
result := r.GetDB(ctx).Model(&entities.Wallet{}).
Where("id = ? AND version = ?", walletID, oldVersion).
Updates(map[string]interface{}{
"balance": newBalance,
"version": oldVersion + 1,
})
if result.Error != nil {
return false, result.Error
}
if result.RowsAffected == 1 {
return true, nil
}
// 并发冲突重试前重新查version
var wallet entities.Wallet
err := r.GetDB(ctx).Where("id = ?", walletID).First(&wallet).Error
if err != nil {
return false, err
}
oldVersion = wallet.Version
}
return false, fmt.Errorf("高并发下余额变动失败,请重试")
}
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
}
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
}
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
}
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 总钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 激活钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 总余额
var totalBalance decimal.Decimal
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 用户钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 用户激活钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 用户总余额
var totalBalance decimal.Decimal
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 用户今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}

View File

@@ -0,0 +1,296 @@
package repositories
import (
"context"
"time"
"tyapi-server/internal/domains/finance/entities"
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// WalletTransactionWithProduct 包含产品名称的钱包交易记录
type WalletTransactionWithProduct struct {
entities.WalletTransaction
ProductName string `json:"product_name" gorm:"column:product_name"`
}
const (
WalletTransactionsTable = "wallet_transactions"
)
type GormWalletTransactionRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.WalletTransactionRepository = (*GormWalletTransactionRepository)(nil)
func NewGormWalletTransactionRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletTransactionRepository {
return &GormWalletTransactionRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletTransactionsTable),
}
}
func (r *GormWalletTransactionRepository) Create(ctx context.Context, transaction entities.WalletTransaction) (entities.WalletTransaction, error) {
err := r.CreateEntity(ctx, &transaction)
return transaction, err
}
func (r *GormWalletTransactionRepository) Update(ctx context.Context, transaction entities.WalletTransaction) error {
return r.UpdateEntity(ctx, &transaction)
}
func (r *GormWalletTransactionRepository) GetByID(ctx context.Context, id string) (entities.WalletTransaction, error) {
var transaction entities.WalletTransaction
err := r.SmartGetByID(ctx, id, &transaction)
return transaction, err
}
func (r *GormWalletTransactionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.WalletTransaction, error) {
var transactions []*entities.WalletTransaction
options := database.CacheListOptions{
Where: "user_id = ?",
Args: []interface{}{userID},
Order: "created_at DESC",
Limit: limit,
Offset: offset,
}
err := r.ListWithCache(ctx, &transactions, 10*time.Minute, options)
return transactions, err
}
func (r *GormWalletTransactionRepository) GetByApiCallID(ctx context.Context, apiCallID string) (*entities.WalletTransaction, error) {
var transaction entities.WalletTransaction
err := r.FindOne(ctx, &transaction, "api_call_id = ?", apiCallID)
if err != nil {
return nil, err
}
return &transaction, nil
}
func (r *GormWalletTransactionRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
var transactions []*entities.WalletTransaction
var total int64
// 构建查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 获取总数
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
return transactions, total, err
}
func (r *GormWalletTransactionRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
var transactions []*entities.WalletTransaction
var total int64
// 构建基础查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// 关键词筛选支持transaction_id和product_name
if keyword, ok := filters["keyword"].(string); ok && keyword != "" {
whereCondition += " AND (transaction_id LIKE ? OR product_id IN (SELECT id FROM product WHERE name LIKE ?))"
whereArgs = append(whereArgs, "%"+keyword+"%", "%"+keyword+"%")
}
// API调用ID筛选
if apiCallId, ok := filters["api_call_id"].(string); ok && apiCallId != "" {
whereCondition += " AND api_call_id LIKE ?"
whereArgs = append(whereArgs, "%"+apiCallId+"%")
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereCondition += " AND amount >= ?"
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereCondition += " AND amount <= ?"
whereArgs = append(whereArgs, maxAmount)
}
}
// 获取总数
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
return transactions, total, err
}
func (r *GormWalletTransactionRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ?", userId)
}
// 实现interfaces.Repository接口的其他方法
func (r *GormWalletTransactionRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.WalletTransaction, error) {
var transactions []entities.WalletTransaction
err := r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
return transactions, err
}
func (r *GormWalletTransactionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
return r.CountWithOptions(ctx, &entities.WalletTransaction{}, options)
}
func (r *GormWalletTransactionRepository) CreateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
return r.CreateBatchEntity(ctx, &transactions)
}
func (r *GormWalletTransactionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.WalletTransaction, error) {
var transactions []entities.WalletTransaction
err := r.GetEntitiesByIDs(ctx, ids, &transactions)
return transactions, err
}
func (r *GormWalletTransactionRepository) UpdateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
return r.UpdateBatchEntity(ctx, &transactions)
}
func (r *GormWalletTransactionRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.DeleteBatchEntity(ctx, ids, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) WithTx(tx interface{}) interfaces.Repository[entities.WalletTransaction] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormWalletTransactionRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletTransactionsTable),
}
}
return r
}
func (r *GormWalletTransactionRepository) SoftDelete(ctx context.Context, id string) error {
return r.SoftDeleteEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
var transactionsWithProduct []*WalletTransactionWithProduct
var total int64
// 构建基础查询条件
whereCondition := "wt.user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND wt.created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND wt.created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// 交易ID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND wt.transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereCondition += " AND p.name LIKE ?"
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereCondition += " AND wt.amount >= ?"
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereCondition += " AND wt.amount <= ?"
whereArgs = append(whereArgs, maxAmount)
}
}
// 构建JOIN查询
query := r.GetDB(ctx).Table("wallet_transactions wt").
Select("wt.*, p.name as product_name").
Joins("LEFT JOIN product p ON wt.product_id = p.id").
Where(whereCondition, whereArgs...)
// 获取总数
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, nil, 0, err
}
total = count
// 应用排序和分页
if options.Sort != "" {
query = query.Order("wt." + options.Sort + " " + options.Order)
} else {
query = query.Order("wt.created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
// 执行查询
err = query.Find(&transactionsWithProduct).Error
if err != nil {
return nil, nil, 0, err
}
// 转换为entities.WalletTransaction并构建产品名称映射
var transactions []*entities.WalletTransaction
productNameMap := make(map[string]string)
for _, t := range transactionsWithProduct {
transaction := t.WalletTransaction
transactions = append(transactions, &transaction)
// 构建产品ID到产品名称的映射
if t.ProductName != "" {
productNameMap[transaction.ProductID] = t.ProductName
}
}
return productNameMap, transactions, total, nil
}

View File

@@ -0,0 +1,92 @@
package repositories
import (
"context"
"errors"
"time"
"tyapi-server/internal/domains/product/entities"
"tyapi-server/internal/domains/product/repositories"
"tyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductApiConfigsTable = "product_api_configs"
ProductApiConfigCacheTTL = 30 * time.Minute
)
type GormProductApiConfigRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ProductApiConfigRepository = (*GormProductApiConfigRepository)(nil)
func NewGormProductApiConfigRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductApiConfigRepository {
return &GormProductApiConfigRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductApiConfigsTable),
}
}
func (r *GormProductApiConfigRepository) Create(ctx context.Context, config entities.ProductApiConfig) error {
return r.CreateEntity(ctx, &config)
}
func (r *GormProductApiConfigRepository) Update(ctx context.Context, config entities.ProductApiConfig) error {
return r.UpdateEntity(ctx, &config)
}
func (r *GormProductApiConfigRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.ProductApiConfig{})
}
func (r *GormProductApiConfigRepository) GetByID(ctx context.Context, id string) (*entities.ProductApiConfig, error) {
var config entities.ProductApiConfig
err := r.SmartGetByID(ctx, id, &config)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &config, nil
}
func (r *GormProductApiConfigRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductApiConfig, error) {
var config entities.ProductApiConfig
err := r.SmartGetByField(ctx, &config, "product_id", productID, ProductApiConfigCacheTTL)
if err != nil {
return nil, err
}
return &config, nil
}
func (r *GormProductApiConfigRepository) FindByProductCode(ctx context.Context, productCode string) (*entities.ProductApiConfig, error) {
var config entities.ProductApiConfig
err := r.GetDB(ctx).Joins("JOIN products ON products.id = product_api_configs.product_id").
Where("products.code = ?", productCode).
First(&config).Error
if err != nil {
return nil, err
}
return &config, nil
}
func (r *GormProductApiConfigRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductApiConfig, error) {
var configs []*entities.ProductApiConfig
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&configs).Error
if err != nil {
return nil, err
}
return configs, nil
}
func (r *GormProductApiConfigRepository) ExistsByProductID(ctx context.Context, productID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ProductApiConfig{}).Where("product_id = ?", productID).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}

View File

@@ -3,44 +3,44 @@ package repositories
import (
"context"
"errors"
"go.uber.org/zap"
"gorm.io/gorm"
"tyapi-server/internal/domains/product/entities"
"tyapi-server/internal/domains/product/repositories"
"tyapi-server/internal/domains/product/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductCategoriesTable = "product_categories"
)
// GormProductCategoryRepository GORM产品分类仓储实现
type GormProductCategoryRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.ProductCategory{})
}
// 编译时检查接口实现
var _ repositories.ProductCategoryRepository = (*GormProductCategoryRepository)(nil)
// NewGormProductCategoryRepository 创建GORM产品分类仓储
func NewGormProductCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductCategoryRepository {
return &GormProductCategoryRepository{
db: db,
logger: logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductCategoriesTable),
}
}
// Create 创建产品分类
func (r *GormProductCategoryRepository) Create(ctx context.Context, entity entities.ProductCategory) (entities.ProductCategory, error) {
r.logger.Info("创建产品分类", zap.String("id", entity.ID), zap.String("name", entity.Name))
err := r.db.WithContext(ctx).Create(&entity).Error
err := r.CreateEntity(ctx, &entity)
return entity, err
}
// GetByID 根据ID获取产品分类
func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string) (entities.ProductCategory, error) {
var entity entities.ProductCategory
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.ProductCategory{}, gorm.ErrRecordNotFound
@@ -50,22 +50,14 @@ func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string)
return entity, nil
}
// Update 更新产品分类
func (r *GormProductCategoryRepository) Update(ctx context.Context, entity entities.ProductCategory) error {
r.logger.Info("更新产品分类", zap.String("id", entity.ID))
return r.db.WithContext(ctx).Save(&entity).Error
}
// Delete 删除产品分类
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除产品分类", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
return r.UpdateEntity(ctx, &entity)
}
// FindByCode 根据编号查找产品分类
func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductCategory, error) {
var entity entities.ProductCategory
err := r.db.WithContext(ctx).Where("code = ?", code).First(&entity).Error
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
@@ -78,7 +70,7 @@ func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code str
// FindVisible 查找可见分类
func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductCategory, error) {
var categories []entities.ProductCategory
err := r.db.WithContext(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Find(&categories).Error
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Find(&categories).Error
if err != nil {
return nil, err
}
@@ -94,7 +86,7 @@ func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*ent
// FindEnabled 查找启用分类
func (r *GormProductCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductCategory, error) {
var categories []entities.ProductCategory
err := r.db.WithContext(ctx).Where("is_enabled = ?", true).Find(&categories).Error
err := r.GetDB(ctx).Where("is_enabled = ?", true).Find(&categories).Error
if err != nil {
return nil, err
}
@@ -112,7 +104,7 @@ func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, quer
var categories []entities.ProductCategory
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.ProductCategory{})
dbQuery := r.GetDB(ctx).Model(&entities.ProductCategory{})
// 应用筛选条件
if query.IsEnabled != nil {
@@ -164,14 +156,14 @@ func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, quer
// CountEnabled 统计启用分类数量
func (r *GormProductCategoryRepository) CountEnabled(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
return count, err
}
// CountVisible 统计可见分类数量
func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
return count, err
}
@@ -180,7 +172,7 @@ func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64
// Count 返回分类总数
func (r *GormProductCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.ProductCategory{})
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
// 应用筛选条件
if options.Filters != nil {
@@ -201,29 +193,29 @@ func (r *GormProductCategoryRepository) Count(ctx context.Context, options inter
// GetByIDs 根据ID列表获取分类
func (r *GormProductCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ProductCategory, error) {
var categories []entities.ProductCategory
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&categories).Error
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&categories).Error
return categories, err
}
// CreateBatch 批量创建分类
func (r *GormProductCategoryRepository) CreateBatch(ctx context.Context, categories []entities.ProductCategory) error {
return r.db.WithContext(ctx).Create(&categories).Error
return r.GetDB(ctx).Create(&categories).Error
}
// UpdateBatch 批量更新分类
func (r *GormProductCategoryRepository) UpdateBatch(ctx context.Context, categories []entities.ProductCategory) error {
return r.db.WithContext(ctx).Save(&categories).Error
return r.GetDB(ctx).Save(&categories).Error
}
// DeleteBatch 批量删除分类
func (r *GormProductCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
}
// List 获取分类列表(基础方法)
func (r *GormProductCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ProductCategory, error) {
var categories []entities.ProductCategory
query := r.db.WithContext(ctx).Model(&entities.ProductCategory{})
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
// 应用筛选条件
if options.Filters != nil {
@@ -261,26 +253,25 @@ func (r *GormProductCategoryRepository) List(ctx context.Context, options interf
// Exists 检查分类是否存在
func (r *GormProductCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除分类
func (r *GormProductCategoryRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
}
// Restore 恢复软删除的分类
func (r *GormProductCategoryRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// WithTx 使用事务
func (r *GormProductCategoryRepository) WithTx(tx interface{}) interfaces.Repository[entities.ProductCategory] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormProductCategoryRepository{
db: gormTx,
logger: r.logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductCategoriesTable),
}
}
return r

View File

@@ -6,40 +6,41 @@ import (
"tyapi-server/internal/domains/product/entities"
"tyapi-server/internal/domains/product/repositories"
"tyapi-server/internal/domains/product/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// GormProductRepository GORM产品仓储实现
const (
ProductsTable = "products"
)
type GormProductRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.Product{})
}
// 编译时检查接口实现
var _ repositories.ProductRepository = (*GormProductRepository)(nil)
// NewGormProductRepository 创建GORM产品仓储
func NewGormProductRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductRepository {
return &GormProductRepository{
db: db,
logger: logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductsTable),
}
}
// Create 创建产品
func (r *GormProductRepository) Create(ctx context.Context, entity entities.Product) (entities.Product, error) {
r.logger.Info("创建产品", zap.String("id", entity.ID), zap.String("name", entity.Name))
err := r.db.WithContext(ctx).Create(&entity).Error
err := r.CreateEntity(ctx, &entity)
return entity, err
}
// GetByID 根据ID获取产品
func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entities.Product, error) {
var entity entities.Product
err := r.db.WithContext(ctx).Preload("Category").Where("id = ?", id).First(&entity).Error
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Product{}, gorm.ErrRecordNotFound
@@ -49,26 +50,17 @@ func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entitie
return entity, nil
}
// Update 更新产品
func (r *GormProductRepository) Update(ctx context.Context, entity entities.Product) error {
r.logger.Info("更新产品", zap.String("id", entity.ID))
return r.db.WithContext(ctx).Save(&entity).Error
return r.UpdateEntity(ctx, &entity)
}
// Delete 删除产品
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除产品", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Product{}, "id = ?", id).Error
}
// 其它方法同理迁移全部用r.GetDB(ctx)
// FindByCode 根据编号查找产品
func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*entities.Product, error) {
var entity entities.Product
err := r.db.WithContext(ctx).Preload("Category").Where("code = ?", code).First(&entity).Error
err := r.SmartGetByField(ctx, &entity, "code", code) // 自动缓存
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
@@ -77,11 +69,11 @@ func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*e
// FindByCategoryID 根据分类ID查找产品
func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.db.WithContext(ctx).Preload("Category").Where("category_id = ?", categoryID).Find(&productEntities).Error
err := r.GetDB(ctx).Preload("Category").Where("category_id = ?", categoryID).Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
@@ -93,11 +85,11 @@ func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID
// FindVisible 查找可见产品
func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.db.WithContext(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Find(&productEntities).Error
err := r.GetDB(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
@@ -109,11 +101,11 @@ func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Pr
// FindEnabled 查找启用产品
func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.db.WithContext(ctx).Preload("Category").Where("is_enabled = ?", true).Find(&productEntities).Error
err := r.GetDB(ctx).Preload("Category").Where("is_enabled = ?", true).Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
@@ -126,12 +118,12 @@ func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Pr
func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, int64, error) {
var productEntities []entities.Product
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Product{})
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if query.Keyword != "" {
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
if query.CategoryID != "" {
@@ -152,12 +144,12 @@ func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries
if query.IsPackage != nil {
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
if query.SortBy != "" {
order := query.SortBy
@@ -170,32 +162,31 @@ func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载分类信息并获取数据
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, total, nil
}
// FindSubscribableProducts 查找可订阅产品
func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, userID string) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.db.WithContext(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Find(&productEntities).Error
err := r.GetDB(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Find(&productEntities).Error
if err != nil {
return nil, err
}
@@ -210,7 +201,7 @@ func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, us
// FindProductsByIDs 根据ID列表查找产品
func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []string) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&productEntities).Error
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&productEntities).Error
if err != nil {
return nil, err
}
@@ -225,7 +216,7 @@ func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []str
// CountByCategory 统计分类下的产品数量
func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID string) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Product{})
query := r.GetDB(ctx).Model(&entities.Product{})
if categoryID != "" {
query = query.Where("category_id = ?", categoryID)
}
@@ -236,34 +227,34 @@ func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID
// CountEnabled 统计启用产品数量
func (r *GormProductRepository) CountEnabled(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
return count, err
}
// CountVisible 统计可见产品数量
func (r *GormProductRepository) CountVisible(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
return count, err
}
}
// Count 返回产品总数
func (r *GormProductRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Product{})
query := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
@@ -271,42 +262,42 @@ func (r *GormProductRepository) Count(ctx context.Context, options interfaces.Co
// GetByIDs 根据ID列表获取产品
func (r *GormProductRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Product, error) {
var products []entities.Product
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&products).Error
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&products).Error
return products, err
}
// CreateBatch 批量创建产品
func (r *GormProductRepository) CreateBatch(ctx context.Context, products []entities.Product) error {
return r.db.WithContext(ctx).Create(&products).Error
return r.GetDB(ctx).Create(&products).Error
}
// UpdateBatch 批量更新产品
func (r *GormProductRepository) UpdateBatch(ctx context.Context, products []entities.Product) error {
return r.db.WithContext(ctx).Save(&products).Error
return r.GetDB(ctx).Save(&products).Error
}
// DeleteBatch 批量删除产品
func (r *GormProductRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
return r.GetDB(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
}
// List 获取产品列表(基础方法)
func (r *GormProductRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Product, error) {
var products []entities.Product
query := r.db.WithContext(ctx).Model(&entities.Product{})
query := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := options.Sort
@@ -317,13 +308,13 @@ func (r *GormProductRepository) List(ctx context.Context, options interfaces.Lis
}
query = query.Order(order)
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&products).Error
return products, err
}
@@ -331,27 +322,80 @@ func (r *GormProductRepository) List(ctx context.Context, options interfaces.Lis
// Exists 检查产品是否存在
func (r *GormProductRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除产品
func (r *GormProductRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.Product{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.Product{}, "id = ?", id).Error
}
// Restore 恢复软删除的产品
func (r *GormProductRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// GetPackageItems 获取组合包项目
func (r *GormProductRepository) GetPackageItems(ctx context.Context, packageID string) ([]*entities.ProductPackageItem, error) {
var packageItems []entities.ProductPackageItem
err := r.GetDB(ctx).
Preload("Product").
Where("package_id = ?", packageID).
Order("sort_order ASC").
Find(&packageItems).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductPackageItem, len(packageItems))
for i := range packageItems {
result[i] = &packageItems[i]
}
return result, nil
}
// CreatePackageItem 创建组合包项目
func (r *GormProductRepository) CreatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
return r.GetDB(ctx).Create(packageItem).Error
}
// GetPackageItemByID 根据ID获取组合包项目
func (r *GormProductRepository) GetPackageItemByID(ctx context.Context, itemID string) (*entities.ProductPackageItem, error) {
var packageItem entities.ProductPackageItem
err := r.GetDB(ctx).
Preload("Product").
Preload("Package").
Where("id = ?", itemID).
First(&packageItem).Error
if err != nil {
return nil, err
}
return &packageItem, nil
}
// UpdatePackageItem 更新组合包项目
func (r *GormProductRepository) UpdatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
return r.GetDB(ctx).Save(packageItem).Error
}
// DeletePackageItem 删除组合包项目(硬删除)
func (r *GormProductRepository) DeletePackageItem(ctx context.Context, itemID string) error {
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "id = ?", itemID).Error
}
// DeletePackageItemsByPackageID 根据组合包ID删除所有子产品硬删除
func (r *GormProductRepository) DeletePackageItemsByPackageID(ctx context.Context, packageID string) error {
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "package_id = ?", packageID).Error
}
// WithTx 使用事务
func (r *GormProductRepository) WithTx(tx interface{}) interfaces.Repository[entities.Product] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormProductRepository{
db: gormTx,
logger: r.logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductsTable),
}
}
return r
}
}

View File

@@ -3,44 +3,46 @@ package repositories
import (
"context"
"errors"
"go.uber.org/zap"
"gorm.io/gorm"
"time"
"tyapi-server/internal/domains/product/entities"
"tyapi-server/internal/domains/product/repositories"
"tyapi-server/internal/domains/product/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
SubscriptionsTable = "subscriptions"
SubscriptionCacheTTL = 60 * time.Minute
)
// GormSubscriptionRepository GORM订阅仓储实现
type GormSubscriptionRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.Subscription{})
}
// 编译时检查接口实现
var _ repositories.SubscriptionRepository = (*GormSubscriptionRepository)(nil)
// NewGormSubscriptionRepository 创建GORM订阅仓储
func NewGormSubscriptionRepository(db *gorm.DB, logger *zap.Logger) repositories.SubscriptionRepository {
return &GormSubscriptionRepository{
db: db,
logger: logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SubscriptionsTable),
}
}
// Create 创建订阅
func (r *GormSubscriptionRepository) Create(ctx context.Context, entity entities.Subscription) (entities.Subscription, error) {
r.logger.Info("创建订阅", zap.String("id", entity.ID), zap.String("user_id", entity.UserID))
err := r.db.WithContext(ctx).Create(&entity).Error
err := r.CreateEntity(ctx, &entity)
return entity, err
}
// GetByID 根据ID获取订阅
func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (entities.Subscription, error) {
var entity entities.Subscription
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Subscription{}, gorm.ErrRecordNotFound
@@ -50,22 +52,14 @@ func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (en
return entity, nil
}
// Update 更新订阅
func (r *GormSubscriptionRepository) Update(ctx context.Context, entity entities.Subscription) error {
r.logger.Info("更新订阅", zap.String("id", entity.ID))
return r.db.WithContext(ctx).Save(&entity).Error
}
// Delete 删除订阅
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除订阅", zap.String("id", id))
return r.db.WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
return r.UpdateEntity(ctx, &entity)
}
// FindByUserID 根据用户ID查找订阅
func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.Subscription, error) {
var subscriptions []entities.Subscription
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&subscriptions).Error
err := r.GetDB(ctx).WithContext(ctx).Where("user_id = ?", userID).Find(&subscriptions).Error
if err != nil {
return nil, err
}
@@ -81,7 +75,7 @@ func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID st
// FindByProductID 根据产品ID查找订阅
func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, productID string) ([]*entities.Subscription, error) {
var subscriptions []entities.Subscription
err := r.db.WithContext(ctx).Where("product_id = ?", productID).Find(&subscriptions).Error
err := r.GetDB(ctx).WithContext(ctx).Where("product_id = ?", productID).Find(&subscriptions).Error
if err != nil {
return nil, err
}
@@ -97,7 +91,10 @@ func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, produc
// FindByUserAndProduct 根据用户和产品查找订阅
func (r *GormSubscriptionRepository) FindByUserAndProduct(ctx context.Context, userID, productID string) (*entities.Subscription, error) {
var entity entities.Subscription
err := r.db.WithContext(ctx).Where("user_id = ? AND product_id = ?", userID, productID).First(&entity).Error
// 组合缓存key的条件
where := "user_id = ? AND product_id = ?"
ttl := SubscriptionCacheTTL // 缓存10分钟可根据业务调整
err := r.GetWithCache(ctx, &entity, ttl, where, userID, productID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
@@ -112,7 +109,7 @@ func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, quer
var subscriptions []entities.Subscription
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Subscription{})
dbQuery := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
// 应用筛选条件
if query.UserID != "" {
@@ -172,14 +169,14 @@ func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, quer
// CountByUser 统计用户订阅数量
func (r *GormSubscriptionRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
// CountByProduct 统计产品订阅数量
func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, productID string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
return count, err
}
@@ -188,7 +185,7 @@ func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, product
// Count 返回订阅总数
func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.Subscription{})
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
// 应用筛选条件
if options.Filters != nil {
@@ -209,29 +206,29 @@ func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfac
// GetByIDs 根据ID列表获取订阅
func (r *GormSubscriptionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Subscription, error) {
var subscriptions []entities.Subscription
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&subscriptions).Error
err := r.GetDB(ctx).WithContext(ctx).Where("id IN ?", ids).Find(&subscriptions).Error
return subscriptions, err
}
// CreateBatch 批量创建订阅
func (r *GormSubscriptionRepository) CreateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
return r.db.WithContext(ctx).Create(&subscriptions).Error
return r.GetDB(ctx).WithContext(ctx).Create(&subscriptions).Error
}
// UpdateBatch 批量更新订阅
func (r *GormSubscriptionRepository) UpdateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
return r.db.WithContext(ctx).Save(&subscriptions).Error
return r.GetDB(ctx).WithContext(ctx).Save(&subscriptions).Error
}
// DeleteBatch 批量删除订阅
func (r *GormSubscriptionRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
}
// List 获取订阅列表(基础方法)
func (r *GormSubscriptionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Subscription, error) {
var subscriptions []entities.Subscription
query := r.db.WithContext(ctx).Model(&entities.Subscription{})
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
// 应用筛选条件
if options.Filters != nil {
@@ -269,26 +266,25 @@ func (r *GormSubscriptionRepository) List(ctx context.Context, options interface
// Exists 检查订阅是否存在
func (r *GormSubscriptionRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除订阅
func (r *GormSubscriptionRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
}
// Restore 恢复软删除的订阅
func (r *GormSubscriptionRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// WithTx 使用事务
func (r *GormSubscriptionRepository) WithTx(tx interface{}) interfaces.Repository[entities.Subscription] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormSubscriptionRepository{
db: gormTx,
logger: r.logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), SubscriptionsTable),
}
}
return r

View File

@@ -0,0 +1,101 @@
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
package repositories
import (
"context"
"errors"
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ContractInfosTable = "contract_infos"
)
type GormContractInfoRepository struct {
*database.CachedBaseRepositoryImpl
}
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
return &GormContractInfoRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
}
}
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
return r.CreateEntity(ctx, contract)
}
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
var contract entities.ContractInfo
err := r.SmartGetByID(ctx, contractID, &contract)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &contract, nil
}
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
}
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
return count > 0, err
}
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
return count > 0, err
}

View File

@@ -15,20 +15,23 @@ import (
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
)
const (
SMSCodesTable = "sms_codes"
)
// GormSMSCodeRepository 短信验证码GORM仓储实现无缓存确保安全性
type GormSMSCodeRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
// NewGormSMSCodeRepository 创建短信验证码仓储
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
return &GormSMSCodeRepository{
db: db,
logger: logger,
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
}
}
@@ -39,64 +42,54 @@ var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
// Create 创建短信验证码记录(不缓存,确保安全性)
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
err := r.GetDB(ctx).Create(&smsCode).Error
return smsCode, err
}
// GetByID 根据ID获取短信验证码
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&smsCode).Error; err != nil {
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
}
r.logger.Error("获取短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
}
// Update 更新验证码记录
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
if err := r.db.WithContext(ctx).Save(&smsCode).Error; err != nil {
r.logger.Error("更新短信验证码失败", zap.Error(err))
return err
}
return nil
return r.GetDB(ctx).Save(&smsCode).Error
}
// CreateBatch 批量创建短信验证码
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
return r.db.WithContext(ctx).Create(&smsCodes).Error
return r.GetDB(ctx).Create(&smsCodes).Error
}
// GetByIDs 根据ID列表获取短信验证码
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
return smsCodes, err
}
// UpdateBatch 批量更新短信验证码
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
return r.db.WithContext(ctx).Save(&smsCodes).Error
return r.GetDB(ctx).Save(&smsCodes).Error
}
// DeleteBatch 批量删除短信验证码
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
}
// List 获取短信验证码列表
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
query := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
@@ -139,20 +132,20 @@ func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.Lis
// Delete 删除短信验证码
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Exists 检查短信验证码是否存在
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// Count 统计短信验证码数量
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
query := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
@@ -172,12 +165,12 @@ func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.Co
// SoftDelete 软删除短信验证码
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Restore 恢复短信验证码
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
@@ -185,7 +178,7 @@ func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
// GetByPhone 根据手机号获取短信验证码
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
@@ -198,7 +191,7 @@ func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*
// GetLatestByPhone 根据手机号获取最新短信验证码
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
@@ -211,7 +204,7 @@ func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone stri
// GetValidByPhone 根据手机号获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
@@ -227,7 +220,7 @@ func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone strin
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
@@ -246,7 +239,7 @@ func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries
var total int64
// 构建查询条件
db := r.db.WithContext(ctx).Model(&entities.SMSCode{})
db := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if query.Phone != "" {
@@ -288,8 +281,8 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
}
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
r.logger.Error("创建短信验证码失败", zap.Error(err))
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
@@ -299,7 +292,7 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
// ValidateCode 验证验证码
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
Count(&count).Error
@@ -309,7 +302,7 @@ func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string,
// InvalidateCode 使验证码失效
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&entities.SMSCode{}).
return r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND used_at IS NULL", phone).
Update("used_at", &now).Error
}
@@ -320,7 +313,7 @@ func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone st
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
Count(&count).Error
@@ -333,7 +326,7 @@ func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone str
today := time.Now().Truncate(24 * time.Hour)
var count int64
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&count).Error
@@ -348,7 +341,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
startDate := time.Now().AddDate(0, 0, -days)
// 总发送数
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, startDate).
Count(&stats.TotalSent).Error; err != nil {
@@ -356,7 +349,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
}
// 总验证数
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
Count(&stats.TotalValidated).Error; err != nil {
@@ -370,7 +363,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
// 今日发送数
today := time.Now().Truncate(24 * time.Hour)
if err := r.db.WithContext(ctx).
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&stats.TodaySent).Error; err != nil {

View File

@@ -14,242 +14,199 @@ import (
"tyapi-server/internal/domains/user/entities"
"tyapi-server/internal/domains/user/repositories"
"tyapi-server/internal/domains/user/repositories/queries"
"tyapi-server/internal/shared/database"
"tyapi-server/internal/shared/interfaces"
)
const (
UsersTable = "users"
UserCacheTTL = 30 * 60 // 30分钟
)
// 定义错误常量
var (
// ErrUserNotFound 用户不存在错误
ErrUserNotFound = errors.New("用户不存在")
)
// UserRepository 用户仓储实现(已移除手动缓存管理)
type GormUserRepository struct {
db *gorm.DB
logger *zap.Logger
*database.CachedBaseRepositoryImpl
}
// NewGormUserRepository 创建用户仓储
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
db: db,
logger: logger,
}
}
// 确保 GormUserRepository 实现了 UserRepository 接口
var _ repositories.UserRepository = (*GormUserRepository)(nil)
// ================ Repository[T] 接口实现 ================
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
}
}
// Create 创建用户(自动缓存失效)
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
if err := r.db.WithContext(ctx).Create(&user).Error; err != nil {
r.logger.Error("创建用户失败", zap.Error(err))
err := r.CreateEntity(ctx, &user)
return user, err
}
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
var user entities.User
err := r.SmartGetByID(ctx, id, &user)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, errors.New("用户不存在")
}
return entities.User{}, err
}
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
return user, nil
}
// GetByID 根据ID获取用户自动缓存
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, ErrUserNotFound
}
r.logger.Error("根据ID查询用户失败", zap.Error(err))
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
return entities.User{}, err
}
return user, nil
}
// Update 更新用户(自动缓存失效)
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
if err := r.db.WithContext(ctx).Save(&user).Error; err != nil {
r.logger.Error("更新用户失败", zap.Error(err))
return err
}
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
return nil
}
// CreateBatch 批量创建用户
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量创建用户", zap.Int("count", len(users)))
return r.db.WithContext(ctx).Create(&users).Error
}
// GetByIDs 根据ID列表获取用户
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
var users []entities.User
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
return users, err
}
// UpdateBatch 批量更新用户
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
r.logger.Info("批量更新用户", zap.Int("count", len(users)))
return r.db.WithContext(ctx).Save(&users).Error
}
// DeleteBatch 批量删除用户
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.logger.Info("批量删除用户", zap.Strings("ids", ids))
return r.db.WithContext(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
}
// List 获取用户列表
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
var users []entities.User
query := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return users, query.Find(&users).Error
}
// ================ BaseRepository 接口实现 ================
// Delete 删除用户
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除用户失败", zap.Error(err))
return err
}
r.logger.Info("用户删除成功", zap.String("user_id", id))
return nil
}
// Exists 检查用户是否存在
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
query := r.GetDB(ctx).Model(&entities.User{}).
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
// Count 统计用户数量
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.User{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
// 如果指定了排除的用户ID则排除该用户的记录
if excludeUserID != "" {
query = query.Where("users.id != ?", excludeUserID)
}
err := query.Count(&count).Error
if err != nil {
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
return false, err
}
return count > 0, nil
}
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
return r.UpdateEntity(ctx, &user)
}
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
return r.GetDB(ctx).Create(&users).Error
}
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
var users []entities.User
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&users).Error
return users, err
}
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
return r.GetDB(ctx).Save(&users).Error
}
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
}
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
var users []entities.User
err := r.SmartList(ctx, &users, options)
return users, err
}
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.User{})
}
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.User{})
}
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
return count, err
}
// SoftDelete 软删除用户
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
}
// Restore 恢复用户
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
// GetByPhone 根据手机号获取用户
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据手机号查询用户失败", zap.Error(err))
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
var user entities.User
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.logger.Error("根据用户名查询用户失败", zap.Error(err))
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
// GetByUserType 根据用户类型获取用户列表
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
var users []*entities.User
err := r.db.WithContext(ctx).Where("user_type = ?", userType).Find(&users).Error
err := r.GetDB(ctx).Where("user_type = ?", userType).Find(&users).Error
return users, err
}
// ListUsers 获取用户列表(带分页和筛选)
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
var users []*entities.User
var total int64
// 构建查询条件
db := r.db.WithContext(ctx).Model(&entities.User{})
// 构建查询条件,预加载企业信息
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
// 应用筛选条件
if query.Phone != "" {
db = db.Where("phone LIKE ?", "%"+query.Phone+"%")
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
}
if query.UserType != "" {
db = db.Where("users.user_type = ?", query.UserType)
}
if query.IsActive != nil {
db = db.Where("users.active = ?", *query.IsActive)
}
if query.IsCertified != nil {
db = db.Where("users.is_certified = ?", *query.IsCertified)
}
if query.CompanyName != "" {
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
}
if query.StartDate != "" {
db = db.Where("created_at >= ?", query.StartDate)
db = db.Where("users.created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
db = db.Where("created_at <= ?", query.EndDate)
db = db.Where("users.created_at <= ?", query.EndDate)
}
// 统计总数
@@ -266,10 +223,9 @@ func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListU
return users, total, nil
}
// ValidateUser 验证用户登录
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
var user entities.User
err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
if err != nil {
return nil, err
}
@@ -277,10 +233,9 @@ func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password s
return &user, nil
}
// UpdateLastLogin 更新最后登录时间
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"last_login_at": &now,
@@ -288,40 +243,35 @@ func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string)
}).Error
}
// UpdatePassword 更新密码
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("password", newPassword).Error
}
// CheckPassword 检查密码
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.User{}).
err := r.GetDB(ctx).Model(&entities.User{}).
Where("id = ? AND password = ?", userID, password).
Count(&count).Error
return count > 0, err
}
// ActivateUser 激活用户
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("active", true).Error
}
// DeactivateUser 停用用户
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("active", false).Error
}
// UpdateLoginStats 更新登录统计
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Model(&entities.User{}).
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"login_count": gorm.Expr("login_count + 1"),
@@ -329,11 +279,10 @@ func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string
}).Error
}
// GetStats 获取用户统计信息
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.db.WithContext(ctx)
db := r.GetDB(ctx)
// 总用户数
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
@@ -345,6 +294,11 @@ func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserSt
return nil, err
}
// 已认证用户数
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
@@ -359,11 +313,10 @@ func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserSt
return &stats, nil
}
// GetStatsByDateRange 获取指定日期范围的用户统计
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.db.WithContext(ctx)
db := r.GetDB(ctx)
// 指定时间范围内的注册数
if err := db.Model(&entities.User{}).