v0.1
This commit is contained in:
124
internal/infrastructure/cache/redis_cache.go
vendored
124
internal/infrastructure/cache/redis_cache.go
vendored
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -85,6 +86,7 @@ func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
r.logger.Error("序列化缓存数据失败", zap.String("key", key), zap.Error(err))
|
||||
return fmt.Errorf("failed to marshal value: %w", err)
|
||||
}
|
||||
|
||||
@@ -106,10 +108,11 @@ func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl
|
||||
|
||||
err = r.client.Set(ctx, fullKey, data, expiration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
|
||||
r.logger.Error("设置缓存失败", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Debug("设置缓存成功", zap.String("key", key), zap.Duration("ttl", expiration))
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -158,6 +161,7 @@ func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string
|
||||
|
||||
values, err := r.client.MGet(ctx, fullKeys...).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("批量获取缓存失败", zap.Strings("keys", keys), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -165,9 +169,15 @@ func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string
|
||||
for i, val := range values {
|
||||
if val != nil {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
|
||||
result[keys[i]] = data
|
||||
// 修复:改进JSON反序列化错误处理
|
||||
if err := json.Unmarshal([]byte(val.(string)), &data); err != nil {
|
||||
r.logger.Warn("反序列化缓存数据失败",
|
||||
zap.String("key", keys[i]),
|
||||
zap.String("value", val.(string)),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
result[keys[i]] = data
|
||||
}
|
||||
}
|
||||
|
||||
@@ -210,17 +220,107 @@ func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{
|
||||
|
||||
// DeletePattern 按模式删除
|
||||
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
// 修复:避免重复添加前缀
|
||||
var fullPattern string
|
||||
if strings.HasPrefix(pattern, r.prefix+":") {
|
||||
fullPattern = pattern
|
||||
} else {
|
||||
fullPattern = r.getFullKey(pattern)
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
return r.client.Del(ctx, keys...).Err()
|
||||
|
||||
// 检查上下文是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
|
||||
var cursor uint64
|
||||
var totalDeleted int64
|
||||
maxIterations := 100 // 防止无限循环
|
||||
iteration := 0
|
||||
|
||||
for {
|
||||
// 检查迭代次数限制
|
||||
iteration++
|
||||
if iteration > maxIterations {
|
||||
r.logger.Warn("缓存删除操作达到最大迭代次数限制",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int("max_iterations", maxIterations),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
// 检查上下文是否已取消
|
||||
if ctx.Err() != nil {
|
||||
r.logger.Warn("缓存删除操作被取消",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
zap.Error(ctx.Err()),
|
||||
)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 执行SCAN操作
|
||||
keys, next, err := r.client.Scan(ctx, cursor, fullPattern, 1000).Result()
|
||||
if err != nil {
|
||||
// 如果是上下文取消错误,直接返回
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
r.logger.Warn("缓存删除操作被取消",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Error("扫描缓存键失败",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 批量删除找到的键
|
||||
if len(keys) > 0 {
|
||||
// 使用pipeline批量删除,提高性能
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.Del(ctx, keys...)
|
||||
|
||||
cmds, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
r.logger.Error("批量删除缓存键失败",
|
||||
zap.Strings("keys", keys),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 统计删除的键数量
|
||||
for _, cmd := range cmds {
|
||||
if delCmd, ok := cmd.(*redis.IntCmd); ok {
|
||||
if deleted, err := delCmd.Result(); err == nil {
|
||||
totalDeleted += deleted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Debug("批量删除缓存键",
|
||||
zap.Strings("keys", keys),
|
||||
zap.Int("batch_size", len(keys)),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
)
|
||||
}
|
||||
|
||||
cursor = next
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Debug("缓存模式删除完成",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
zap.Int("iterations", iteration),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"tyapi-server/internal/domains/api/entities"
|
||||
"tyapi-server/internal/domains/api/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiCallsTable = "api_calls"
|
||||
ApiCallCacheTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ApiCallWithProduct 包含产品名称的API调用记录
|
||||
type ApiCallWithProduct struct {
|
||||
entities.ApiCall
|
||||
ProductName string `json:"product_name" gorm:"column:product_name"`
|
||||
}
|
||||
|
||||
type GormApiCallRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ApiCallRepository = (*GormApiCallRepository)(nil)
|
||||
|
||||
func NewGormApiCallRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiCallRepository {
|
||||
return &GormApiCallRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiCallsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) Create(ctx context.Context, call *entities.ApiCall) error {
|
||||
return r.CreateEntity(ctx, call)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) Update(ctx context.Context, call *entities.ApiCall) error {
|
||||
return r.UpdateEntity(ctx, call)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindById(ctx context.Context, id string) (*entities.ApiCall, error) {
|
||||
var call entities.ApiCall
|
||||
err := r.SmartGetByID(ctx, id, &call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &call, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindByUserId(ctx context.Context, userId string, limit, offset int) ([]*entities.ApiCall, error) {
|
||||
var calls []*entities.ApiCall
|
||||
options := database.CacheListOptions{
|
||||
Where: "user_id = ?",
|
||||
Args: []interface{}{userId},
|
||||
Order: "created_at DESC",
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
err := r.ListWithCache(ctx, &calls, ApiCallCacheTTL, options)
|
||||
return calls, err
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
|
||||
var calls []*entities.ApiCall
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
|
||||
return calls, total, err
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
|
||||
var calls []*entities.ApiCall
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品ID筛选
|
||||
if productId, ok := filters["product_id"].(string); ok && productId != "" {
|
||||
whereCondition += " AND product_id = ?"
|
||||
whereArgs = append(whereArgs, productId)
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
|
||||
return calls, total, err
|
||||
}
|
||||
|
||||
// ListByUserIdWithFiltersAndProductName 根据用户ID和筛选条件获取API调用记录(包含产品名称)
|
||||
func (r *GormApiCallRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
|
||||
var callsWithProduct []*ApiCallWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "ac.user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND ac.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND ac.status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
query := r.GetDB(ctx).Table("api_calls ac").
|
||||
Select("ac.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON ac.product_id = p.id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("ac." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("ac.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&callsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.ApiCall并构建产品名称映射
|
||||
var calls []*entities.ApiCall
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, c := range callsWithProduct {
|
||||
call := c.ApiCall
|
||||
calls = append(calls, &call)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if c.ProductName != "" {
|
||||
productNameMap[call.ID] = c.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, calls, total, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ?", userId)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindByTransactionId(ctx context.Context, transactionId string) (*entities.ApiCall, error) {
|
||||
var call entities.ApiCall
|
||||
err := r.FindOne(ctx, &call, "transaction_id = ?", transactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &call, nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"tyapi-server/internal/domains/api/entities"
|
||||
"tyapi-server/internal/domains/api/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiUsersTable = "api_users"
|
||||
ApiUserCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type GormApiUserRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ApiUserRepository = (*GormApiUserRepository)(nil)
|
||||
|
||||
func NewGormApiUserRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiUserRepository {
|
||||
return &GormApiUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiUsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) Create(ctx context.Context, user *entities.ApiUser) error {
|
||||
return r.CreateEntity(ctx, user)
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) Update(ctx context.Context, user *entities.ApiUser) error {
|
||||
return r.UpdateEntity(ctx, user)
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) FindByAccessId(ctx context.Context, accessId string) (*entities.ApiUser, error) {
|
||||
var user entities.ApiUser
|
||||
err := r.SmartGetByField(ctx, &user, "access_id", accessId, ApiUserCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) FindByUserId(ctx context.Context, userId string) (*entities.ApiUser, error) {
|
||||
var user entities.ApiUser
|
||||
err := r.SmartGetByField(ctx, &user, "user_id", userId, ApiUserCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
@@ -123,6 +123,15 @@ func (r *GormCertificationQueryRepository) Exists(ctx context.Context, id string
|
||||
return r.ExistsEntity(ctx, id, &entities.Certification{})
|
||||
}
|
||||
|
||||
func (r *GormCertificationQueryRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Certification{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("查询用户认证是否存在失败: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ================ 列表查询 ================
|
||||
|
||||
// List 分页列表查询
|
||||
@@ -278,12 +287,12 @@ func (r *GormCertificationQueryRepository) FindByEsignFlowID(ctx context.Context
|
||||
func (r *GormCertificationQueryRepository) ListPendingRetry(ctx context.Context, maxRetryCount int) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status IN ? AND retry_count < ?",
|
||||
Where("status IN ? AND retry_count < ?",
|
||||
[]enums.CertificationStatus{
|
||||
enums.StatusInfoRejected,
|
||||
enums.StatusContractRejected,
|
||||
enums.StatusContractExpired,
|
||||
},
|
||||
},
|
||||
maxRetryCount).
|
||||
Order("created_at ASC").
|
||||
Find(&certifications).Error
|
||||
@@ -367,61 +376,7 @@ func (r *GormCertificationQueryRepository) GetUserActiveCertification(ctx contex
|
||||
|
||||
// ================ 统计查询 ================
|
||||
|
||||
// GetStatistics 获取统计数据
|
||||
func (r *GormCertificationQueryRepository) GetStatistics(ctx context.Context, period repositories.CertificationTimePeriod) (*repositories.CertificationStatistics, error) {
|
||||
now := time.Now()
|
||||
var startDate time.Time
|
||||
|
||||
switch period {
|
||||
case repositories.PeriodDaily:
|
||||
startDate = now.AddDate(0, 0, -1)
|
||||
case repositories.PeriodWeekly:
|
||||
startDate = now.AddDate(0, 0, -7)
|
||||
case repositories.PeriodMonthly:
|
||||
startDate = now.AddDate(0, -1, 0)
|
||||
case repositories.PeriodYearly:
|
||||
startDate = now.AddDate(-1, 0, 0)
|
||||
default:
|
||||
startDate = now.AddDate(0, 0, -7)
|
||||
}
|
||||
|
||||
// 使用短期缓存进行统计查询
|
||||
var totalCount int64
|
||||
r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("created_at BETWEEN ? AND ?", startDate, now).
|
||||
Count(&totalCount)
|
||||
|
||||
var completedCount int64
|
||||
r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("created_at BETWEEN ? AND ? AND status = ?", startDate, now, enums.StatusContractSigned).
|
||||
Count(&completedCount)
|
||||
|
||||
successRate := float64(0)
|
||||
if totalCount > 0 {
|
||||
successRate = float64(completedCount) / float64(totalCount)
|
||||
}
|
||||
|
||||
return &repositories.CertificationStatistics{
|
||||
Period: period,
|
||||
StartDate: startDate,
|
||||
EndDate: now,
|
||||
TotalCertifications: totalCount,
|
||||
CompletedCount: completedCount,
|
||||
SuccessRate: successRate,
|
||||
StatusDistribution: make(map[enums.CertificationStatus]int64),
|
||||
FailureDistribution: make(map[enums.FailureReason]int64),
|
||||
AvgProcessingTime: 24 * time.Hour, // 简化实现
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CountByStatus 按状态统计
|
||||
func (r *GormCertificationQueryRepository) CountByStatus(ctx context.Context, status enums.CertificationStatus) (int64, error) {
|
||||
var count int64
|
||||
if err := r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).Where("status = ?", status).Count(&count).Error; err != nil {
|
||||
return 0, fmt.Errorf("按状态统计认证失败: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByFailureReason 按失败原因统计
|
||||
func (r *GormCertificationQueryRepository) CountByFailureReason(ctx context.Context, reason enums.FailureReason) (int64, error) {
|
||||
@@ -511,4 +466,4 @@ func (r *GormCertificationQueryRepository) GetCacheStats() map[string]interface{
|
||||
QueryCachePatternUser,
|
||||
}
|
||||
return stats
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"tyapi-server/internal/domains/certification/entities"
|
||||
"tyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
EnterpriseInfoSubmitRecordsTable = "enterprise_info_submit_records"
|
||||
)
|
||||
|
||||
type GormEnterpriseInfoSubmitRecordRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.EnterpriseInfoSubmitRecord{})
|
||||
}
|
||||
|
||||
func NewGormEnterpriseInfoSubmitRecordRepository(db *gorm.DB, logger *zap.Logger) *GormEnterpriseInfoSubmitRecordRepository {
|
||||
return &GormEnterpriseInfoSubmitRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, EnterpriseInfoSubmitRecordsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Create(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
|
||||
return r.CreateEntity(ctx, record)
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Update(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
|
||||
return r.UpdateEntity(ctx, record)
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Exists(ctx context.Context, ID string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, ID, &entities.EnterpriseInfoSubmitRecord{})
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Order("submit_at DESC").
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestVerifiedByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND status = ?", userID, "verified").
|
||||
Order("verified_at DESC").
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"tyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
AlipayOrdersTable = "alipay_orders"
|
||||
)
|
||||
|
||||
type GormAlipayOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.AlipayOrderRepository = (*GormAlipayOrderRepository)(nil)
|
||||
|
||||
func NewGormAlipayOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.AlipayOrderRepository {
|
||||
return &GormAlipayOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, AlipayOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Create(ctx context.Context, order entities.AlipayOrder) (entities.AlipayOrder, error) {
|
||||
err := r.CreateEntity(ctx, &order)
|
||||
return order, err
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByID(ctx context.Context, id string) (entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.AlipayOrder{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.AlipayOrder{}, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.AlipayOrder, error) {
|
||||
var orders []entities.AlipayOrder
|
||||
err := r.GetDB(ctx).
|
||||
Joins("JOIN recharge_records ON alipay_orders.recharge_id = recharge_records.id").
|
||||
Where("recharge_records.user_id = ?", userID).
|
||||
Order("alipay_orders.created_at DESC").
|
||||
Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Update(ctx context.Context, order entities.AlipayOrder) error {
|
||||
return r.UpdateEntity(ctx, &order)
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.AlipayOrderStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.AlipayOrder{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -1,697 +0,0 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
|
||||
"tyapi-server/internal/domains/finance/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormWalletRepository 钱包GORM仓储实现
|
||||
type GormWalletRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
|
||||
|
||||
// NewGormWalletRepository 创建钱包GORM仓储
|
||||
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
|
||||
return &GormWalletRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建钱包
|
||||
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
|
||||
r.logger.Info("创建钱包", zap.String("user_id", wallet.UserID))
|
||||
err := r.db.WithContext(ctx).Create(&wallet).Error
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取钱包
|
||||
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Wallet{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Wallet{}, err
|
||||
}
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
// Update 更新钱包
|
||||
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
|
||||
r.logger.Info("更新钱包", zap.String("id", wallet.ID))
|
||||
return r.db.WithContext(ctx).Save(&wallet).Error
|
||||
}
|
||||
|
||||
// Delete 删除钱包
|
||||
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除钱包", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除钱包
|
||||
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除钱包", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复钱包
|
||||
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复钱包", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Wallet{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计钱包数量
|
||||
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查钱包是否存在
|
||||
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建钱包
|
||||
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
r.logger.Info("批量创建钱包", zap.Int("count", len(wallets)))
|
||||
return r.db.WithContext(ctx).Create(&wallets).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取钱包
|
||||
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&wallets).Error
|
||||
return wallets, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新钱包
|
||||
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
r.logger.Info("批量更新钱包", zap.Int("count", len(wallets)))
|
||||
return r.db.WithContext(ctx).Save(&wallets).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除钱包
|
||||
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除钱包", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取钱包列表
|
||||
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return wallets, query.Find(&wallets).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找钱包
|
||||
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID 检查用户钱包是否存在
|
||||
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GetTotalBalance 获取总余额
|
||||
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetActiveWalletCount 获取激活钱包数量
|
||||
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
// GetByUserID 根据用户ID获取钱包
|
||||
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// GetByWalletAddress 根据钱包地址获取钱包
|
||||
func (r *GormWalletRepository) GetByWalletAddress(ctx context.Context, walletAddress string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("wallet_address = ?", walletAddress).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// GetByWalletType 根据钱包类型获取钱包
|
||||
func (r *GormWalletRepository) GetByWalletType(ctx context.Context, userID string, walletType string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND wallet_type = ?", userID, walletType).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// ListWallets 获取钱包列表(带分页和筛选)
|
||||
func (r *GormWalletRepository) ListWallets(ctx context.Context, query *queries.ListWalletsQuery) ([]*entities.Wallet, int64, error) {
|
||||
var wallets []entities.Wallet
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Wallet{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.WalletType != "" {
|
||||
dbQuery = dbQuery.Where("wallet_type = ?", query.WalletType)
|
||||
}
|
||||
if query.WalletAddress != "" {
|
||||
dbQuery = dbQuery.Where("wallet_address LIKE ?", "%"+query.WalletAddress+"%")
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&wallets).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
walletPtrs := make([]*entities.Wallet, len(wallets))
|
||||
for i := range wallets {
|
||||
walletPtrs[i] = &wallets[i]
|
||||
}
|
||||
|
||||
return walletPtrs, total, nil
|
||||
}
|
||||
|
||||
// UpdateBalance 更新钱包余额
|
||||
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
|
||||
}
|
||||
|
||||
// AddBalance 增加钱包余额
|
||||
func (r *GormWalletRepository) AddBalance(ctx context.Context, walletID string, amount string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance + ?", amount)).Error
|
||||
}
|
||||
|
||||
// SubtractBalance 减少钱包余额
|
||||
func (r *GormWalletRepository) SubtractBalance(ctx context.Context, walletID string, amount string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance - ?", amount)).Error
|
||||
}
|
||||
|
||||
// ActivateWallet 激活钱包
|
||||
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
|
||||
}
|
||||
|
||||
// DeactivateWallet 停用钱包
|
||||
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// GetStats 获取财务统计信息
|
||||
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 总钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 激活钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetUserWalletStats 获取用户钱包统计信息
|
||||
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 用户钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户激活钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 用户今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GormUserSecretsRepository 用户密钥GORM仓储实现
|
||||
type GormUserSecretsRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ domain_finance_repo.UserSecretsRepository = (*GormUserSecretsRepository)(nil)
|
||||
|
||||
// NewGormUserSecretsRepository 创建用户密钥GORM仓储
|
||||
func NewGormUserSecretsRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.UserSecretsRepository {
|
||||
return &GormUserSecretsRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建用户密钥
|
||||
func (r *GormUserSecretsRepository) Create(ctx context.Context, secrets entities.UserSecrets) (entities.UserSecrets, error) {
|
||||
r.logger.Info("创建用户密钥", zap.String("user_id", secrets.UserID))
|
||||
err := r.db.WithContext(ctx).Create(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetByID(ctx context.Context, id string) (entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&secrets).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.UserSecrets{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.UserSecrets{}, err
|
||||
}
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// Update 更新用户密钥
|
||||
func (r *GormUserSecretsRepository) Update(ctx context.Context, secrets entities.UserSecrets) error {
|
||||
r.logger.Info("更新用户密钥", zap.String("id", secrets.ID))
|
||||
return r.db.WithContext(ctx).Save(&secrets).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户密钥
|
||||
func (r *GormUserSecretsRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除用户密钥", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除用户密钥
|
||||
func (r *GormUserSecretsRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除用户密钥", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复用户密钥
|
||||
func (r *GormUserSecretsRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复用户密钥", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.UserSecrets{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计用户密钥数量
|
||||
func (r *GormUserSecretsRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查用户密钥是否存在
|
||||
func (r *GormUserSecretsRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建用户密钥
|
||||
func (r *GormUserSecretsRepository) CreateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
|
||||
r.logger.Info("批量创建用户密钥", zap.Int("count", len(secrets)))
|
||||
return r.db.WithContext(ctx).Create(&secrets).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.UserSecrets, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新用户密钥
|
||||
func (r *GormUserSecretsRepository) UpdateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
|
||||
r.logger.Info("批量更新用户密钥", zap.Int("count", len(secrets)))
|
||||
return r.db.WithContext(ctx).Save(&secrets).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除用户密钥
|
||||
func (r *GormUserSecretsRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除用户密钥", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取用户密钥列表
|
||||
func (r *GormUserSecretsRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.UserSecrets, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return secrets, query.Find(&secrets).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormUserSecretsRepository) WithTx(tx interface{}) interfaces.Repository[entities.UserSecrets] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormUserSecretsRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找用户密钥
|
||||
func (r *GormUserSecretsRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// FindByAccessID 根据访问ID查找用户密钥
|
||||
func (r *GormUserSecretsRepository) FindByAccessID(ctx context.Context, accessID string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("access_id = ?", accessID).First(&secrets).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID 检查用户密钥是否存在
|
||||
func (r *GormUserSecretsRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ExistsByAccessID 检查访问ID是否存在
|
||||
func (r *GormUserSecretsRepository) ExistsByAccessID(ctx context.Context, accessID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// UpdateLastUsedAt 更新最后使用时间
|
||||
func (r *GormUserSecretsRepository) UpdateLastUsedAt(ctx context.Context, accessID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Update("last_used_at", time.Now()).Error
|
||||
}
|
||||
|
||||
// DeactivateByUserID 停用用户密钥
|
||||
func (r *GormUserSecretsRepository) DeactivateByUserID(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// RegenerateAccessKey 重新生成访问密钥
|
||||
func (r *GormUserSecretsRepository) RegenerateAccessKey(ctx context.Context, userID string, accessID, accessKey string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"access_id": accessID,
|
||||
"access_key": accessKey,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetExpiredSecrets 获取过期的密钥
|
||||
func (r *GormUserSecretsRepository) GetExpiredSecrets(ctx context.Context) ([]entities.UserSecrets, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Find(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// DeleteExpiredSecrets 删除过期的密钥
|
||||
func (r *GormUserSecretsRepository) DeleteExpiredSecrets(ctx context.Context) error {
|
||||
return r.db.WithContext(ctx).Where("expires_at < ?", time.Now()).Delete(&entities.UserSecrets{}).Error
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
// GetByUserID 根据用户ID获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// GetBySecretType 根据用户ID和密钥类型获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetBySecretType(ctx context.Context, userID string, secretType string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).First(&secrets).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// ListUserSecrets 获取用户密钥列表(带分页和筛选)
|
||||
func (r *GormUserSecretsRepository) ListUserSecrets(ctx context.Context, query *queries.ListUserSecretsQuery) ([]*entities.UserSecrets, int64, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.SecretType != "" {
|
||||
dbQuery = dbQuery.Where("secret_type = ?", query.SecretType)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&secrets).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
secretPtrs := make([]*entities.UserSecrets, len(secrets))
|
||||
for i := range secrets {
|
||||
secretPtrs[i] = &secrets[i]
|
||||
}
|
||||
|
||||
return secretPtrs, total, nil
|
||||
}
|
||||
|
||||
// UpdateSecret 更新密钥
|
||||
func (r *GormUserSecretsRepository) UpdateSecret(ctx context.Context, userID string, secretType string, secretValue string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
|
||||
Where("user_id = ? AND secret_type = ?", userID, secretType).
|
||||
Update("secret_value", secretValue).Error
|
||||
}
|
||||
|
||||
// DeleteSecret 删除密钥
|
||||
func (r *GormUserSecretsRepository) DeleteSecret(ctx context.Context, userID string, secretType string) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).
|
||||
Delete(&entities.UserSecrets{}).Error
|
||||
}
|
||||
|
||||
// ValidateSecret 验证密钥
|
||||
func (r *GormUserSecretsRepository) ValidateSecret(ctx context.Context, userID string, secretType string, secretValue string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
|
||||
Where("user_id = ? AND secret_type = ? AND secret_value = ?", userID, secretType, secretValue).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,178 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"tyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
RechargeRecordsTable = "recharge_records"
|
||||
)
|
||||
|
||||
type GormRechargeRecordRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.RechargeRecordRepository = (*GormRechargeRecordRepository)(nil)
|
||||
|
||||
func NewGormRechargeRecordRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.RechargeRecordRepository {
|
||||
return &GormRechargeRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, RechargeRecordsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Create(ctx context.Context, record entities.RechargeRecord) (entities.RechargeRecord, error) {
|
||||
err := r.CreateEntity(ctx, &record)
|
||||
return record, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByID(ctx context.Context, id string) (entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.SmartGetByID(ctx, id, &record)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.RechargeRecord{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.RechargeRecord{}, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByUserID(ctx context.Context, userID string) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByAlipayOrderID(ctx context.Context, alipayOrderID string) (*entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("alipay_order_id = ?", alipayOrderID).First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByTransferOrderID(ctx context.Context, transferOrderID string) (*entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("transfer_order_id = ?", transferOrderID).First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Update(ctx context.Context, record entities.RechargeRecord) error {
|
||||
return r.UpdateEntity(ctx, &record)
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) UpdateStatus(ctx context.Context, id string, status entities.RechargeStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.RechargeRecord{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
query := r.GetDB(ctx).Model(&entities.RechargeRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) CreateBatch(ctx context.Context, records []entities.RechargeRecord) error {
|
||||
return r.GetDB(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) UpdateBatch(ctx context.Context, records []entities.RechargeRecord) error {
|
||||
return r.GetDB(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.RechargeRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.RechargeRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormRechargeRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), RechargeRecordsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"tyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
WalletsTable = "wallets"
|
||||
)
|
||||
|
||||
type GormWalletRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
|
||||
|
||||
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
|
||||
return &GormWalletRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
|
||||
err := r.CreateEntity(ctx, &wallet)
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.SmartGetByID(ctx, id, &wallet)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Wallet{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Wallet{}, err
|
||||
}
|
||||
return wallet, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
|
||||
return r.UpdateEntity(ctx, &wallet)
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Wallet{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
return r.GetDB(ctx).Create(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&wallets).Error
|
||||
return wallets, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
return r.GetDB(ctx).Save(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
query := r.GetDB(ctx).Model(&entities.Wallet{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
}
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
return wallets, query.Find(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// UpdateBalanceWithVersionRetry 乐观锁自动重试,最大重试maxRetry次
|
||||
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, newBalance string, oldVersion int64) (bool, error) {
|
||||
maxRetry := 10
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
result := r.GetDB(ctx).Model(&entities.Wallet{}).
|
||||
Where("id = ? AND version = ?", walletID, oldVersion).
|
||||
Updates(map[string]interface{}{
|
||||
"balance": newBalance,
|
||||
"version": oldVersion + 1,
|
||||
})
|
||||
if result.Error != nil {
|
||||
return false, result.Error
|
||||
}
|
||||
if result.RowsAffected == 1 {
|
||||
return true, nil
|
||||
}
|
||||
// 并发冲突,重试前重新查version
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("id = ?", walletID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
oldVersion = wallet.Version
|
||||
}
|
||||
return false, fmt.Errorf("高并发下余额变动失败,请重试")
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 总钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 激活钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 用户钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户激活钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 用户今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
"tyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WalletTransactionWithProduct 包含产品名称的钱包交易记录
|
||||
type WalletTransactionWithProduct struct {
|
||||
entities.WalletTransaction
|
||||
ProductName string `json:"product_name" gorm:"column:product_name"`
|
||||
}
|
||||
|
||||
const (
|
||||
WalletTransactionsTable = "wallet_transactions"
|
||||
)
|
||||
|
||||
type GormWalletTransactionRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WalletTransactionRepository = (*GormWalletTransactionRepository)(nil)
|
||||
|
||||
func NewGormWalletTransactionRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletTransactionRepository {
|
||||
return &GormWalletTransactionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletTransactionsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Create(ctx context.Context, transaction entities.WalletTransaction) (entities.WalletTransaction, error) {
|
||||
err := r.CreateEntity(ctx, &transaction)
|
||||
return transaction, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Update(ctx context.Context, transaction entities.WalletTransaction) error {
|
||||
return r.UpdateEntity(ctx, &transaction)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByID(ctx context.Context, id string) (entities.WalletTransaction, error) {
|
||||
var transaction entities.WalletTransaction
|
||||
err := r.SmartGetByID(ctx, id, &transaction)
|
||||
return transaction, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.WalletTransaction, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
options := database.CacheListOptions{
|
||||
Where: "user_id = ?",
|
||||
Args: []interface{}{userID},
|
||||
Order: "created_at DESC",
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
err := r.ListWithCache(ctx, &transactions, 10*time.Minute, options)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByApiCallID(ctx context.Context, apiCallID string) (*entities.WalletTransaction, error) {
|
||||
var transaction entities.WalletTransaction
|
||||
err := r.FindOne(ctx, &transaction, "api_call_id = ?", apiCallID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 关键词筛选(支持transaction_id和product_name)
|
||||
if keyword, ok := filters["keyword"].(string); ok && keyword != "" {
|
||||
whereCondition += " AND (transaction_id LIKE ? OR product_id IN (SELECT id FROM product WHERE name LIKE ?))"
|
||||
whereArgs = append(whereArgs, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// API调用ID筛选
|
||||
if apiCallId, ok := filters["api_call_id"].(string); ok && apiCallId != "" {
|
||||
whereCondition += " AND api_call_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+apiCallId+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ?", userId)
|
||||
}
|
||||
|
||||
// 实现interfaces.Repository接口的其他方法
|
||||
func (r *GormWalletTransactionRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.WalletTransaction, error) {
|
||||
var transactions []entities.WalletTransaction
|
||||
err := r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
return r.CountWithOptions(ctx, &entities.WalletTransaction{}, options)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) CreateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
|
||||
return r.CreateBatchEntity(ctx, &transactions)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.WalletTransaction, error) {
|
||||
var transactions []entities.WalletTransaction
|
||||
err := r.GetEntitiesByIDs(ctx, ids, &transactions)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) UpdateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
|
||||
return r.UpdateBatchEntity(ctx, &transactions)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.DeleteBatchEntity(ctx, ids, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) WithTx(tx interface{}) interfaces.Repository[entities.WalletTransaction] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletTransactionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletTransactionsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
|
||||
var transactionsWithProduct []*WalletTransactionWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "wt.user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND wt.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND wt.amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND wt.amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("wt." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("wt.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction并构建产品名称映射
|
||||
var transactions []*entities.WalletTransaction
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if t.ProductName != "" {
|
||||
productNameMap[transaction.ProductID] = t.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, transactions, total, nil
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"tyapi-server/internal/domains/product/entities"
|
||||
"tyapi-server/internal/domains/product/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductApiConfigsTable = "product_api_configs"
|
||||
ProductApiConfigCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type GormProductApiConfigRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ProductApiConfigRepository = (*GormProductApiConfigRepository)(nil)
|
||||
|
||||
func NewGormProductApiConfigRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductApiConfigRepository {
|
||||
return &GormProductApiConfigRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductApiConfigsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Create(ctx context.Context, config entities.ProductApiConfig) error {
|
||||
return r.CreateEntity(ctx, &config)
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Update(ctx context.Context, config entities.ProductApiConfig) error {
|
||||
return r.UpdateEntity(ctx, &config)
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductApiConfig{})
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) GetByID(ctx context.Context, id string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.SmartGetByID(ctx, id, &config)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.SmartGetByField(ctx, &config, "product_id", productID, ProductApiConfigCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductCode(ctx context.Context, productCode string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.GetDB(ctx).Joins("JOIN products ON products.id = product_api_configs.product_id").
|
||||
Where("products.code = ?", productCode).
|
||||
First(&config).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductApiConfig, error) {
|
||||
var configs []*entities.ProductApiConfig
|
||||
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) ExistsByProductID(ctx context.Context, productID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductApiConfig{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
@@ -3,44 +3,44 @@ package repositories
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/product/entities"
|
||||
"tyapi-server/internal/domains/product/repositories"
|
||||
"tyapi-server/internal/domains/product/repositories/queries"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductCategoriesTable = "product_categories"
|
||||
)
|
||||
|
||||
// GormProductCategoryRepository GORM产品分类仓储实现
|
||||
type GormProductCategoryRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductCategory{})
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ProductCategoryRepository = (*GormProductCategoryRepository)(nil)
|
||||
|
||||
// NewGormProductCategoryRepository 创建GORM产品分类仓储
|
||||
func NewGormProductCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductCategoryRepository {
|
||||
return &GormProductCategoryRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductCategoriesTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建产品分类
|
||||
func (r *GormProductCategoryRepository) Create(ctx context.Context, entity entities.ProductCategory) (entities.ProductCategory, error) {
|
||||
r.logger.Info("创建产品分类", zap.String("id", entity.ID), zap.String("name", entity.Name))
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取产品分类
|
||||
func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string) (entities.ProductCategory, error) {
|
||||
var entity entities.ProductCategory
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.ProductCategory{}, gorm.ErrRecordNotFound
|
||||
@@ -50,22 +50,14 @@ func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string)
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新产品分类
|
||||
func (r *GormProductCategoryRepository) Update(ctx context.Context, entity entities.ProductCategory) error {
|
||||
r.logger.Info("更新产品分类", zap.String("id", entity.ID))
|
||||
return r.db.WithContext(ctx).Save(&entity).Error
|
||||
}
|
||||
|
||||
// Delete 删除产品分类
|
||||
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除产品分类", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// FindByCode 根据编号查找产品分类
|
||||
func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductCategory, error) {
|
||||
var entity entities.ProductCategory
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&entity).Error
|
||||
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
@@ -78,7 +70,7 @@ func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code str
|
||||
// FindVisible 查找可见分类
|
||||
func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.db.WithContext(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Find(&categories).Error
|
||||
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -94,7 +86,7 @@ func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*ent
|
||||
// FindEnabled 查找启用分类
|
||||
func (r *GormProductCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.db.WithContext(ctx).Where("is_enabled = ?", true).Find(&categories).Error
|
||||
err := r.GetDB(ctx).Where("is_enabled = ?", true).Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -112,7 +104,7 @@ func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, quer
|
||||
var categories []entities.ProductCategory
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.ProductCategory{})
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.IsEnabled != nil {
|
||||
@@ -164,14 +156,14 @@ func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, quer
|
||||
// CountEnabled 统计启用分类数量
|
||||
func (r *GormProductCategoryRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountVisible 统计可见分类数量
|
||||
func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -180,7 +172,7 @@ func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64
|
||||
// Count 返回分类总数
|
||||
func (r *GormProductCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.ProductCategory{})
|
||||
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
@@ -201,29 +193,29 @@ func (r *GormProductCategoryRepository) Count(ctx context.Context, options inter
|
||||
// GetByIDs 根据ID列表获取分类
|
||||
func (r *GormProductCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&categories).Error
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建分类
|
||||
func (r *GormProductCategoryRepository) CreateBatch(ctx context.Context, categories []entities.ProductCategory) error {
|
||||
return r.db.WithContext(ctx).Create(&categories).Error
|
||||
return r.GetDB(ctx).Create(&categories).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新分类
|
||||
func (r *GormProductCategoryRepository) UpdateBatch(ctx context.Context, categories []entities.ProductCategory) error {
|
||||
return r.db.WithContext(ctx).Save(&categories).Error
|
||||
return r.GetDB(ctx).Save(&categories).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除分类
|
||||
func (r *GormProductCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
|
||||
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取分类列表(基础方法)
|
||||
func (r *GormProductCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
query := r.db.WithContext(ctx).Model(&entities.ProductCategory{})
|
||||
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
@@ -261,26 +253,25 @@ func (r *GormProductCategoryRepository) List(ctx context.Context, options interf
|
||||
// Exists 检查分类是否存在
|
||||
func (r *GormProductCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除分类
|
||||
func (r *GormProductCategoryRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
|
||||
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的分类
|
||||
func (r *GormProductCategoryRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormProductCategoryRepository) WithTx(tx interface{}) interfaces.Repository[entities.ProductCategory] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormProductCategoryRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductCategoriesTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
|
||||
@@ -6,40 +6,41 @@ import (
|
||||
"tyapi-server/internal/domains/product/entities"
|
||||
"tyapi-server/internal/domains/product/repositories"
|
||||
"tyapi-server/internal/domains/product/repositories/queries"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormProductRepository GORM产品仓储实现
|
||||
const (
|
||||
ProductsTable = "products"
|
||||
)
|
||||
|
||||
type GormProductRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Product{})
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ProductRepository = (*GormProductRepository)(nil)
|
||||
|
||||
// NewGormProductRepository 创建GORM产品仓储
|
||||
func NewGormProductRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductRepository {
|
||||
return &GormProductRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建产品
|
||||
func (r *GormProductRepository) Create(ctx context.Context, entity entities.Product) (entities.Product, error) {
|
||||
r.logger.Info("创建产品", zap.String("id", entity.ID), zap.String("name", entity.Name))
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取产品
|
||||
func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.db.WithContext(ctx).Preload("Category").Where("id = ?", id).First(&entity).Error
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Product{}, gorm.ErrRecordNotFound
|
||||
@@ -49,26 +50,17 @@ func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entitie
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新产品
|
||||
func (r *GormProductRepository) Update(ctx context.Context, entity entities.Product) error {
|
||||
r.logger.Info("更新产品", zap.String("id", entity.ID))
|
||||
return r.db.WithContext(ctx).Save(&entity).Error
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// Delete 删除产品
|
||||
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除产品", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Product{}, "id = ?", id).Error
|
||||
}
|
||||
// 其它方法同理迁移,全部用r.GetDB(ctx)
|
||||
|
||||
// FindByCode 根据编号查找产品
|
||||
func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.db.WithContext(ctx).Preload("Category").Where("code = ?", code).First(&entity).Error
|
||||
err := r.SmartGetByField(ctx, &entity, "code", code) // 自动缓存
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
@@ -77,11 +69,11 @@ func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*e
|
||||
// FindByCategoryID 根据分类ID查找产品
|
||||
func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.db.WithContext(ctx).Preload("Category").Where("category_id = ?", categoryID).Find(&productEntities).Error
|
||||
err := r.GetDB(ctx).Preload("Category").Where("category_id = ?", categoryID).Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
@@ -93,11 +85,11 @@ func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID
|
||||
// FindVisible 查找可见产品
|
||||
func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.db.WithContext(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Find(&productEntities).Error
|
||||
err := r.GetDB(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
@@ -109,11 +101,11 @@ func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Pr
|
||||
// FindEnabled 查找启用产品
|
||||
func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.db.WithContext(ctx).Preload("Category").Where("is_enabled = ?", true).Find(&productEntities).Error
|
||||
err := r.GetDB(ctx).Preload("Category").Where("is_enabled = ?", true).Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
@@ -126,12 +118,12 @@ func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Pr
|
||||
func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, int64, error) {
|
||||
var productEntities []entities.Product
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Product{})
|
||||
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
|
||||
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
if query.CategoryID != "" {
|
||||
@@ -152,12 +144,12 @@ func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries
|
||||
if query.IsPackage != nil {
|
||||
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
|
||||
}
|
||||
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
@@ -170,32 +162,31 @@ func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
|
||||
// 预加载分类信息并获取数据
|
||||
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
|
||||
// FindSubscribableProducts 查找可订阅产品
|
||||
func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, userID string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.db.WithContext(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Find(&productEntities).Error
|
||||
err := r.GetDB(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -210,7 +201,7 @@ func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, us
|
||||
// FindProductsByIDs 根据ID列表查找产品
|
||||
func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&productEntities).Error
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -225,7 +216,7 @@ func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []str
|
||||
// CountByCategory 统计分类下的产品数量
|
||||
func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID string) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Product{})
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
if categoryID != "" {
|
||||
query = query.Where("category_id = ?", categoryID)
|
||||
}
|
||||
@@ -236,34 +227,34 @@ func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID
|
||||
// CountEnabled 统计启用产品数量
|
||||
func (r *GormProductRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountVisible 统计可见产品数量
|
||||
func (r *GormProductRepository) CountVisible(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
}
|
||||
|
||||
// Count 返回产品总数
|
||||
func (r *GormProductRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Product{})
|
||||
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
@@ -271,42 +262,42 @@ func (r *GormProductRepository) Count(ctx context.Context, options interfaces.Co
|
||||
// GetByIDs 根据ID列表获取产品
|
||||
func (r *GormProductRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Product, error) {
|
||||
var products []entities.Product
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&products).Error
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建产品
|
||||
func (r *GormProductRepository) CreateBatch(ctx context.Context, products []entities.Product) error {
|
||||
return r.db.WithContext(ctx).Create(&products).Error
|
||||
return r.GetDB(ctx).Create(&products).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新产品
|
||||
func (r *GormProductRepository) UpdateBatch(ctx context.Context, products []entities.Product) error {
|
||||
return r.db.WithContext(ctx).Save(&products).Error
|
||||
return r.GetDB(ctx).Save(&products).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除产品
|
||||
func (r *GormProductRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
|
||||
return r.GetDB(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取产品列表(基础方法)
|
||||
func (r *GormProductRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Product, error) {
|
||||
var products []entities.Product
|
||||
query := r.db.WithContext(ctx).Model(&entities.Product{})
|
||||
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
@@ -317,13 +308,13 @@ func (r *GormProductRepository) List(ctx context.Context, options interfaces.Lis
|
||||
}
|
||||
query = query.Order(order)
|
||||
}
|
||||
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
|
||||
err := query.Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
@@ -331,27 +322,80 @@ func (r *GormProductRepository) List(ctx context.Context, options interfaces.Lis
|
||||
// Exists 检查产品是否存在
|
||||
func (r *GormProductRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除产品
|
||||
func (r *GormProductRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Product{}, "id = ?", id).Error
|
||||
return r.GetDB(ctx).Delete(&entities.Product{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的产品
|
||||
func (r *GormProductRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// GetPackageItems 获取组合包项目
|
||||
func (r *GormProductRepository) GetPackageItems(ctx context.Context, packageID string) ([]*entities.ProductPackageItem, error) {
|
||||
var packageItems []entities.ProductPackageItem
|
||||
err := r.GetDB(ctx).
|
||||
Preload("Product").
|
||||
Where("package_id = ?", packageID).
|
||||
Order("sort_order ASC").
|
||||
Find(&packageItems).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductPackageItem, len(packageItems))
|
||||
for i := range packageItems {
|
||||
result[i] = &packageItems[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CreatePackageItem 创建组合包项目
|
||||
func (r *GormProductRepository) CreatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
|
||||
return r.GetDB(ctx).Create(packageItem).Error
|
||||
}
|
||||
|
||||
// GetPackageItemByID 根据ID获取组合包项目
|
||||
func (r *GormProductRepository) GetPackageItemByID(ctx context.Context, itemID string) (*entities.ProductPackageItem, error) {
|
||||
var packageItem entities.ProductPackageItem
|
||||
err := r.GetDB(ctx).
|
||||
Preload("Product").
|
||||
Preload("Package").
|
||||
Where("id = ?", itemID).
|
||||
First(&packageItem).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packageItem, nil
|
||||
}
|
||||
|
||||
// UpdatePackageItem 更新组合包项目
|
||||
func (r *GormProductRepository) UpdatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
|
||||
return r.GetDB(ctx).Save(packageItem).Error
|
||||
}
|
||||
|
||||
// DeletePackageItem 删除组合包项目(硬删除)
|
||||
func (r *GormProductRepository) DeletePackageItem(ctx context.Context, itemID string) error {
|
||||
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "id = ?", itemID).Error
|
||||
}
|
||||
|
||||
// DeletePackageItemsByPackageID 根据组合包ID删除所有子产品(硬删除)
|
||||
func (r *GormProductRepository) DeletePackageItemsByPackageID(ctx context.Context, packageID string) error {
|
||||
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "package_id = ?", packageID).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormProductRepository) WithTx(tx interface{}) interfaces.Repository[entities.Product] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormProductRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,44 +3,46 @@ package repositories
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"time"
|
||||
"tyapi-server/internal/domains/product/entities"
|
||||
"tyapi-server/internal/domains/product/repositories"
|
||||
"tyapi-server/internal/domains/product/repositories/queries"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
SubscriptionsTable = "subscriptions"
|
||||
SubscriptionCacheTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
// GormSubscriptionRepository GORM订阅仓储实现
|
||||
type GormSubscriptionRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Subscription{})
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.SubscriptionRepository = (*GormSubscriptionRepository)(nil)
|
||||
|
||||
// NewGormSubscriptionRepository 创建GORM订阅仓储
|
||||
func NewGormSubscriptionRepository(db *gorm.DB, logger *zap.Logger) repositories.SubscriptionRepository {
|
||||
return &GormSubscriptionRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SubscriptionsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建订阅
|
||||
func (r *GormSubscriptionRepository) Create(ctx context.Context, entity entities.Subscription) (entities.Subscription, error) {
|
||||
r.logger.Info("创建订阅", zap.String("id", entity.ID), zap.String("user_id", entity.UserID))
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (entities.Subscription, error) {
|
||||
var entity entities.Subscription
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Subscription{}, gorm.ErrRecordNotFound
|
||||
@@ -50,22 +52,14 @@ func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (en
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新订阅
|
||||
func (r *GormSubscriptionRepository) Update(ctx context.Context, entity entities.Subscription) error {
|
||||
r.logger.Info("更新订阅", zap.String("id", entity.ID))
|
||||
return r.db.WithContext(ctx).Save(&entity).Error
|
||||
}
|
||||
|
||||
// Delete 删除订阅
|
||||
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除订阅", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&subscriptions).Error
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("user_id = ?", userID).Find(&subscriptions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,7 +75,7 @@ func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID st
|
||||
// FindByProductID 根据产品ID查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, productID string) ([]*entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.db.WithContext(ctx).Where("product_id = ?", productID).Find(&subscriptions).Error
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("product_id = ?", productID).Find(&subscriptions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -97,7 +91,10 @@ func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, produc
|
||||
// FindByUserAndProduct 根据用户和产品查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByUserAndProduct(ctx context.Context, userID, productID string) (*entities.Subscription, error) {
|
||||
var entity entities.Subscription
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND product_id = ?", userID, productID).First(&entity).Error
|
||||
// 组合缓存key的条件
|
||||
where := "user_id = ? AND product_id = ?"
|
||||
ttl := SubscriptionCacheTTL // 缓存10分钟,可根据业务调整
|
||||
err := r.GetWithCache(ctx, &entity, ttl, where, userID, productID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
@@ -112,7 +109,7 @@ func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, quer
|
||||
var subscriptions []entities.Subscription
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Subscription{})
|
||||
dbQuery := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
@@ -172,14 +169,14 @@ func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, quer
|
||||
// CountByUser 统计用户订阅数量
|
||||
func (r *GormSubscriptionRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByProduct 统计产品订阅数量
|
||||
func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, productID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
@@ -188,7 +185,7 @@ func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, product
|
||||
// Count 返回订阅总数
|
||||
func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Subscription{})
|
||||
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
@@ -209,29 +206,29 @@ func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfac
|
||||
// GetByIDs 根据ID列表获取订阅
|
||||
func (r *GormSubscriptionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&subscriptions).Error
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("id IN ?", ids).Find(&subscriptions).Error
|
||||
return subscriptions, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建订阅
|
||||
func (r *GormSubscriptionRepository) CreateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
|
||||
return r.db.WithContext(ctx).Create(&subscriptions).Error
|
||||
return r.GetDB(ctx).WithContext(ctx).Create(&subscriptions).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新订阅
|
||||
func (r *GormSubscriptionRepository) UpdateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
|
||||
return r.db.WithContext(ctx).Save(&subscriptions).Error
|
||||
return r.GetDB(ctx).WithContext(ctx).Save(&subscriptions).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除订阅
|
||||
func (r *GormSubscriptionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
|
||||
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取订阅列表(基础方法)
|
||||
func (r *GormSubscriptionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
query := r.db.WithContext(ctx).Model(&entities.Subscription{})
|
||||
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
@@ -269,26 +266,25 @@ func (r *GormSubscriptionRepository) List(ctx context.Context, options interface
|
||||
// Exists 检查订阅是否存在
|
||||
func (r *GormSubscriptionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除订阅
|
||||
func (r *GormSubscriptionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
|
||||
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的订阅
|
||||
func (r *GormSubscriptionRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
return r.GetDB(ctx).WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormSubscriptionRepository) WithTx(tx interface{}) interfaces.Repository[entities.Subscription] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormSubscriptionRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), SubscriptionsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
|
||||
@@ -0,0 +1,101 @@
|
||||
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ContractInfosTable = "contract_infos"
|
||||
)
|
||||
|
||||
type GormContractInfoRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
|
||||
return &GormContractInfoRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
|
||||
return r.CreateEntity(ctx, contract)
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
|
||||
var contract entities.ContractInfo
|
||||
err := r.SmartGetByID(ctx, contractID, &contract)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &contract, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
|
||||
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -15,20 +15,23 @@ import (
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/domains/user/repositories/queries"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
SMSCodesTable = "sms_codes"
|
||||
)
|
||||
|
||||
// GormSMSCodeRepository 短信验证码GORM仓储实现(无缓存,确保安全性)
|
||||
type GormSMSCodeRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// NewGormSMSCodeRepository 创建短信验证码仓储
|
||||
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
|
||||
return &GormSMSCodeRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,64 +42,54 @@ var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
|
||||
|
||||
// Create 创建短信验证码记录(不缓存,确保安全性)
|
||||
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.logger.Error("创建短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
return smsCode, nil
|
||||
err := r.GetDB(ctx).Create(&smsCode).Error
|
||||
return smsCode, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&smsCode).Error; err != nil {
|
||||
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
r.logger.Error("获取短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// Update 更新验证码记录
|
||||
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
|
||||
if err := r.db.WithContext(ctx).Save(&smsCode).Error; err != nil {
|
||||
r.logger.Error("更新短信验证码失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return r.GetDB(ctx).Save(&smsCode).Error
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建短信验证码
|
||||
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.db.WithContext(ctx).Create(&smsCodes).Error
|
||||
return r.GetDB(ctx).Create(&smsCodes).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
|
||||
return smsCodes, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新短信验证码
|
||||
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.db.WithContext(ctx).Save(&smsCodes).Error
|
||||
return r.GetDB(ctx).Save(&smsCodes).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除短信验证码
|
||||
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取短信验证码列表
|
||||
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
@@ -139,20 +132,20 @@ func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.Lis
|
||||
|
||||
// Delete 删除短信验证码
|
||||
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Exists 检查短信验证码是否存在
|
||||
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Count 统计短信验证码数量
|
||||
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
@@ -172,12 +165,12 @@ func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.Co
|
||||
|
||||
// SoftDelete 软删除短信验证码
|
||||
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复短信验证码
|
||||
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
@@ -185,7 +178,7 @@ func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
|
||||
// GetByPhone 根据手机号获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
@@ -198,7 +191,7 @@ func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*
|
||||
// GetLatestByPhone 根据手机号获取最新短信验证码
|
||||
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
@@ -211,7 +204,7 @@ func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone stri
|
||||
// GetValidByPhone 根据手机号获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
@@ -227,7 +220,7 @@ func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone strin
|
||||
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
@@ -246,7 +239,7 @@ func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
db := r.db.WithContext(ctx).Model(&entities.SMSCode{})
|
||||
db := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
@@ -288,8 +281,8 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.logger.Error("创建短信验证码失败", zap.Error(err))
|
||||
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
@@ -299,7 +292,7 @@ func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, co
|
||||
// ValidateCode 验证验证码
|
||||
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
|
||||
Count(&count).Error
|
||||
|
||||
@@ -309,7 +302,7 @@ func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string,
|
||||
// InvalidateCode 使验证码失效
|
||||
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&entities.SMSCode{}).
|
||||
return r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND used_at IS NULL", phone).
|
||||
Update("used_at", &now).Error
|
||||
}
|
||||
@@ -320,7 +313,7 @@ func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone st
|
||||
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
|
||||
Count(&count).Error
|
||||
|
||||
@@ -333,7 +326,7 @@ func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone str
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&count).Error
|
||||
|
||||
@@ -348,7 +341,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
// 总发送数
|
||||
if err := r.db.WithContext(ctx).
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, startDate).
|
||||
Count(&stats.TotalSent).Error; err != nil {
|
||||
@@ -356,7 +349,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
|
||||
}
|
||||
|
||||
// 总验证数
|
||||
if err := r.db.WithContext(ctx).
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
|
||||
Count(&stats.TotalValidated).Error; err != nil {
|
||||
@@ -370,7 +363,7 @@ func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string,
|
||||
|
||||
// 今日发送数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&stats.TodaySent).Error; err != nil {
|
||||
|
||||
@@ -14,242 +14,199 @@ import (
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/domains/user/repositories/queries"
|
||||
"tyapi-server/internal/shared/database"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
UsersTable = "users"
|
||||
UserCacheTTL = 30 * 60 // 30分钟
|
||||
)
|
||||
|
||||
// 定义错误常量
|
||||
var (
|
||||
// ErrUserNotFound 用户不存在错误
|
||||
ErrUserNotFound = errors.New("用户不存在")
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储实现(已移除手动缓存管理)
|
||||
type GormUserRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// NewGormUserRepository 创建用户仓储
|
||||
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
|
||||
return &GormUserRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 GormUserRepository 实现了 UserRepository 接口
|
||||
var _ repositories.UserRepository = (*GormUserRepository)(nil)
|
||||
|
||||
// ================ Repository[T] 接口实现 ================
|
||||
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
|
||||
return &GormUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建用户(自动缓存失效)
|
||||
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&user).Error; err != nil {
|
||||
r.logger.Error("创建用户失败", zap.Error(err))
|
||||
err := r.CreateEntity(ctx, &user)
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.SmartGetByID(ctx, id, &user)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, errors.New("用户不存在")
|
||||
}
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户(自动缓存)
|
||||
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
|
||||
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, ErrUserNotFound
|
||||
}
|
||||
r.logger.Error("根据ID查询用户失败", zap.Error(err))
|
||||
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Update 更新用户(自动缓存失效)
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
|
||||
if err := r.db.WithContext(ctx).Save(&user).Error; err != nil {
|
||||
r.logger.Error("更新用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建用户
|
||||
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.logger.Info("批量创建用户", zap.Int("count", len(users)))
|
||||
return r.db.WithContext(ctx).Create(&users).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取用户
|
||||
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新用户
|
||||
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.logger.Info("批量更新用户", zap.Int("count", len(users)))
|
||||
return r.db.WithContext(ctx).Save(&users).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除用户
|
||||
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除用户", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
query := r.db.WithContext(ctx).Model(&entities.User{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, include := range options.Include {
|
||||
query = query.Preload(include)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return users, query.Find(&users).Error
|
||||
}
|
||||
|
||||
// ================ BaseRepository 接口实现 ================
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("用户删除成功", zap.String("user_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查用户是否存在
|
||||
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
query := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
// Count 统计用户数量
|
||||
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.User{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("username LIKE ? OR phone LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
// 如果指定了排除的用户ID,则排除该用户的记录
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("users.id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
|
||||
return r.UpdateEntity(ctx, &user)
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Create(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Save(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.SmartList(ctx, &users, options)
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除用户
|
||||
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复用户
|
||||
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
// GetByPhone 根据手机号获取用户
|
||||
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.logger.Error("根据手机号查询用户失败", zap.Error(err))
|
||||
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByUsername 根据用户名获取用户
|
||||
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.logger.Error("根据用户名查询用户失败", zap.Error(err))
|
||||
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByUserType 根据用户类型获取用户列表
|
||||
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
|
||||
var users []*entities.User
|
||||
err := r.db.WithContext(ctx).Where("user_type = ?", userType).Find(&users).Error
|
||||
err := r.GetDB(ctx).Where("user_type = ?", userType).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表(带分页和筛选)
|
||||
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
|
||||
var users []*entities.User
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
db := r.db.WithContext(ctx).Model(&entities.User{})
|
||||
// 构建查询条件,预加载企业信息
|
||||
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("phone LIKE ?", "%"+query.Phone+"%")
|
||||
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
|
||||
}
|
||||
if query.UserType != "" {
|
||||
db = db.Where("users.user_type = ?", query.UserType)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
db = db.Where("users.active = ?", *query.IsActive)
|
||||
}
|
||||
if query.IsCertified != nil {
|
||||
db = db.Where("users.is_certified = ?", *query.IsCertified)
|
||||
}
|
||||
if query.CompanyName != "" {
|
||||
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("created_at >= ?", query.StartDate)
|
||||
db = db.Where("users.created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("created_at <= ?", query.EndDate)
|
||||
db = db.Where("users.created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
@@ -266,10 +223,9 @@ func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListU
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// ValidateUser 验证用户登录
|
||||
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
|
||||
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -277,10 +233,9 @@ func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password s
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateLastLogin 更新最后登录时间
|
||||
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&entities.User{}).
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_login_at": &now,
|
||||
@@ -288,40 +243,35 @@ func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string)
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdatePassword 更新密码
|
||||
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.User{}).
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("password", newPassword).Error
|
||||
}
|
||||
|
||||
// CheckPassword 检查密码
|
||||
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.User{}).
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ? AND password = ?", userID, password).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ActivateUser 激活用户
|
||||
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.User{}).
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", true).Error
|
||||
}
|
||||
|
||||
// DeactivateUser 停用用户
|
||||
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.User{}).
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", false).Error
|
||||
}
|
||||
|
||||
// UpdateLoginStats 更新登录统计
|
||||
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.User{}).
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"login_count": gorm.Expr("login_count + 1"),
|
||||
@@ -329,11 +279,10 @@ func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetStats 获取用户统计信息
|
||||
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.db.WithContext(ctx)
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
@@ -345,6 +294,11 @@ func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserSt
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
@@ -359,11 +313,10 @@ func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserSt
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetStatsByDateRange 获取指定日期范围的用户统计
|
||||
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.db.WithContext(ctx)
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/domains/certification/entities/value_objects"
|
||||
"tyapi-server/internal/domains/certification/enums"
|
||||
"tyapi-server/internal/domains/certification/repositories"
|
||||
"tyapi-server/internal/domains/certification/value_objects"
|
||||
"tyapi-server/internal/shared/esign"
|
||||
)
|
||||
|
||||
@@ -250,13 +250,19 @@ func (s *CertificationEsignService) HandleContractSignCallback(
|
||||
}
|
||||
|
||||
if success {
|
||||
// 合同签署成功,认证完成
|
||||
// 合同签署成功,更新合同URL
|
||||
if err := s.commandRepo.UpdateContractInfo(ctx, cert.ID, cert.ContractFileID, cert.EsignFlowID, signedFileURL, cert.ContractSignURL); err != nil {
|
||||
s.logger.Error("更新合同URL失败", zap.Error(err))
|
||||
return fmt.Errorf("更新合同URL失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新状态到合同已签署
|
||||
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusContractSigned); err != nil {
|
||||
s.logger.Error("更新认证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新认证状态失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("认证流程完成", zap.String("certification_id", cert.ID))
|
||||
s.logger.Info("合同签署成功", zap.String("certification_id", cert.ID))
|
||||
} else {
|
||||
// 合同签署失败
|
||||
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusContractRejected); err != nil {
|
||||
|
||||
160
internal/infrastructure/external/westdex/crypto.go
vendored
Normal file
160
internal/infrastructure/external/westdex/crypto.go
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
package westdex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
const (
|
||||
KEY_SIZE = 16 // AES-128, 16 bytes
|
||||
)
|
||||
|
||||
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
|
||||
func Encrypt(data, secretKey string) (string, error) {
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
ciphertext, err := aesEncrypt([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
|
||||
func Decrypt(encodedData, secretKey string) ([]byte, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
plaintext, err := aesDecrypt(ciphertext, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
|
||||
func generateAESKey(length int, password []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(password)
|
||||
state := h.Sum(nil)
|
||||
|
||||
keyBytes := make([]byte, 0, length/8)
|
||||
for len(keyBytes) < length/8 {
|
||||
h := sha1.New()
|
||||
h.Write(state)
|
||||
state = h.Sum(nil)
|
||||
keyBytes = append(keyBytes, state...)
|
||||
}
|
||||
|
||||
return keyBytes[:length/8]
|
||||
}
|
||||
|
||||
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
|
||||
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
|
||||
ciphertext := make([]byte, len(paddedPlaintext))
|
||||
mode := newECBEncrypter(block)
|
||||
mode.CryptBlocks(ciphertext, paddedPlaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
|
||||
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode := newECBDecrypter(block)
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext), nil
|
||||
}
|
||||
|
||||
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
|
||||
func pkcs5Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs5Unpadding removes PKCS5 padding from the input
|
||||
func pkcs5Unpadding(src []byte) []byte {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
return src[:(length - unpadding)]
|
||||
}
|
||||
|
||||
// ECB mode encryption/decryption
|
||||
type ecb struct {
|
||||
b cipher.Block
|
||||
blockSize int
|
||||
}
|
||||
|
||||
func newECB(b cipher.Block) *ecb {
|
||||
return &ecb{
|
||||
b: b,
|
||||
blockSize: b.BlockSize(),
|
||||
}
|
||||
}
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Encrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Decrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Md5Encrypt 用于对传入的message进行MD5加密
|
||||
func Md5Encrypt(message string) string {
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(message)) // 将字符串转换为字节切片并写入
|
||||
return hex.EncodeToString(hash.Sum(nil)) // 将哈希值转换为16进制字符串并返回
|
||||
}
|
||||
206
internal/infrastructure/external/westdex/westdex_service.go
vendored
Normal file
206
internal/infrastructure/external/westdex/westdex_service.go
vendored
Normal file
@@ -0,0 +1,206 @@
|
||||
package westdex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
"tyapi-server/internal/shared/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
)
|
||||
|
||||
type WestResp struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
Data string `json:"data"`
|
||||
ID string `json:"id"`
|
||||
ErrorCode *int `json:"error_code"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
type G05HZ01WestResp struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
ID string `json:"id"`
|
||||
ErrorCode *int `json:"error_code"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type WestConfig struct {
|
||||
Url string
|
||||
Key string
|
||||
SecretId string
|
||||
SecretSecondId string
|
||||
}
|
||||
|
||||
type WestDexService struct {
|
||||
config WestConfig
|
||||
}
|
||||
|
||||
// NewWestDexService 是一个构造函数,用于初始化 WestDexService
|
||||
func NewWestDexService(url, key, secretId, secretSecondId string) *WestDexService {
|
||||
return &WestDexService{
|
||||
config: WestConfig{
|
||||
Url: url,
|
||||
Key: key,
|
||||
SecretId: secretId,
|
||||
SecretSecondId: secretSecondId,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CallAPI 调用西部数据的 API
|
||||
func (w *WestDexService) CallAPI(code string, reqData map[string]interface{}) (resp []byte, err error) {
|
||||
// 生成当前的13位时间戳
|
||||
timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10)
|
||||
|
||||
// 构造请求URL
|
||||
reqUrl := fmt.Sprintf("%s/%s/%s?timestamp=%s", w.config.Url, w.config.SecretId, code, timestamp)
|
||||
|
||||
jsonData, marshalErr := json.Marshal(reqData)
|
||||
if marshalErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, marshalErr.Error())
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, newRequestErr := http.NewRequest("POST", reqUrl, bytes.NewBuffer(jsonData))
|
||||
if newRequestErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, newRequestErr.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{}
|
||||
httpResp, clientDoErr := client.Do(req)
|
||||
if clientDoErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, clientDoErr.Error())
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
closeErr := Body.Close()
|
||||
if closeErr != nil {
|
||||
}
|
||||
}(httpResp.Body)
|
||||
|
||||
// 检查请求是否成功
|
||||
if httpResp.StatusCode == 200 {
|
||||
// 读取响应体
|
||||
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
|
||||
if ReadErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, ReadErr.Error())
|
||||
}
|
||||
|
||||
// 手动调用 json.Unmarshal 触发自定义的 UnmarshalJSON 方法
|
||||
var westDexResp WestResp
|
||||
UnmarshalErr := json.Unmarshal(bodyBytes, &westDexResp)
|
||||
if UnmarshalErr != nil {
|
||||
return nil, UnmarshalErr
|
||||
}
|
||||
if westDexResp.Code != "00000" && westDexResp.Code != "200" && westDexResp.Code != "0" {
|
||||
if westDexResp.Data == "" {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, westDexResp.Message)
|
||||
}
|
||||
decryptedData, DecryptErr := crypto.WestDexDecrypt(westDexResp.Data, w.config.Key)
|
||||
if DecryptErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, DecryptErr.Error())
|
||||
}
|
||||
return decryptedData, fmt.Errorf("%w: %s", ErrDatasource, westDexResp.Message)
|
||||
}
|
||||
if westDexResp.Data == "" {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, westDexResp.Message)
|
||||
}
|
||||
decryptedData, DecryptErr := crypto.WestDexDecrypt(westDexResp.Data, w.config.Key)
|
||||
if DecryptErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, DecryptErr.Error())
|
||||
}
|
||||
|
||||
return decryptedData, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: 西部请求失败Code: %d", ErrSystem, httpResp.StatusCode)
|
||||
}
|
||||
|
||||
// G05HZ01CallAPI 调用西部数据的 G05HZ01 API
|
||||
func (w *WestDexService) G05HZ01CallAPI(code string, reqData map[string]interface{}) (resp []byte, err error) {
|
||||
// 生成当前的13位时间戳
|
||||
timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10)
|
||||
|
||||
// 构造请求URL
|
||||
reqUrl := fmt.Sprintf("%s/%s/%s?timestamp=%s", w.config.Url, w.config.SecretSecondId, code, timestamp)
|
||||
|
||||
jsonData, marshalErr := json.Marshal(reqData)
|
||||
if marshalErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, marshalErr.Error())
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, newRequestErr := http.NewRequest("POST", reqUrl, bytes.NewBuffer(jsonData))
|
||||
if newRequestErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, newRequestErr.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{}
|
||||
httpResp, clientDoErr := client.Do(req)
|
||||
if clientDoErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, clientDoErr.Error())
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
closeErr := Body.Close()
|
||||
if closeErr != nil {
|
||||
// 忽略
|
||||
}
|
||||
}(httpResp.Body)
|
||||
|
||||
if httpResp.StatusCode == 200 {
|
||||
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
|
||||
if ReadErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, ReadErr.Error())
|
||||
}
|
||||
var westDexResp G05HZ01WestResp
|
||||
UnmarshalErr := json.Unmarshal(bodyBytes, &westDexResp)
|
||||
if UnmarshalErr != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, UnmarshalErr.Error())
|
||||
}
|
||||
if westDexResp.Code != "0000" {
|
||||
if westDexResp.Data == nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, westDexResp.Message)
|
||||
} else {
|
||||
return westDexResp.Data, fmt.Errorf("%w: %s", ErrSystem, string(westDexResp.Data))
|
||||
}
|
||||
}
|
||||
if westDexResp.Data == nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, westDexResp.Message)
|
||||
}
|
||||
return westDexResp.Data, nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("%w: 西部请求失败Code: %d", ErrSystem, httpResp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WestDexService) Encrypt(data string) (string, error) {
|
||||
encryptedValue, err := crypto.WestDexEncrypt(data, w.config.Key)
|
||||
if err != nil {
|
||||
return "", ErrSystem
|
||||
}
|
||||
return encryptedValue, nil
|
||||
}
|
||||
func (w *WestDexService) Md5Encrypt(data string) string {
|
||||
return Md5Encrypt(data)
|
||||
}
|
||||
|
||||
func (w *WestDexService) GetConfig() WestConfig {
|
||||
return w.config
|
||||
}
|
||||
205
internal/infrastructure/external/yushan/yushan_service.go
vendored
Normal file
205
internal/infrastructure/external/yushan/yushan_service.go
vendored
Normal file
@@ -0,0 +1,205 @@
|
||||
package yushan
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrNotFound = errors.New("查询为空")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
)
|
||||
|
||||
type YushanConfig struct {
|
||||
URL string
|
||||
ApiKey string
|
||||
AcctID string
|
||||
}
|
||||
|
||||
type YushanService struct {
|
||||
config YushanConfig
|
||||
}
|
||||
|
||||
// NewWestDexService 是一个构造函数,用于初始化 WestDexService
|
||||
func NewYushanService(url, apiKey, acctID string) *YushanService {
|
||||
return &YushanService{
|
||||
config: YushanConfig{
|
||||
URL: url,
|
||||
ApiKey: apiKey,
|
||||
AcctID: acctID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CallAPI 调用西部数据的 API
|
||||
func (y *YushanService) CallAPI(code string, params map[string]interface{}) (respBytes []byte, err error) {
|
||||
// 获取当前时间戳
|
||||
unixMilliseconds := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
|
||||
// 生成请求序列号
|
||||
requestSN, _ := y.GenerateRandomString()
|
||||
|
||||
// 构建请求数据
|
||||
reqData := map[string]interface{}{
|
||||
"prod_id": code,
|
||||
"req_time": unixMilliseconds,
|
||||
"request_sn": requestSN,
|
||||
"req_data": params,
|
||||
}
|
||||
|
||||
// 将请求数据转换为 JSON 字节数组
|
||||
messageBytes, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
|
||||
// 获取 API 密钥
|
||||
key, err := hex.DecodeString(y.config.ApiKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
|
||||
// 使用 AES CBC 加密请求数据
|
||||
cipherText := y.AES_CBC_Encrypt(messageBytes, key)
|
||||
|
||||
// 将加密后的数据编码为 Base64 字符串
|
||||
content := base64.StdEncoding.EncodeToString(cipherText)
|
||||
|
||||
// 发起 HTTP 请求
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("POST", y.config.URL, strings.NewReader(content))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("ACCT_ID", y.config.AcctID)
|
||||
|
||||
// 执行请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var respData []byte
|
||||
|
||||
if IsJSON(string(body)) {
|
||||
respData = body
|
||||
} else {
|
||||
sDec, err := base64.StdEncoding.DecodeString(string(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
respData = y.AES_CBC_Decrypt(sDec, key)
|
||||
}
|
||||
retCode := gjson.GetBytes(respData, "retcode").String()
|
||||
|
||||
if retCode == "100000" {
|
||||
// retcode 为 100000,表示查询为空
|
||||
return nil, ErrNotFound
|
||||
} else if retCode == "000000" {
|
||||
// retcode 为 000000,表示有数据,返回 retdata
|
||||
retData := gjson.GetBytes(respData, "retdata")
|
||||
if !retData.Exists() {
|
||||
return nil, fmt.Errorf("%w: %s", ErrDatasource, "羽山请求retdata为空")
|
||||
}
|
||||
return []byte(retData.Raw), nil
|
||||
} else {
|
||||
return nil, fmt.Errorf("%w: %s", ErrDatasource, "羽山请求未知的状态码")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// GenerateRandomString 生成一个32位的随机字符串订单号
|
||||
func (y *YushanService) GenerateRandomString() (string, error) {
|
||||
// 创建一个16字节的数组
|
||||
bytes := make([]byte, 16)
|
||||
// 读取随机字节到数组中
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 将字节数组编码为16进制字符串
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// AEC加密(CBC模式)
|
||||
func (y *YushanService) AES_CBC_Encrypt(plainText []byte, key []byte) []byte {
|
||||
//指定加密算法,返回一个AES算法的Block接口对象
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//进行填充
|
||||
plainText = Padding(plainText, block.BlockSize())
|
||||
//指定初始向量vi,长度和block的块尺寸一致
|
||||
iv := []byte("0000000000000000")
|
||||
//指定分组模式,返回一个BlockMode接口对象
|
||||
blockMode := cipher.NewCBCEncrypter(block, iv)
|
||||
//加密连续数据库
|
||||
cipherText := make([]byte, len(plainText))
|
||||
blockMode.CryptBlocks(cipherText, plainText)
|
||||
//返回base64密文
|
||||
return cipherText
|
||||
}
|
||||
|
||||
// AEC解密(CBC模式)
|
||||
func (y *YushanService) AES_CBC_Decrypt(cipherText []byte, key []byte) []byte {
|
||||
//指定解密算法,返回一个AES算法的Block接口对象
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//指定初始化向量IV,和加密的一致
|
||||
iv := []byte("0000000000000000")
|
||||
//指定分组模式,返回一个BlockMode接口对象
|
||||
blockMode := cipher.NewCBCDecrypter(block, iv)
|
||||
//解密
|
||||
plainText := make([]byte, len(cipherText))
|
||||
blockMode.CryptBlocks(plainText, cipherText)
|
||||
//删除填充
|
||||
plainText = UnPadding(plainText)
|
||||
return plainText
|
||||
} // 对明文进行填充
|
||||
func Padding(plainText []byte, blockSize int) []byte {
|
||||
//计算要填充的长度
|
||||
n := blockSize - len(plainText)%blockSize
|
||||
//对原来的明文填充n个n
|
||||
temp := bytes.Repeat([]byte{byte(n)}, n)
|
||||
plainText = append(plainText, temp...)
|
||||
return plainText
|
||||
}
|
||||
|
||||
// 对密文删除填充
|
||||
func UnPadding(cipherText []byte) []byte {
|
||||
//取出密文最后一个字节end
|
||||
end := cipherText[len(cipherText)-1]
|
||||
//删除填充
|
||||
cipherText = cipherText[:len(cipherText)-int(end)]
|
||||
return cipherText
|
||||
}
|
||||
|
||||
// 判断字符串是否为 JSON 格式
|
||||
func IsJSON(s string) bool {
|
||||
var js interface{}
|
||||
return json.Unmarshal([]byte(s), &js) == nil
|
||||
}
|
||||
322
internal/infrastructure/http/handlers/api_handler.go
Normal file
322
internal/infrastructure/http/handlers/api_handler.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
"tyapi-server/internal/application/api"
|
||||
"tyapi-server/internal/application/api/commands"
|
||||
"tyapi-server/internal/application/api/dto"
|
||||
"tyapi-server/internal/shared/crypto"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ApiHandler API调用HTTP处理器
|
||||
type ApiHandler struct {
|
||||
appService api.ApiApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewApiHandler 创建API调用HTTP处理器
|
||||
func NewApiHandler(
|
||||
appService api.ApiApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
) *ApiHandler {
|
||||
return &ApiHandler{
|
||||
appService: appService,
|
||||
responseBuilder: responseBuilder,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleApiCall 统一API调用入口
|
||||
// @Summary API调用
|
||||
// @Description 统一API调用入口,参数加密传输
|
||||
// @Tags API调用
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.ApiCallCommand true "API调用请求"
|
||||
// @Success 200 {object} dto.ApiCallResponse "调用成功"
|
||||
// @Failure 400 {object} dto.ApiCallResponse "请求参数错误"
|
||||
// @Failure 401 {object} dto.ApiCallResponse "未授权"
|
||||
// @Failure 429 {object} dto.ApiCallResponse "请求过于频繁"
|
||||
// @Failure 500 {object} dto.ApiCallResponse "服务器内部错误"
|
||||
// @Router /api/v1/:api_name [post]
|
||||
func (h *ApiHandler) HandleApiCall(c *gin.Context) {
|
||||
// 1. 基础参数校验
|
||||
accessId := c.GetHeader("Access-Id")
|
||||
if accessId == "" {
|
||||
response := dto.NewErrorResponse(1005, "缺少Access-Id", "")
|
||||
c.JSON(200, response)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 绑定和校验请求参数
|
||||
var cmd commands.ApiCallCommand
|
||||
cmd.ClientIP = c.ClientIP()
|
||||
cmd.AccessId = accessId
|
||||
cmd.ApiName = c.Param("api_name")
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
response := dto.NewErrorResponse(1003, "请求参数结构不正确", "")
|
||||
c.JSON(200, response)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 调用应用服务
|
||||
transactionId, encryptedResp, err := h.appService.CallApi(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
// 根据错误类型返回对应的错误码
|
||||
errorCode := api.GetErrorCode(err)
|
||||
response := dto.NewErrorResponse(errorCode, err.Error(), transactionId)
|
||||
c.JSON(200, response) // API调用接口统一返回200状态码
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 返回成功响应
|
||||
response := dto.NewSuccessResponse(transactionId, encryptedResp)
|
||||
c.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetUserApiKeys 获取用户API密钥
|
||||
func (h *ApiHandler) GetUserApiKeys(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserApiKeys(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户API密钥失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取API密钥成功")
|
||||
}
|
||||
|
||||
// GetUserWhiteList 获取用户白名单列表
|
||||
func (h *ApiHandler) GetUserWhiteList(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserWhiteList(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户白名单失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取白名单成功")
|
||||
}
|
||||
|
||||
// AddWhiteListIP 添加白名单IP
|
||||
func (h *ApiHandler) AddWhiteListIP(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.WhiteListRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.AddWhiteListIP(c.Request.Context(), userID, req.IPAddress)
|
||||
if err != nil {
|
||||
h.logger.Error("添加白名单IP失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "添加白名单IP成功")
|
||||
}
|
||||
|
||||
// DeleteWhiteListIP 删除白名单IP
|
||||
func (h *ApiHandler) DeleteWhiteListIP(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := c.Param("ip")
|
||||
if ipAddress == "" {
|
||||
h.responseBuilder.BadRequest(c, "IP地址不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.DeleteWhiteListIP(c.Request.Context(), userID, ipAddress)
|
||||
if err != nil {
|
||||
h.logger.Error("删除白名单IP失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "删除白名单IP成功")
|
||||
}
|
||||
|
||||
// EncryptParams 加密参数接口(用于前端调试)
|
||||
// @Summary 加密参数
|
||||
// @Description 用于前端调试时加密API调用参数
|
||||
// @Tags API调试
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.EncryptCommand true "加密请求"
|
||||
// @Success 200 {object} dto.EncryptResponse "加密成功"
|
||||
// @Failure 400 {object} dto.EncryptResponse "请求参数错误"
|
||||
// @Failure 401 {object} dto.EncryptResponse "未授权"
|
||||
// @Router /api/v1/encrypt [post]
|
||||
func (h *ApiHandler) EncryptParams(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.EncryptCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取用户的SecretKey
|
||||
apiKeys, err := h.appService.GetUserApiKeys(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户API密钥失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取API密钥失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 将JSON对象转换为字节数组
|
||||
jsonData, err := json.Marshal(cmd.Data)
|
||||
if err != nil {
|
||||
h.logger.Error("序列化参数失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "参数序列化失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 加密参数
|
||||
encryptedData, err := crypto.AesEncrypt(jsonData, apiKeys.SecretKey)
|
||||
if err != nil {
|
||||
h.logger.Error("加密参数失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "加密参数失败")
|
||||
return
|
||||
}
|
||||
|
||||
response := dto.EncryptResponse{
|
||||
EncryptedData: encryptedData,
|
||||
}
|
||||
h.responseBuilder.Success(c, response, "加密成功")
|
||||
}
|
||||
|
||||
// getCurrentUserID 获取当前用户ID
|
||||
func (h *ApiHandler) getCurrentUserID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetUserApiCalls 获取用户API调用记录
|
||||
// @Summary 获取用户API调用记录
|
||||
// @Description 获取当前用户的API调用记录列表,支持分页和筛选
|
||||
// @Tags API管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param start_time query string false "开始时间 (格式: 2006-01-02 15:04:05)"
|
||||
// @Param end_time query string false "结束时间 (格式: 2006-01-02 15:04:05)"
|
||||
// @Param transaction_id query string false "交易ID"
|
||||
// @Param product_name query string false "产品名称"
|
||||
// @Param status query string false "状态 (pending/success/failed)"
|
||||
// @Success 200 {object} dto.ApiCallListResponse "获取成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/my/api-calls [get]
|
||||
func (h *ApiHandler) GetUserApiCalls(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId := c.Query("transaction_id"); transactionId != "" {
|
||||
filters["transaction_id"] = transactionId
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName := c.Query("product_name"); productName != "" {
|
||||
filters["product_name"] = productName
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status := c.Query("status"); status != "" {
|
||||
filters["status"] = status
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: "created_at",
|
||||
Order: "desc",
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserApiCalls(c.Request.Context(), userID, filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户API调用记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取API调用记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取API调用记录成功")
|
||||
}
|
||||
|
||||
// getIntQuery 获取整数查询参数
|
||||
func (h *ApiHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if value := c.Query(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@@ -38,38 +43,6 @@ func NewCertificationHandler(
|
||||
}
|
||||
|
||||
// ================ 认证申请管理 ================
|
||||
|
||||
// CreateCertification 创建认证申请
|
||||
// @Summary 创建认证申请
|
||||
// @Description 为用户创建企业认证申请
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateCertificationCommand true "创建认证申请请求"
|
||||
// @Success 201 {object} responses.CertificationResponse "认证申请创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications [post]
|
||||
func (h *CertificationHandler) CreateCertification(c *gin.Context) {
|
||||
var cmd commands.CreateCertificationCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.CreateCertification(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("创建认证申请失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Created(c, result, "认证申请创建成功")
|
||||
}
|
||||
|
||||
// GetCertification 获取认证详情
|
||||
// @Summary 获取认证详情
|
||||
// @Description 根据认证ID获取认证详情
|
||||
@@ -77,13 +50,12 @@ func (h *CertificationHandler) CreateCertification(c *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "认证ID"
|
||||
// @Success 200 {object} responses.CertificationResponse "获取认证详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/{id} [get]
|
||||
// @Router /api/v1/certifications/details [get]
|
||||
func (h *CertificationHandler) GetCertification(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
@@ -91,21 +63,14 @@ func (h *CertificationHandler) GetCertification(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
certificationID := c.Param("id")
|
||||
if certificationID == "" {
|
||||
h.response.BadRequest(c, "认证ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetCertificationQuery{
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
result, err := h.appService.GetCertification(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证详情失败", zap.Error(err), zap.String("certification_id", certificationID))
|
||||
h.response.NotFound(c, "认证记录不存在")
|
||||
h.logger.Error("获取认证详情失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
@@ -121,14 +86,13 @@ func (h *CertificationHandler) GetCertification(c *gin.Context) {
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "认证ID"
|
||||
// @Param request body commands.SubmitEnterpriseInfoCommand true "提交企业信息请求"
|
||||
// @Success 200 {object} responses.CertificationResponse "企业信息提交成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/{id}/enterprise-info [post]
|
||||
// @Router /api/v1/certifications/enterprise-info [post]
|
||||
func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
@@ -136,22 +100,15 @@ func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
certificationID := c.Param("id")
|
||||
if certificationID == "" {
|
||||
h.response.BadRequest(c, "认证ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.SubmitEnterpriseInfoCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.CertificationID = certificationID
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.SubmitEnterpriseInfo(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("提交企业信息失败", zap.Error(err), zap.String("certification_id", certificationID))
|
||||
h.logger.Error("提交企业信息失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -159,6 +116,69 @@ func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
|
||||
h.response.Success(c, result, "企业信息提交成功")
|
||||
}
|
||||
|
||||
// ConfirmAuth 前端确认是否完成认证
|
||||
// @Summary 前端确认认证状态
|
||||
// @Description 前端轮询确认企业认证是否完成
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.ConfirmAuthCommand true "确认状态请求"
|
||||
// @Success 200 {object} responses.ConfirmStatusResponse "状态确认成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/confirm-auth [post]
|
||||
func (h *CertificationHandler) ConfirmAuth(c *gin.Context) {
|
||||
var cmd queries.ConfirmAuthCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.ConfirmAuth(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("确认认证/签署状态失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "状态确认成功")
|
||||
}
|
||||
|
||||
// ConfirmSign 前端确认是否完成签署
|
||||
// @Summary 前端确认签署状态
|
||||
// @Description 前端轮询确认合同签署是否完成
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.ConfirmSignCommand true "确认状态请求"
|
||||
// @Success 200 {object} responses.ConfirmStatusResponse "状态确认成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/confirm-sign [post]
|
||||
func (h *CertificationHandler) ConfirmSign(c *gin.Context) {
|
||||
var cmd queries.ConfirmSignCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
result, err := h.appService.ConfirmSign(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("确认认证/签署状态失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "状态确认成功")
|
||||
}
|
||||
|
||||
// ================ 合同管理 ================
|
||||
|
||||
// ApplyContract 申请合同签署
|
||||
@@ -176,21 +196,16 @@ func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/apply-contract [post]
|
||||
func (h *CertificationHandler) ApplyContract(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
var cmd commands.ApplyContractCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.ApplyContractCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.ApplyContract(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("申请合同失败", zap.Error(err), zap.String("certification_id", cmd.CertificationID))
|
||||
h.logger.Error("申请合同失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
@@ -198,86 +213,6 @@ func (h *CertificationHandler) ApplyContract(c *gin.Context) {
|
||||
h.response.Success(c, result, "合同申请成功")
|
||||
}
|
||||
|
||||
// ================ 重试操作 ================
|
||||
|
||||
// RetryOperation 重试操作
|
||||
// @Summary 重试操作
|
||||
// @Description 重试失败的企业认证或合同申请操作
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.RetryOperationCommand true "重试操作请求"
|
||||
// @Success 200 {object} responses.CertificationResponse "重试操作成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/retry [post]
|
||||
func (h *CertificationHandler) RetryOperation(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.RetryOperationCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.RetryOperation(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("重试操作失败", zap.Error(err), zap.String("certification_id", cmd.CertificationID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "重试操作成功")
|
||||
}
|
||||
|
||||
// ================ 查询操作 ================
|
||||
|
||||
// GetUserCertifications 获取用户认证列表
|
||||
// @Summary 获取用户认证列表
|
||||
// @Description 获取当前用户的认证申请列表
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param status query string false "认证状态"
|
||||
// @Param include_completed query bool false "是否包含已完成"
|
||||
// @Param include_failed query bool false "是否包含失败"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Success 200 {object} responses.CertificationListResponse "获取用户认证列表成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/user [get]
|
||||
func (h *CertificationHandler) GetUserCertifications(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var query queries.GetUserCertificationsQuery
|
||||
if err := h.validator.BindAndValidate(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
query.UserID = userID
|
||||
|
||||
result, err := h.appService.GetUserCertifications(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户认证列表失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取用户认证列表成功")
|
||||
}
|
||||
|
||||
// ListCertifications 获取认证列表(管理员)
|
||||
// @Summary 获取认证列表
|
||||
// @Description 管理员获取认证申请列表
|
||||
@@ -321,46 +256,6 @@ func (h *CertificationHandler) ListCertifications(c *gin.Context) {
|
||||
h.response.Success(c, result, "获取认证列表成功")
|
||||
}
|
||||
|
||||
// GetCertificationStatistics 获取认证统计
|
||||
// @Summary 获取认证统计
|
||||
// @Description 获取认证相关的统计数据
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param start_date query string true "开始日期" format(date)
|
||||
// @Param end_date query string true "结束日期" format(date)
|
||||
// @Param period query string false "统计周期" Enums(daily, weekly, monthly, yearly) default(daily)
|
||||
// @Param group_by query []string false "分组字段"
|
||||
// @Param user_ids query []string false "用户ID列表"
|
||||
// @Param statuses query []string false "状态列表"
|
||||
// @Success 200 {object} responses.CertificationStatisticsResponse "获取认证统计成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/statistics [get]
|
||||
func (h *CertificationHandler) GetCertificationStatistics(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var query queries.GetCertificationStatisticsQuery
|
||||
if err := h.validator.BindAndValidate(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetCertificationStatistics(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证统计失败", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取认证统计成功")
|
||||
}
|
||||
|
||||
// ================ 回调处理 ================
|
||||
|
||||
// HandleEsignCallback 处理e签宝回调
|
||||
@@ -373,97 +268,118 @@ func (h *CertificationHandler) GetCertificationStatistics(c *gin.Context) {
|
||||
// @Success 200 {object} responses.CallbackResponse "回调处理成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/callbacks [post]
|
||||
// @Router /api/v1/certifications/callbacks/esign [post]
|
||||
func (h *CertificationHandler) HandleEsignCallback(c *gin.Context) {
|
||||
var cmd commands.EsignCallbackCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
// 记录请求基本信息
|
||||
h.logger.Info("收到e签宝回调请求",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("url", c.Request.URL.String()),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
)
|
||||
|
||||
// 记录所有请求头
|
||||
headers := make(map[string]string)
|
||||
for key, values := range c.Request.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
h.logger.Info("回调请求头信息", zap.Any("headers", headers))
|
||||
|
||||
// 记录URL查询参数
|
||||
queryParams := make(map[string]string)
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
queryParams[key] = values[0]
|
||||
}
|
||||
}
|
||||
if len(queryParams) > 0 {
|
||||
h.logger.Info("回调URL查询参数", zap.Any("query_params", queryParams))
|
||||
}
|
||||
|
||||
result, err := h.appService.HandleEsignCallback(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("处理e签宝回调失败", zap.Error(err), zap.String("certification_id", cmd.CertificationID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
// 读取并记录请求体
|
||||
var callbackData *commands.EsignCallbackData
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
h.logger.Error("读取回调请求体失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "读取请求体失败")
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bodyBytes, &callbackData); err != nil {
|
||||
h.logger.Error("回调请求体不是有效的JSON格式", zap.Error(err))
|
||||
h.response.BadRequest(c, "请求体格式错误")
|
||||
return
|
||||
}
|
||||
h.logger.Info("回调请求体内容", zap.Any("body", callbackData))
|
||||
|
||||
// 如果后续还需要用 c.Request.Body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "回调处理成功")
|
||||
}
|
||||
// 记录Content-Type
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
h.logger.Info("回调请求Content-Type", zap.String("content_type", contentType))
|
||||
|
||||
// ================ 管理员操作 ================
|
||||
|
||||
// ForceTransitionStatus 强制状态转换(管理员)
|
||||
// @Summary 强制状态转换
|
||||
// @Description 管理员强制转换认证状态
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.ForceTransitionStatusCommand true "强制状态转换请求"
|
||||
// @Success 200 {object} responses.CertificationResponse "状态转换成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/force-transition [post]
|
||||
func (h *CertificationHandler) ForceTransitionStatus(c *gin.Context) {
|
||||
adminID := h.getCurrentUserID(c)
|
||||
if adminID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
// 记录Content-Length
|
||||
contentLength := c.GetHeader("Content-Length")
|
||||
if contentLength != "" {
|
||||
h.logger.Info("回调请求Content-Length", zap.String("content_length", contentLength))
|
||||
}
|
||||
|
||||
var cmd commands.ForceTransitionStatusCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.AdminID = adminID
|
||||
// 记录时间戳
|
||||
h.logger.Info("回调请求时间",
|
||||
zap.Time("request_time", time.Now()),
|
||||
zap.String("request_id", c.GetHeader("X-Request-ID")),
|
||||
)
|
||||
|
||||
result, err := h.appService.ForceTransitionStatus(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("强制状态转换失败", zap.Error(err), zap.String("certification_id", cmd.CertificationID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
// 记录完整的请求信息摘要
|
||||
h.logger.Info("e签宝回调完整信息摘要",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("url", c.Request.URL.String()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.String("content_type", contentType),
|
||||
zap.Any("headers", headers),
|
||||
zap.Any("query_params", queryParams),
|
||||
zap.Any("body", callbackData),
|
||||
)
|
||||
|
||||
// 处理回调数据
|
||||
if callbackData != nil {
|
||||
// 构建请求头映射
|
||||
headers := make(map[string]string)
|
||||
for key, values := range c.Request.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
// 构建查询参数映射
|
||||
queryParams := make(map[string]string)
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
queryParams[key] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.appService.HandleEsignCallback(c.Request.Context(), &commands.EsignCallbackCommand{
|
||||
Data: callbackData,
|
||||
Headers: headers,
|
||||
QueryParams: queryParams,
|
||||
}); err != nil {
|
||||
h.logger.Error("处理e签宝回调失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "回调处理失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "状态转换成功")
|
||||
}
|
||||
|
||||
// GetSystemMonitoring 获取系统监控数据
|
||||
// @Summary 获取系统监控数据
|
||||
// @Description 获取认证系统的监控数据
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param time_range query string false "时间范围" Enums(1h, 6h, 24h, 7d, 30d) default(24h)
|
||||
// @Param metrics query []string false "监控指标"
|
||||
// @Success 200 {object} responses.SystemMonitoringResponse "获取系统监控数据成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/monitoring [get]
|
||||
func (h *CertificationHandler) GetSystemMonitoring(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var query queries.GetSystemMonitoringQuery
|
||||
if err := h.validator.BindAndValidate(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetSystemMonitoring(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取系统监控数据失败", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取系统监控数据成功")
|
||||
// 返回成功响应
|
||||
c.JSON(200, map[string]interface{}{
|
||||
"code": "200",
|
||||
"msg": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ================ 辅助方法 ================
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
@@ -33,34 +38,6 @@ func NewFinanceHandler(
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWallet 创建钱包
|
||||
// @Summary 创建钱包
|
||||
// @Description 为用户创建新的钱包账户
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.CreateWalletCommand true "创建钱包请求"
|
||||
// @Success 201 {object} responses.WalletResponse "钱包创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 409 {object} map[string]interface{} "钱包已存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet [post]
|
||||
func (h *FinanceHandler) CreateWallet(c *gin.Context) {
|
||||
var cmd commands.CreateWalletCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.CreateWallet(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("创建钱包失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, response, "钱包创建成功")
|
||||
}
|
||||
|
||||
// GetWallet 获取钱包信息
|
||||
// @Summary 获取钱包信息
|
||||
// @Description 获取当前用户的钱包详细信息
|
||||
@@ -94,333 +71,486 @@ func (h *FinanceHandler) GetWallet(c *gin.Context) {
|
||||
h.responseBuilder.Success(c, result, "获取钱包信息成功")
|
||||
}
|
||||
|
||||
// UpdateWallet 更新钱包
|
||||
// @Summary 更新钱包信息
|
||||
// @Description 更新当前用户的钱包基本信息
|
||||
// GetUserWalletTransactions 获取用户钱包交易记录
|
||||
// @Summary 获取用户钱包交易记录
|
||||
// @Description 获取当前用户的钱包交易记录列表,支持分页和筛选
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.UpdateWalletCommand true "更新钱包请求"
|
||||
// @Success 200 {object} map[string]interface{} "钱包更新成功"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param start_time query string false "开始时间 (格式: 2006-01-02 15:04:05)"
|
||||
// @Param end_time query string false "结束时间 (格式: 2006-01-02 15:04:05)"
|
||||
// @Param transaction_id query string false "交易ID"
|
||||
// @Param product_name query string false "产品名称"
|
||||
// @Param min_amount query string false "最小金额"
|
||||
// @Param max_amount query string false "最大金额"
|
||||
// @Success 200 {object} responses.WalletTransactionListResponse "获取成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet [put]
|
||||
func (h *FinanceHandler) UpdateWallet(c *gin.Context) {
|
||||
// @Router /api/v1/finance/wallet/transactions [get]
|
||||
func (h *FinanceHandler) GetUserWalletTransactions(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.UpdateWalletCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
// 交易ID筛选
|
||||
if transactionId := c.Query("transaction_id"); transactionId != "" {
|
||||
filters["transaction_id"] = transactionId
|
||||
}
|
||||
|
||||
err := h.appService.UpdateWallet(c.Request.Context(), &cmd)
|
||||
// 产品名称筛选
|
||||
if productName := c.Query("product_name"); productName != "" {
|
||||
filters["product_name"] = productName
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount := c.Query("min_amount"); minAmount != "" {
|
||||
filters["min_amount"] = minAmount
|
||||
}
|
||||
if maxAmount := c.Query("max_amount"); maxAmount != "" {
|
||||
filters["max_amount"] = maxAmount
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: "created_at",
|
||||
Order: "desc",
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserWalletTransactions(c.Request.Context(), userID, filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("更新钱包失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
h.logger.Error("获取用户钱包交易记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取钱包交易记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "钱包更新成功")
|
||||
h.responseBuilder.Success(c, result, "获取钱包交易记录成功")
|
||||
}
|
||||
|
||||
// Recharge 充值
|
||||
// @Summary 钱包充值
|
||||
// @Description 为钱包进行充值操作
|
||||
// getIntQuery 获取整数查询参数
|
||||
func (h *FinanceHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if value := c.Query(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// HandleAlipayCallback 处理支付宝支付回调
|
||||
// @Summary 支付宝支付回调
|
||||
// @Description 处理支付宝异步支付通知
|
||||
// @Tags 支付管理
|
||||
// @Accept application/x-www-form-urlencoded
|
||||
// @Produce text/plain
|
||||
// @Success 200 {string} string "success"
|
||||
// @Failure 400 {string} string "fail"
|
||||
// @Router /api/v1/finance/alipay/callback [post]
|
||||
func (h *FinanceHandler) HandleAlipayCallback(c *gin.Context) {
|
||||
// 记录回调请求信息
|
||||
h.logger.Info("收到支付宝回调请求",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("url", c.Request.URL.String()),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
)
|
||||
|
||||
// 通过应用服务处理支付宝回调
|
||||
err := h.appService.HandleAlipayCallback(c.Request.Context(), c.Request)
|
||||
if err != nil {
|
||||
h.logger.Error("支付宝回调处理失败", zap.Error(err))
|
||||
c.String(400, "fail")
|
||||
return
|
||||
}
|
||||
|
||||
// 返回成功响应(支付宝要求返回success)
|
||||
c.String(200, "success")
|
||||
}
|
||||
|
||||
// HandleAlipayReturn 处理支付宝同步回调
|
||||
// @Summary 支付宝同步回调
|
||||
// @Description 处理支付宝同步支付通知,跳转到前端成功页面
|
||||
// @Tags 支付管理
|
||||
// @Accept application/x-www-form-urlencoded
|
||||
// @Produce text/html
|
||||
// @Success 200 {string} string "支付成功页面"
|
||||
// @Failure 400 {string} string "支付失败页面"
|
||||
// @Router /api/v1/finance/alipay/return [get]
|
||||
func (h *FinanceHandler) HandleAlipayReturn(c *gin.Context) {
|
||||
// 记录同步回调请求信息
|
||||
h.logger.Info("收到支付宝同步回调请求",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("url", c.Request.URL.String()),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
)
|
||||
|
||||
// 获取查询参数
|
||||
outTradeNo := c.Query("out_trade_no")
|
||||
tradeNo := c.Query("trade_no")
|
||||
totalAmount := c.Query("total_amount")
|
||||
|
||||
h.logger.Info("支付宝同步回调参数",
|
||||
zap.String("out_trade_no", outTradeNo),
|
||||
zap.String("trade_no", tradeNo),
|
||||
zap.String("total_amount", totalAmount),
|
||||
)
|
||||
|
||||
// 验证必要参数
|
||||
if outTradeNo == "" {
|
||||
h.logger.Error("支付宝同步回调缺少商户订单号")
|
||||
h.redirectToFailPage(c, "", "缺少商户订单号")
|
||||
return
|
||||
}
|
||||
|
||||
// 通过应用服务处理同步回调,查询订单状态
|
||||
orderStatus, err := h.appService.HandleAlipayReturn(c.Request.Context(), outTradeNo)
|
||||
if err != nil {
|
||||
h.logger.Error("支付宝同步回调处理失败",
|
||||
zap.String("out_trade_no", outTradeNo),
|
||||
zap.Error(err))
|
||||
h.redirectToFailPage(c, outTradeNo, "订单处理失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 根据环境确定前端域名
|
||||
frontendDomain := "https://www.tianyuanapi.com"
|
||||
if gin.Mode() == gin.DebugMode {
|
||||
frontendDomain = "http://localhost:5173"
|
||||
}
|
||||
|
||||
// 根据订单状态跳转到相应页面
|
||||
switch orderStatus {
|
||||
case "TRADE_SUCCESS":
|
||||
// 支付成功,跳转到前端成功页面
|
||||
successURL := fmt.Sprintf("%s/finance/wallet/success?out_trade_no=%s&trade_no=%s&amount=%s",
|
||||
frontendDomain, outTradeNo, tradeNo, totalAmount)
|
||||
c.Redirect(http.StatusFound, successURL)
|
||||
case "WAIT_BUYER_PAY":
|
||||
// 支付处理中,跳转到处理中页面
|
||||
h.redirectToProcessingPage(c, outTradeNo, totalAmount)
|
||||
default:
|
||||
// 支付失败或取消,跳转到前端失败页面
|
||||
h.redirectToFailPage(c, outTradeNo, orderStatus)
|
||||
}
|
||||
}
|
||||
|
||||
// redirectToFailPage 跳转到失败页面
|
||||
func (h *FinanceHandler) redirectToFailPage(c *gin.Context, outTradeNo, reason string) {
|
||||
frontendDomain := "https://www.tianyuanapi.com"
|
||||
if gin.Mode() == gin.DebugMode {
|
||||
frontendDomain = "http://localhost:5173"
|
||||
}
|
||||
|
||||
failURL := fmt.Sprintf("%s/finance/wallet/fail?out_trade_no=%s&reason=%s",
|
||||
frontendDomain, outTradeNo, reason)
|
||||
c.Redirect(http.StatusFound, failURL)
|
||||
}
|
||||
|
||||
// redirectToProcessingPage 跳转到处理中页面
|
||||
func (h *FinanceHandler) redirectToProcessingPage(c *gin.Context, outTradeNo, amount string) {
|
||||
frontendDomain := "https://www.tianyuanapi.com"
|
||||
if gin.Mode() == gin.DebugMode {
|
||||
frontendDomain = "http://localhost:5173"
|
||||
}
|
||||
|
||||
processingURL := fmt.Sprintf("%s/finance/wallet/processing?out_trade_no=%s&amount=%s",
|
||||
frontendDomain, outTradeNo, amount)
|
||||
c.Redirect(http.StatusFound, processingURL)
|
||||
}
|
||||
|
||||
// CreateAlipayRecharge 创建支付宝充值订单
|
||||
// @Summary 创建支付宝充值订单
|
||||
// @Description 创建支付宝充值订单并返回支付链接
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.RechargeWalletCommand true "充值请求"
|
||||
// @Success 200 {object} responses.TransactionResponse "充值成功"
|
||||
// @Param request body commands.CreateAlipayRechargeCommand true "充值请求"
|
||||
// @Success 200 {object} responses.AlipayRechargeOrderResponse "创建充值订单成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/recharge [post]
|
||||
func (h *FinanceHandler) Recharge(c *gin.Context) {
|
||||
// @Router /api/v1/finance/wallet/alipay-recharge [post]
|
||||
func (h *FinanceHandler) CreateAlipayRecharge(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.RechargeWalletCommand
|
||||
var cmd commands.CreateAlipayRechargeCommand
|
||||
cmd.UserID = userID
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.Recharge(c.Request.Context(), &cmd)
|
||||
// 调用应用服务进行完整的业务流程编排
|
||||
result, err := h.appService.CreateAlipayRechargeOrder(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("充值失败",
|
||||
h.logger.Error("创建支付宝充值订单失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("amount", cmd.Amount),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, "创建支付宝充值订单失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("支付宝充值订单创建成功",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("out_trade_no", result.OutTradeNo),
|
||||
zap.String("amount", cmd.Amount),
|
||||
zap.String("platform", cmd.Platform),
|
||||
)
|
||||
|
||||
// 返回支付链接和订单信息
|
||||
h.responseBuilder.Success(c, result, "支付宝充值订单创建成功")
|
||||
}
|
||||
|
||||
// TransferRecharge 管理员对公转账充值
|
||||
func (h *FinanceHandler) TransferRecharge(c *gin.Context) {
|
||||
var cmd commands.TransferRechargeCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cmd.UserID == "" {
|
||||
h.responseBuilder.BadRequest(c, "缺少用户ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.TransferRecharge(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("对公转账充值失败",
|
||||
zap.String("user_id", cmd.UserID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "充值成功")
|
||||
h.responseBuilder.Success(c, result, "对公转账充值成功")
|
||||
}
|
||||
|
||||
// Withdraw 提现
|
||||
// @Summary 钱包提现
|
||||
// @Description 从钱包进行提现操作
|
||||
// GiftRecharge 管理员赠送充值
|
||||
func (h *FinanceHandler) GiftRecharge(c *gin.Context) {
|
||||
var cmd commands.GiftRechargeCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if cmd.UserID == "" {
|
||||
h.responseBuilder.BadRequest(c, "缺少用户ID")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GiftRecharge(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("赠送充值失败",
|
||||
zap.String("user_id", cmd.UserID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "赠送充值成功")
|
||||
}
|
||||
|
||||
// GetUserRechargeRecords 用户获取自己充值记录分页
|
||||
func (h *FinanceHandler) GetUserRechargeRecords(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
// 充值类型筛选
|
||||
if rechargeType := c.Query("recharge_type"); rechargeType != "" {
|
||||
filters["recharge_type"] = rechargeType
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status := c.Query("status"); status != "" {
|
||||
filters["status"] = status
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: "created_at",
|
||||
Order: "desc",
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserRechargeRecords(c.Request.Context(), userID, filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户充值记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取充值记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取充值记录成功")
|
||||
}
|
||||
|
||||
// GetAdminRechargeRecords 管理员获取充值记录分页
|
||||
func (h *FinanceHandler) GetAdminRechargeRecords(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 用户ID筛选
|
||||
if userID := c.Query("user_id"); userID != "" {
|
||||
filters["user_id"] = userID
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
// 充值类型筛选
|
||||
if rechargeType := c.Query("recharge_type"); rechargeType != "" {
|
||||
filters["recharge_type"] = rechargeType
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status := c.Query("status"); status != "" {
|
||||
filters["status"] = status
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: "created_at",
|
||||
Order: "desc",
|
||||
}
|
||||
|
||||
result, err := h.appService.GetAdminRechargeRecords(c.Request.Context(), filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取充值记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取充值记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取充值记录成功")
|
||||
}
|
||||
|
||||
// GetRechargeConfig 获取充值配置
|
||||
// @Summary 获取充值配置
|
||||
// @Description 获取当前环境的充值配置信息(最低充值金额、最高充值金额等)
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} responses.RechargeConfigResponse "获取充值配置成功"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/recharge-config [get]
|
||||
func (h *FinanceHandler) GetRechargeConfig(c *gin.Context) {
|
||||
result, err := h.appService.GetRechargeConfig(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取充值配置失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取充值配置失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取充值配置成功")
|
||||
}
|
||||
|
||||
// GetAlipayOrderStatus 获取支付宝订单状态
|
||||
// @Summary 获取支付宝订单状态
|
||||
// @Description 获取支付宝订单的当前状态,用于轮询查询
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.WithdrawWalletCommand true "提现请求"
|
||||
// @Success 200 {object} responses.TransactionResponse "提现申请已提交"
|
||||
// @Param out_trade_no query string true "商户订单号"
|
||||
// @Success 200 {object} responses.AlipayOrderStatusResponse "获取订单状态成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "订单不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/withdraw [post]
|
||||
func (h *FinanceHandler) Withdraw(c *gin.Context) {
|
||||
// @Router /api/v1/finance/wallet/alipay-order-status [get]
|
||||
func (h *FinanceHandler) GetAlipayOrderStatus(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.WithdrawWalletCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
outTradeNo := c.Query("out_trade_no")
|
||||
if outTradeNo == "" {
|
||||
h.responseBuilder.BadRequest(c, "缺少商户订单号")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.Withdraw(c.Request.Context(), &cmd)
|
||||
result, err := h.appService.GetAlipayOrderStatus(c.Request.Context(), outTradeNo)
|
||||
if err != nil {
|
||||
h.logger.Error("提现失败",
|
||||
h.logger.Error("获取支付宝订单状态失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("out_trade_no", outTradeNo),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
h.responseBuilder.BadRequest(c, "获取订单状态失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "提现申请已提交")
|
||||
}
|
||||
|
||||
// WalletTransaction 钱包交易
|
||||
// @Summary 钱包交易
|
||||
// @Description 执行钱包内部交易操作
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.WalletTransactionCommand true "交易请求"
|
||||
// @Success 200 {object} responses.TransactionResponse "交易成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/transaction [post]
|
||||
func (h *FinanceHandler) WalletTransaction(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.WalletTransactionCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.WalletTransaction(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("钱包交易失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "交易成功")
|
||||
}
|
||||
|
||||
// GetWalletStats 获取钱包统计
|
||||
// @Summary 获取钱包统计
|
||||
// @Description 获取钱包相关的统计数据
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.WalletStatsResponse "获取钱包统计成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/stats [get]
|
||||
func (h *FinanceHandler) GetWalletStats(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetWalletStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取钱包统计失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.InternalError(c, "获取钱包统计失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取钱包统计成功")
|
||||
}
|
||||
|
||||
// CreateUserSecrets 创建用户密钥
|
||||
// @Summary 创建用户密钥
|
||||
// @Description 为用户创建API访问密钥
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateUserSecretsCommand true "创建密钥请求"
|
||||
// @Success 201 {object} responses.UserSecretsResponse "用户密钥创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 409 {object} map[string]interface{} "密钥已存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets [post]
|
||||
func (h *FinanceHandler) CreateUserSecrets(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.CreateUserSecretsCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.CreateUserSecrets(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("创建用户密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, result, "用户密钥创建成功")
|
||||
}
|
||||
|
||||
// GetUserSecrets 获取用户密钥
|
||||
// @Summary 获取用户密钥
|
||||
// @Description 获取当前用户的API访问密钥信息
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UserSecretsResponse "获取用户密钥成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "密钥不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets [get]
|
||||
func (h *FinanceHandler) GetUserSecrets(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetUserSecretsQuery{UserID: userID}
|
||||
result, err := h.appService.GetUserSecrets(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取用户密钥成功")
|
||||
}
|
||||
|
||||
// RegenerateAccessKey 重新生成访问密钥
|
||||
// @Summary 重新生成访问密钥
|
||||
// @Description 重新生成用户的API访问密钥
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UserSecretsResponse "访问密钥重新生成成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "密钥不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets/regenerate [post]
|
||||
func (h *FinanceHandler) RegenerateAccessKey(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
cmd := &commands.RegenerateAccessKeyCommand{UserID: userID}
|
||||
result, err := h.appService.RegenerateAccessKey(c.Request.Context(), cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("重新生成访问密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "访问密钥重新生成成功")
|
||||
}
|
||||
|
||||
// DeactivateUserSecrets 停用用户密钥
|
||||
// @Summary 停用用户密钥
|
||||
// @Description 停用用户的API访问密钥
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} map[string]interface{} "用户密钥停用成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "密钥不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets/deactivate [post]
|
||||
func (h *FinanceHandler) DeactivateUserSecrets(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
cmd := &commands.DeactivateUserSecretsCommand{UserID: userID}
|
||||
err := h.appService.DeactivateUserSecrets(c.Request.Context(), cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("停用用户密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "用户密钥停用成功")
|
||||
h.responseBuilder.Success(c, result, "获取订单状态成功")
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"tyapi-server/internal/application/product"
|
||||
"tyapi-server/internal/application/product/dto/commands"
|
||||
"tyapi-server/internal/application/product/dto/queries"
|
||||
"tyapi-server/internal/application/product/dto/responses"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -84,7 +86,9 @@ func (h *ProductAdminHandler) CreateProduct(c *gin.Context) {
|
||||
// @Router /api/v1/admin/products/{id} [put]
|
||||
func (h *ProductAdminHandler) UpdateProduct(c *gin.Context) {
|
||||
var cmd commands.UpdateProductCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
cmd.ID = c.Param("id")
|
||||
if cmd.ID == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品ID不能为空")
|
||||
return
|
||||
}
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
@@ -254,7 +258,7 @@ func (h *ProductAdminHandler) UpdateSubscriptionPrice(c *gin.Context) {
|
||||
|
||||
// ListProducts 获取产品列表(管理员)
|
||||
// @Summary 获取产品列表
|
||||
// @Description 管理员获取产品列表,支持筛选
|
||||
// @Description 管理员获取产品列表,支持筛选和分页
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
@@ -263,30 +267,76 @@ func (h *ProductAdminHandler) UpdateSubscriptionPrice(c *gin.Context) {
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param keyword query string false "搜索关键词"
|
||||
// @Param category_id query string false "分类ID"
|
||||
// @Param status query string false "产品状态"
|
||||
// @Param is_enabled query bool false "是否启用"
|
||||
// @Param is_visible query bool false "是否可见"
|
||||
// @Param is_package query bool false "是否组合包"
|
||||
// @Param sort_by query string false "排序字段"
|
||||
// @Param sort_order query string false "排序方向" Enums(asc, desc)
|
||||
// @Success 200 {object} responses.ProductListResponse "获取产品列表成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products [get]
|
||||
func (h *ProductAdminHandler) ListProducts(c *gin.Context) {
|
||||
var query queries.ListProductsQuery
|
||||
if err := h.validator.ValidateQuery(c, &query); err != nil {
|
||||
return
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 搜索关键词筛选
|
||||
if keyword := c.Query("keyword"); keyword != "" {
|
||||
filters["keyword"] = keyword
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
if query.PageSize > 100 {
|
||||
query.PageSize = 100
|
||||
// 分类ID筛选
|
||||
if categoryID := c.Query("category_id"); categoryID != "" {
|
||||
filters["category_id"] = categoryID
|
||||
}
|
||||
|
||||
result, err := h.productAppService.ListProducts(c.Request.Context(), &query)
|
||||
// 启用状态筛选
|
||||
if isEnabled := c.Query("is_enabled"); isEnabled != "" {
|
||||
if enabled, err := strconv.ParseBool(isEnabled); err == nil {
|
||||
filters["is_enabled"] = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// 可见状态筛选
|
||||
if isVisible := c.Query("is_visible"); isVisible != "" {
|
||||
if visible, err := strconv.ParseBool(isVisible); err == nil {
|
||||
filters["is_visible"] = visible
|
||||
}
|
||||
}
|
||||
|
||||
// 产品类型筛选
|
||||
if isPackage := c.Query("is_package"); isPackage != "" {
|
||||
if pkg, err := strconv.ParseBool(isPackage); err == nil {
|
||||
filters["is_package"] = pkg
|
||||
}
|
||||
}
|
||||
|
||||
// 排序字段
|
||||
sortBy := c.Query("sort_by")
|
||||
if sortBy == "" {
|
||||
sortBy = "created_at"
|
||||
}
|
||||
|
||||
// 排序方向
|
||||
sortOrder := c.Query("sort_order")
|
||||
if sortOrder == "" {
|
||||
sortOrder = "desc"
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: sortBy,
|
||||
Order: sortOrder,
|
||||
}
|
||||
|
||||
result, err := h.productAppService.ListProducts(c.Request.Context(), filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取产品列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取产品列表失败")
|
||||
@@ -296,6 +346,16 @@ func (h *ProductAdminHandler) ListProducts(c *gin.Context) {
|
||||
h.responseBuilder.Success(c, result, "获取产品列表成功")
|
||||
}
|
||||
|
||||
// getIntQuery 获取整数查询参数
|
||||
func (h *ProductAdminHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if value := c.Query(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetProductDetail 获取产品详情(管理员)
|
||||
// @Summary 获取产品详情
|
||||
// @Description 管理员获取产品详细信息
|
||||
@@ -329,6 +389,233 @@ func (h *ProductAdminHandler) GetProductDetail(c *gin.Context) {
|
||||
h.responseBuilder.Success(c, result, "获取产品详情成功")
|
||||
}
|
||||
|
||||
// GetAvailableProducts 获取可选子产品列表
|
||||
// @Summary 获取可选子产品列表
|
||||
// @Description 管理员获取可选作组合包子产品的产品列表
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param exclude_package_id query string false "排除的组合包ID"
|
||||
// @Param keyword query string false "搜索关键词"
|
||||
// @Param category_id query string false "分类ID"
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(20)
|
||||
// @Success 200 {object} responses.ProductListResponse "获取可选产品列表成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/available [get]
|
||||
func (h *ProductAdminHandler) GetAvailableProducts(c *gin.Context) {
|
||||
var query queries.GetAvailableProductsQuery
|
||||
if err := h.validator.ValidateQuery(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 20
|
||||
}
|
||||
if query.PageSize > 100 {
|
||||
query.PageSize = 100
|
||||
}
|
||||
|
||||
result, err := h.productAppService.GetAvailableProducts(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取可选产品列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取可选产品列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取可选产品列表成功")
|
||||
}
|
||||
|
||||
// AddPackageItem 添加组合包子产品
|
||||
// @Summary 添加组合包子产品
|
||||
// @Description 管理员向组合包添加子产品
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "组合包ID"
|
||||
// @Param command body commands.AddPackageItemCommand true "添加子产品命令"
|
||||
// @Success 200 {object} map[string]interface{} "添加成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/package-items [post]
|
||||
func (h *ProductAdminHandler) AddPackageItem(c *gin.Context) {
|
||||
packageID := c.Param("id")
|
||||
if packageID == "" {
|
||||
h.responseBuilder.BadRequest(c, "组合包ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.AddPackageItemCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.productAppService.AddPackageItem(c.Request.Context(), packageID, &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("添加组合包子产品失败", zap.Error(err), zap.String("package_id", packageID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "添加组合包子产品成功")
|
||||
}
|
||||
|
||||
// UpdatePackageItem 更新组合包子产品
|
||||
// @Summary 更新组合包子产品
|
||||
// @Description 管理员更新组合包子产品信息
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "组合包ID"
|
||||
// @Param item_id path string true "子产品项目ID"
|
||||
// @Param command body commands.UpdatePackageItemCommand true "更新子产品命令"
|
||||
// @Success 200 {object} map[string]interface{} "更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/package-items/{item_id} [put]
|
||||
func (h *ProductAdminHandler) UpdatePackageItem(c *gin.Context) {
|
||||
packageID := c.Param("id")
|
||||
itemID := c.Param("item_id")
|
||||
if packageID == "" || itemID == "" {
|
||||
h.responseBuilder.BadRequest(c, "参数不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.UpdatePackageItemCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.productAppService.UpdatePackageItem(c.Request.Context(), packageID, itemID, &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("更新组合包子产品失败", zap.Error(err), zap.String("package_id", packageID), zap.String("item_id", itemID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "更新组合包子产品成功")
|
||||
}
|
||||
|
||||
// RemovePackageItem 移除组合包子产品
|
||||
// @Summary 移除组合包子产品
|
||||
// @Description 管理员从组合包移除子产品
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "组合包ID"
|
||||
// @Param item_id path string true "子产品项目ID"
|
||||
// @Success 200 {object} map[string]interface{} "移除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/package-items/{item_id} [delete]
|
||||
func (h *ProductAdminHandler) RemovePackageItem(c *gin.Context) {
|
||||
packageID := c.Param("id")
|
||||
itemID := c.Param("item_id")
|
||||
if packageID == "" || itemID == "" {
|
||||
h.responseBuilder.BadRequest(c, "参数不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.productAppService.RemovePackageItem(c.Request.Context(), packageID, itemID)
|
||||
if err != nil {
|
||||
h.logger.Error("移除组合包子产品失败", zap.Error(err), zap.String("package_id", packageID), zap.String("item_id", itemID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "移除组合包子产品成功")
|
||||
}
|
||||
|
||||
// ReorderPackageItems 重新排序组合包子产品
|
||||
// @Summary 重新排序组合包子产品
|
||||
// @Description 管理员重新排序组合包子产品
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "组合包ID"
|
||||
// @Param command body commands.ReorderPackageItemsCommand true "重新排序命令"
|
||||
// @Success 200 {object} map[string]interface{} "排序成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/package-items/reorder [put]
|
||||
func (h *ProductAdminHandler) ReorderPackageItems(c *gin.Context) {
|
||||
packageID := c.Param("id")
|
||||
if packageID == "" {
|
||||
h.responseBuilder.BadRequest(c, "组合包ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.ReorderPackageItemsCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.productAppService.ReorderPackageItems(c.Request.Context(), packageID, &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("重新排序组合包子产品失败", zap.Error(err), zap.String("package_id", packageID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "重新排序组合包子产品成功")
|
||||
}
|
||||
|
||||
// UpdatePackageItems 批量更新组合包子产品
|
||||
// @Summary 批量更新组合包子产品
|
||||
// @Description 管理员批量更新组合包子产品配置
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "组合包ID"
|
||||
// @Param command body commands.UpdatePackageItemsCommand true "批量更新命令"
|
||||
// @Success 200 {object} map[string]interface{} "更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/package-items/batch [put]
|
||||
func (h *ProductAdminHandler) UpdatePackageItems(c *gin.Context) {
|
||||
packageID := c.Param("id")
|
||||
if packageID == "" {
|
||||
h.responseBuilder.BadRequest(c, "组合包ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.UpdatePackageItemsCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
err := h.productAppService.UpdatePackageItems(c.Request.Context(), packageID, &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("批量更新组合包子产品失败", zap.Error(err), zap.String("package_id", packageID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "批量更新组合包子产品成功")
|
||||
}
|
||||
|
||||
// ListCategories 获取分类列表(管理员)
|
||||
// @Summary 获取分类列表
|
||||
// @Description 管理员获取产品分类列表
|
||||
@@ -467,3 +754,166 @@ func (h *ProductAdminHandler) GetSubscriptionStats(c *gin.Context) {
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取订阅统计成功")
|
||||
}
|
||||
|
||||
// GetProductApiConfig 获取产品API配置
|
||||
// @Summary 获取产品API配置
|
||||
// @Description 管理员获取产品的API配置信息,如果不存在则返回空配置
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "产品ID"
|
||||
// @Success 200 {object} responses.ProductApiConfigResponse "获取API配置成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/api-config [get]
|
||||
func (h *ProductAdminHandler) GetProductApiConfig(c *gin.Context) {
|
||||
productID := c.Param("id")
|
||||
if productID == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.productAppService.GetProductApiConfig(c.Request.Context(), productID)
|
||||
if err != nil {
|
||||
// 如果是配置不存在的错误,返回空配置而不是错误
|
||||
if err.Error() == "record not found" || err.Error() == "产品API配置不存在" {
|
||||
// 返回空的配置结构,让前端可以创建新配置
|
||||
emptyConfig := &responses.ProductApiConfigResponse{
|
||||
ID: "",
|
||||
ProductID: productID,
|
||||
RequestParams: []responses.RequestParamResponse{},
|
||||
ResponseFields: []responses.ResponseFieldResponse{},
|
||||
ResponseExample: map[string]interface{}{},
|
||||
}
|
||||
h.responseBuilder.Success(c, emptyConfig, "获取API配置成功")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Error("获取产品API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.NotFound(c, "产品不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取API配置成功")
|
||||
}
|
||||
|
||||
// CreateProductApiConfig 创建产品API配置
|
||||
// @Summary 创建产品API配置
|
||||
// @Description 管理员为产品创建API配置
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "产品ID"
|
||||
// @Param request body responses.ProductApiConfigResponse true "API配置信息"
|
||||
// @Success 201 {object} map[string]interface{} "API配置创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 409 {object} map[string]interface{} "API配置已存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/api-config [post]
|
||||
func (h *ProductAdminHandler) CreateProductApiConfig(c *gin.Context) {
|
||||
productID := c.Param("id")
|
||||
if productID == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var configResponse responses.ProductApiConfigResponse
|
||||
if err := h.validator.BindAndValidate(c, &configResponse); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.productAppService.CreateProductApiConfig(c.Request.Context(), productID, &configResponse); err != nil {
|
||||
h.logger.Error("创建产品API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, nil, "API配置创建成功")
|
||||
}
|
||||
|
||||
// UpdateProductApiConfig 更新产品API配置
|
||||
// @Summary 更新产品API配置
|
||||
// @Description 管理员更新产品的API配置
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "产品ID"
|
||||
// @Param request body responses.ProductApiConfigResponse true "API配置信息"
|
||||
// @Success 200 {object} map[string]interface{} "API配置更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品或配置不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/api-config [put]
|
||||
func (h *ProductAdminHandler) UpdateProductApiConfig(c *gin.Context) {
|
||||
productID := c.Param("id")
|
||||
if productID == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var configResponse responses.ProductApiConfigResponse
|
||||
if err := h.validator.BindAndValidate(c, &configResponse); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 先获取现有配置以获取配置ID
|
||||
existingConfig, err := h.productAppService.GetProductApiConfig(c.Request.Context(), productID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取现有API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.NotFound(c, "产品API配置不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.productAppService.UpdateProductApiConfig(c.Request.Context(), existingConfig.ID, &configResponse); err != nil {
|
||||
h.logger.Error("更新产品API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "API配置更新成功")
|
||||
}
|
||||
|
||||
// DeleteProductApiConfig 删除产品API配置
|
||||
// @Summary 删除产品API配置
|
||||
// @Description 管理员删除产品的API配置
|
||||
// @Tags 产品管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "产品ID"
|
||||
// @Success 200 {object} map[string]interface{} "API配置删除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "产品或配置不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/products/{id}/api-config [delete]
|
||||
func (h *ProductAdminHandler) DeleteProductApiConfig(c *gin.Context) {
|
||||
productID := c.Param("id")
|
||||
if productID == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 先获取现有配置以获取配置ID
|
||||
existingConfig, err := h.productAppService.GetProductApiConfig(c.Request.Context(), productID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取现有API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.NotFound(c, "产品API配置不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.productAppService.DeleteProductApiConfig(c.Request.Context(), existingConfig.ID); err != nil {
|
||||
h.logger.Error("删除产品API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "API配置删除成功")
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"tyapi-server/internal/application/product"
|
||||
"tyapi-server/internal/application/product/dto/commands"
|
||||
"tyapi-server/internal/application/product/dto/queries"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
// ProductHandler 产品相关HTTP处理器
|
||||
type ProductHandler struct {
|
||||
appService product.ProductApplicationService
|
||||
apiConfigService product.ProductApiConfigApplicationService
|
||||
categoryService product.CategoryApplicationService
|
||||
subAppService product.SubscriptionApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
@@ -23,6 +25,7 @@ type ProductHandler struct {
|
||||
// NewProductHandler 创建产品HTTP处理器
|
||||
func NewProductHandler(
|
||||
appService product.ProductApplicationService,
|
||||
apiConfigService product.ProductApiConfigApplicationService,
|
||||
categoryService product.CategoryApplicationService,
|
||||
subAppService product.SubscriptionApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
@@ -31,6 +34,7 @@ func NewProductHandler(
|
||||
) *ProductHandler {
|
||||
return &ProductHandler{
|
||||
appService: appService,
|
||||
apiConfigService: apiConfigService,
|
||||
categoryService: categoryService,
|
||||
subAppService: subAppService,
|
||||
responseBuilder: responseBuilder,
|
||||
@@ -49,8 +53,6 @@ func NewProductHandler(
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param keyword query string false "搜索关键词"
|
||||
// @Param category_id query string false "分类ID"
|
||||
// @Param min_price query number false "最低价格"
|
||||
// @Param max_price query number false "最高价格"
|
||||
// @Param is_enabled query bool false "是否启用"
|
||||
// @Param is_visible query bool false "是否可见"
|
||||
// @Param is_package query bool false "是否组合包"
|
||||
@@ -61,23 +63,65 @@ func NewProductHandler(
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/products [get]
|
||||
func (h *ProductHandler) ListProducts(c *gin.Context) {
|
||||
var query queries.ListProductsQuery
|
||||
if err := h.validator.ValidateQuery(c, &query); err != nil {
|
||||
return
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 搜索关键词筛选
|
||||
if keyword := c.Query("keyword"); keyword != "" {
|
||||
filters["keyword"] = keyword
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
if query.PageSize > 100 {
|
||||
query.PageSize = 100
|
||||
// 分类ID筛选
|
||||
if categoryID := c.Query("category_id"); categoryID != "" {
|
||||
filters["category_id"] = categoryID
|
||||
}
|
||||
|
||||
result, err := h.appService.ListProducts(c.Request.Context(), &query)
|
||||
// 启用状态筛选
|
||||
if isEnabled := c.Query("is_enabled"); isEnabled != "" {
|
||||
if enabled, err := strconv.ParseBool(isEnabled); err == nil {
|
||||
filters["is_enabled"] = enabled
|
||||
}
|
||||
}
|
||||
|
||||
// 可见状态筛选
|
||||
if isVisible := c.Query("is_visible"); isVisible != "" {
|
||||
if visible, err := strconv.ParseBool(isVisible); err == nil {
|
||||
filters["is_visible"] = visible
|
||||
}
|
||||
}
|
||||
|
||||
// 产品类型筛选
|
||||
if isPackage := c.Query("is_package"); isPackage != "" {
|
||||
if pkg, err := strconv.ParseBool(isPackage); err == nil {
|
||||
filters["is_package"] = pkg
|
||||
}
|
||||
}
|
||||
|
||||
// 排序字段
|
||||
sortBy := c.Query("sort_by")
|
||||
if sortBy == "" {
|
||||
sortBy = "created_at"
|
||||
}
|
||||
|
||||
// 排序方向
|
||||
sortOrder := c.Query("sort_order")
|
||||
if sortOrder == "" {
|
||||
sortOrder = "desc"
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: sortBy,
|
||||
Order: sortOrder,
|
||||
}
|
||||
|
||||
result, err := h.appService.ListProducts(c.Request.Context(), filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取产品列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取产品列表失败")
|
||||
@@ -87,6 +131,16 @@ func (h *ProductHandler) ListProducts(c *gin.Context) {
|
||||
h.responseBuilder.Success(c, result, "获取产品列表成功")
|
||||
}
|
||||
|
||||
// getIntQuery 获取整数查询参数
|
||||
func (h *ProductHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if value := c.Query(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetProductDetail 获取产品详情
|
||||
// @Summary 获取产品详情
|
||||
// @Description 根据产品ID获取产品详细信息
|
||||
@@ -176,6 +230,62 @@ func (h *ProductHandler) GetProductStats(c *gin.Context) {
|
||||
h.responseBuilder.Success(c, result, "获取产品统计成功")
|
||||
}
|
||||
|
||||
// GetProductApiConfig 获取产品API配置
|
||||
// @Summary 获取产品API配置
|
||||
// @Description 根据产品ID获取API配置信息
|
||||
// @Tags 产品API配置
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "产品ID"
|
||||
// @Success 200 {object} responses.ProductApiConfigResponse "获取成功"
|
||||
// @Failure 400 {object} interfaces.APIResponse "请求参数错误"
|
||||
// @Failure 404 {object} interfaces.APIResponse "配置不存在"
|
||||
// @Router /api/v1/products/{id}/api-config [get]
|
||||
func (h *ProductHandler) GetProductApiConfig(c *gin.Context) {
|
||||
productID := c.Param("id")
|
||||
if productID == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
config, err := h.apiConfigService.GetProductApiConfig(c.Request.Context(), productID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取产品API配置失败", zap.Error(err), zap.String("product_id", productID))
|
||||
h.responseBuilder.NotFound(c, "产品API配置不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, config, "获取产品API配置成功")
|
||||
}
|
||||
|
||||
// GetProductApiConfigByCode 根据产品代码获取API配置
|
||||
// @Summary 根据产品代码获取API配置
|
||||
// @Description 根据产品代码获取API配置信息
|
||||
// @Tags 产品API配置
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param product_code path string true "产品代码"
|
||||
// @Success 200 {object} responses.ProductApiConfigResponse "获取成功"
|
||||
// @Failure 400 {object} interfaces.APIResponse "请求参数错误"
|
||||
// @Failure 404 {object} interfaces.APIResponse "配置不存在"
|
||||
// @Router /api/v1/products/code/{product_code}/api-config [get]
|
||||
func (h *ProductHandler) GetProductApiConfigByCode(c *gin.Context) {
|
||||
productCode := c.Param("product_code")
|
||||
if productCode == "" {
|
||||
h.responseBuilder.BadRequest(c, "产品代码不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
config, err := h.apiConfigService.GetProductApiConfigByCode(c.Request.Context(), productCode)
|
||||
if err != nil {
|
||||
h.logger.Error("根据产品代码获取API配置失败", zap.Error(err), zap.String("product_code", productCode))
|
||||
h.responseBuilder.NotFound(c, "产品API配置不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, config, "获取产品API配置成功")
|
||||
}
|
||||
|
||||
// ================ 分类相关方法 ================
|
||||
|
||||
// ListCategories 获取分类列表
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/user"
|
||||
"tyapi-server/internal/application/user/dto/commands"
|
||||
"tyapi-server/internal/application/user/dto/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
)
|
||||
@@ -240,6 +243,109 @@ func (h *UserHandler) ResetPassword(c *gin.Context) {
|
||||
h.response.Success(c, nil, "密码重置成功")
|
||||
}
|
||||
|
||||
// ListUsers 管理员查看用户列表
|
||||
// @Summary 管理员查看用户列表
|
||||
// @Description 管理员查看用户列表,支持分页和筛选
|
||||
// @Tags 用户管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param phone query string false "手机号筛选"
|
||||
// @Param user_type query string false "用户类型筛选" Enums(user,admin)
|
||||
// @Param is_active query bool false "是否激活筛选"
|
||||
// @Param is_certified query bool false "是否已认证筛选"
|
||||
// @Param company_name query string false "企业名称筛选"
|
||||
// @Param start_date query string false "开始日期" format(date)
|
||||
// @Param end_date query string false "结束日期" format(date)
|
||||
// @Success 200 {object} responses.UserListResponse "用户列表"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/admin/list [get]
|
||||
func (h *UserHandler) ListUsers(c *gin.Context) {
|
||||
// 检查管理员权限
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// 构建查询参数
|
||||
query := &queries.ListUsersQuery{
|
||||
Page: 1,
|
||||
PageSize: 10,
|
||||
}
|
||||
|
||||
// 从查询参数中获取筛选条件
|
||||
if page := c.Query("page"); page != "" {
|
||||
if pageNum, err := strconv.Atoi(page); err == nil && pageNum > 0 {
|
||||
query.Page = pageNum
|
||||
}
|
||||
}
|
||||
|
||||
if pageSize := c.Query("page_size"); pageSize != "" {
|
||||
if size, err := strconv.Atoi(pageSize); err == nil && size > 0 && size <= 100 {
|
||||
query.PageSize = size
|
||||
}
|
||||
}
|
||||
|
||||
query.Phone = c.Query("phone")
|
||||
query.UserType = c.Query("user_type")
|
||||
query.CompanyName = c.Query("company_name")
|
||||
query.StartDate = c.Query("start_date")
|
||||
query.EndDate = c.Query("end_date")
|
||||
|
||||
// 处理布尔值参数
|
||||
if isActive := c.Query("is_active"); isActive != "" {
|
||||
if active, err := strconv.ParseBool(isActive); err == nil {
|
||||
query.IsActive = &active
|
||||
}
|
||||
}
|
||||
|
||||
if isCertified := c.Query("is_certified"); isCertified != "" {
|
||||
if certified, err := strconv.ParseBool(isCertified); err == nil {
|
||||
query.IsCertified = &certified
|
||||
}
|
||||
}
|
||||
|
||||
// 调用应用服务
|
||||
resp, err := h.appService.ListUsers(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户列表失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "获取用户列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, resp, "获取用户列表成功")
|
||||
}
|
||||
|
||||
// GetUserStats 管理员获取用户统计信息
|
||||
// @Summary 管理员获取用户统计信息
|
||||
// @Description 管理员获取用户统计信息,包括总用户数、活跃用户数、已认证用户数
|
||||
// @Tags 用户管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UserStatsResponse "用户统计信息"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/admin/stats [get]
|
||||
func (h *UserHandler) GetUserStats(c *gin.Context) {
|
||||
// 调用应用服务
|
||||
resp, err := h.appService.GetUserStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户统计信息失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "获取用户统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, resp, "获取用户统计信息成功")
|
||||
}
|
||||
|
||||
// getCurrentUserID 获取当前用户ID
|
||||
func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
|
||||
59
internal/infrastructure/http/routes/api_routes.go
Normal file
59
internal/infrastructure/http/routes/api_routes.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/http/handlers"
|
||||
sharedhttp "tyapi-server/internal/shared/http"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ApiRoutes API路由注册器
|
||||
type ApiRoutes struct {
|
||||
apiHandler *handlers.ApiHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
domainAuthMiddleware *middleware.DomainAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewApiRoutes 创建API路由注册器
|
||||
func NewApiRoutes(
|
||||
apiHandler *handlers.ApiHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
domainAuthMiddleware *middleware.DomainAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *ApiRoutes {
|
||||
return &ApiRoutes{
|
||||
apiHandler: apiHandler,
|
||||
authMiddleware: authMiddleware,
|
||||
domainAuthMiddleware: domainAuthMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册相关路由
|
||||
func (r *ApiRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// API路由组,需要用户认证
|
||||
engine := router.GetEngine()
|
||||
apiGroup := engine.Group("/api/v1")
|
||||
|
||||
{
|
||||
apiGroup.POST("/:api_name", r.domainAuthMiddleware.Handle(""), r.apiHandler.HandleApiCall)
|
||||
|
||||
// 加密接口(用于前端调试)
|
||||
apiGroup.POST("/encrypt", r.authMiddleware.Handle(), r.apiHandler.EncryptParams)
|
||||
|
||||
// API密钥管理接口
|
||||
apiGroup.GET("/api-keys", r.authMiddleware.Handle(), r.apiHandler.GetUserApiKeys)
|
||||
|
||||
// 白名单管理接口
|
||||
apiGroup.GET("/white-list", r.authMiddleware.Handle(), r.apiHandler.GetUserWhiteList)
|
||||
apiGroup.POST("/white-list", r.authMiddleware.Handle(), r.apiHandler.AddWhiteListIP)
|
||||
apiGroup.DELETE("/white-list/:ip", r.authMiddleware.Handle(), r.apiHandler.DeleteWhiteListIP)
|
||||
|
||||
// API调用记录接口
|
||||
apiGroup.GET("/my/api-calls", r.authMiddleware.Handle(), r.apiHandler.GetUserApiCalls)
|
||||
}
|
||||
|
||||
r.logger.Info("API路由注册完成")
|
||||
}
|
||||
@@ -43,31 +43,29 @@ func (r *CertificationRoutes) Register(router *http.GinRouter) {
|
||||
authGroup := certificationGroup.Group("")
|
||||
authGroup.Use(r.auth.Handle())
|
||||
{
|
||||
authGroup.GET("/user", r.handler.GetUserCertifications) // 获取用户认证列表
|
||||
authGroup.GET("", r.handler.ListCertifications) // 查询认证列表(管理员)
|
||||
authGroup.GET("/statistics", r.handler.GetCertificationStatistics) // 获取认证统计
|
||||
authGroup.GET("", r.handler.ListCertifications) // 查询认证列表(管理员)
|
||||
|
||||
// 1. 获取认证详情
|
||||
authGroup.GET("/:id", r.handler.GetCertification)
|
||||
authGroup.GET("/details", r.handler.GetCertification)
|
||||
|
||||
// 2. 提交企业信息
|
||||
authGroup.POST("/:id/enterprise-info", r.handler.SubmitEnterpriseInfo)
|
||||
authGroup.POST("/enterprise-info", r.handler.SubmitEnterpriseInfo)
|
||||
|
||||
// 合同管理
|
||||
authGroup.POST("/apply-contract", r.handler.ApplyContract) // 申请合同签署
|
||||
// 3. 申请合同签署
|
||||
authGroup.POST("/apply-contract", r.handler.ApplyContract)
|
||||
|
||||
// 重试操作
|
||||
authGroup.POST("/retry", r.handler.RetryOperation) // 重试操作
|
||||
// 前端确认是否完成认证
|
||||
authGroup.POST("/confirm-auth", r.handler.ConfirmAuth)
|
||||
|
||||
// 前端确认是否完成签署
|
||||
authGroup.POST("/confirm-sign", r.handler.ConfirmSign)
|
||||
|
||||
// 管理员操作
|
||||
authGroup.POST("/force-transition", r.handler.ForceTransitionStatus) // 强制状态转换
|
||||
authGroup.GET("/monitoring", r.handler.GetSystemMonitoring) // 获取系统监控数据
|
||||
}
|
||||
|
||||
// 回调路由(不需要认证,但需要验证签名)
|
||||
callbackGroup := certificationGroup.Group("/callbacks")
|
||||
{
|
||||
callbackGroup.POST("", r.handler.HandleEsignCallback) // e签宝回调(统一处理企业认证和合同签署回调)
|
||||
callbackGroup.POST("/esign", r.handler.HandleEsignCallback) // e签宝回调(统一处理企业认证和合同签署回调)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,7 +85,7 @@ func (r *CertificationRoutes) GetRoutes() []RouteInfo {
|
||||
{Method: "POST", Path: "/api/v1/certifications/retry", Handler: "RetryOperation", Auth: true},
|
||||
{Method: "POST", Path: "/api/v1/certifications/force-transition", Handler: "ForceTransitionStatus", Auth: true},
|
||||
{Method: "GET", Path: "/api/v1/certifications/monitoring", Handler: "GetSystemMonitoring", Auth: true},
|
||||
{Method: "POST", Path: "/api/v1/certifications/callbacks", Handler: "HandleEsignCallback", Auth: false},
|
||||
{Method: "POST", Path: "/api/v1/certifications/callbacks/esign", Handler: "HandleEsignCallback", Auth: false},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,51 +10,61 @@ import (
|
||||
|
||||
// FinanceRoutes 财务路由注册器
|
||||
type FinanceRoutes struct {
|
||||
financeHandler *handlers.FinanceHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
logger *zap.Logger
|
||||
financeHandler *handlers.FinanceHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
adminAuthMiddleware *middleware.AdminAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewFinanceRoutes 创建财务路由注册器
|
||||
func NewFinanceRoutes(
|
||||
financeHandler *handlers.FinanceHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
adminAuthMiddleware *middleware.AdminAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *FinanceRoutes {
|
||||
return &FinanceRoutes{
|
||||
financeHandler: financeHandler,
|
||||
authMiddleware: authMiddleware,
|
||||
logger: logger,
|
||||
financeHandler: financeHandler,
|
||||
authMiddleware: authMiddleware,
|
||||
adminAuthMiddleware: adminAuthMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册财务相关路由
|
||||
func (r *FinanceRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 财务路由组,需要用户认证
|
||||
engine := router.GetEngine()
|
||||
|
||||
// 支付宝回调路由(不需要认证)
|
||||
alipayGroup := engine.Group("/api/v1/finance/alipay")
|
||||
{
|
||||
alipayGroup.POST("/callback", r.financeHandler.HandleAlipayCallback) // 支付宝异步回调
|
||||
alipayGroup.GET("/return", r.financeHandler.HandleAlipayReturn) // 支付宝同步回调
|
||||
}
|
||||
|
||||
// 财务路由组,需要用户认证
|
||||
financeGroup := engine.Group("/api/v1/finance")
|
||||
financeGroup.Use(r.authMiddleware.Handle())
|
||||
{
|
||||
// 钱包相关路由
|
||||
walletGroup := financeGroup.Group("/wallet")
|
||||
{
|
||||
walletGroup.POST("", r.financeHandler.CreateWallet) // 创建钱包
|
||||
walletGroup.GET("", r.financeHandler.GetWallet) // 获取钱包信息
|
||||
walletGroup.PUT("", r.financeHandler.UpdateWallet) // 更新钱包
|
||||
walletGroup.POST("/recharge", r.financeHandler.Recharge) // 充值
|
||||
walletGroup.POST("/withdraw", r.financeHandler.Withdraw) // 提现
|
||||
walletGroup.POST("/transaction", r.financeHandler.WalletTransaction) // 钱包交易
|
||||
walletGroup.GET("/stats", r.financeHandler.GetWalletStats) // 获取钱包统计
|
||||
walletGroup.GET("", r.financeHandler.GetWallet) // 获取钱包信息
|
||||
walletGroup.GET("/transactions", r.financeHandler.GetUserWalletTransactions) // 获取钱包交易记录
|
||||
walletGroup.GET("/recharge-config", r.financeHandler.GetRechargeConfig) // 获取充值配置
|
||||
walletGroup.POST("/alipay-recharge", r.financeHandler.CreateAlipayRecharge) // 创建支付宝充值订单
|
||||
walletGroup.GET("/recharge-records", r.financeHandler.GetUserRechargeRecords) // 用户充值记录分页
|
||||
walletGroup.GET("/alipay-order-status", r.financeHandler.GetAlipayOrderStatus) // 获取支付宝订单状态
|
||||
}
|
||||
}
|
||||
|
||||
// 用户密钥相关路由
|
||||
secretsGroup := financeGroup.Group("/secrets")
|
||||
{
|
||||
secretsGroup.POST("", r.financeHandler.CreateUserSecrets) // 创建用户密钥
|
||||
secretsGroup.GET("", r.financeHandler.GetUserSecrets) // 获取用户密钥
|
||||
secretsGroup.POST("/regenerate", r.financeHandler.RegenerateAccessKey) // 重新生成访问密钥
|
||||
secretsGroup.POST("/deactivate", r.financeHandler.DeactivateUserSecrets) // 停用用户密钥
|
||||
}
|
||||
// 管理员财务路由组
|
||||
adminFinanceGroup := engine.Group("/api/v1/admin/finance")
|
||||
adminFinanceGroup.Use(r.adminAuthMiddleware.Handle())
|
||||
{
|
||||
adminFinanceGroup.POST("/transfer-recharge", r.financeHandler.TransferRecharge) // 对公转账充值
|
||||
adminFinanceGroup.POST("/gift-recharge", r.financeHandler.GiftRecharge) // 赠送充值
|
||||
adminFinanceGroup.GET("/recharge-records", r.financeHandler.GetAdminRechargeRecords) // 管理员充值记录分页
|
||||
}
|
||||
|
||||
r.logger.Info("财务路由注册完成")
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
package routes
|
||||
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/http/handlers"
|
||||
sharedhttp "tyapi-server/internal/shared/http"
|
||||
@@ -31,19 +31,33 @@ func (r *ProductAdminRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 管理员路由组
|
||||
engine := router.GetEngine()
|
||||
adminGroup := engine.Group("/api/v1/admin")
|
||||
adminGroup.Use(r.auth.Handle()) // JWT认证
|
||||
adminGroup.Use(r.auth.Handle()) // JWT认证
|
||||
adminGroup.Use(r.admin.Handle()) // 管理员权限验证
|
||||
{
|
||||
// 产品管理
|
||||
products := adminGroup.Group("/products")
|
||||
{
|
||||
products.GET("", r.handler.ListProducts)
|
||||
products.GET("/available", r.handler.GetAvailableProducts)
|
||||
products.GET("/:id", r.handler.GetProductDetail)
|
||||
products.POST("", r.handler.CreateProduct)
|
||||
products.PUT("/:id", r.handler.UpdateProduct)
|
||||
products.DELETE("/:id", r.handler.DeleteProduct)
|
||||
|
||||
// 组合包管理
|
||||
products.POST("/:id/package-items", r.handler.AddPackageItem)
|
||||
products.PUT("/:id/package-items/:item_id", r.handler.UpdatePackageItem)
|
||||
products.DELETE("/:id/package-items/:item_id", r.handler.RemovePackageItem)
|
||||
products.PUT("/:id/package-items/reorder", r.handler.ReorderPackageItems)
|
||||
products.PUT("/:id/package-items/batch", r.handler.UpdatePackageItems)
|
||||
|
||||
// API配置管理
|
||||
products.GET("/:id/api-config", r.handler.GetProductApiConfig)
|
||||
products.POST("/:id/api-config", r.handler.CreateProductApiConfig)
|
||||
products.PUT("/:id/api-config", r.handler.UpdateProductApiConfig)
|
||||
products.DELETE("/:id/api-config", r.handler.DeleteProductApiConfig)
|
||||
}
|
||||
|
||||
|
||||
// 分类管理
|
||||
categories := adminGroup.Group("/product-categories")
|
||||
{
|
||||
@@ -53,7 +67,7 @@ func (r *ProductAdminRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
categories.PUT("/:id", r.handler.UpdateCategory)
|
||||
categories.DELETE("/:id", r.handler.DeleteCategory)
|
||||
}
|
||||
|
||||
|
||||
// 订阅管理
|
||||
subscriptions := adminGroup.Group("/subscriptions")
|
||||
{
|
||||
@@ -62,4 +76,4 @@ func (r *ProductAdminRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
subscriptions.PUT("/:id/price", r.handler.UpdateSubscriptionPrice)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,12 +38,16 @@ func (r *ProductRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 获取产品列表(分页+筛选)
|
||||
products.GET("", r.productHandler.ListProducts)
|
||||
|
||||
// 获取产品详情
|
||||
products.GET("/:id", r.productHandler.GetProductDetail)
|
||||
|
||||
// 获取产品统计
|
||||
products.GET("/stats", r.productHandler.GetProductStats)
|
||||
|
||||
// 根据产品代码获取API配置
|
||||
products.GET("/code/:product_code/api-config", r.productHandler.GetProductApiConfigByCode)
|
||||
|
||||
// 产品详情和API配置 - 使用具体路径避免冲突
|
||||
products.GET("/:id", r.productHandler.GetProductDetail)
|
||||
products.GET("/:id/api-config", r.productHandler.GetProductApiConfig)
|
||||
|
||||
// 订阅产品(需要认证)
|
||||
products.POST("/:id/subscribe", r.auth.Handle(), r.productHandler.SubscribeProduct)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
type UserRoutes struct {
|
||||
handler *handlers.UserHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
adminAuthMiddleware *middleware.AdminAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
@@ -19,11 +20,13 @@ type UserRoutes struct {
|
||||
func NewUserRoutes(
|
||||
handler *handlers.UserHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
adminAuthMiddleware *middleware.AdminAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *UserRoutes {
|
||||
return &UserRoutes{
|
||||
handler: handler,
|
||||
authMiddleware: authMiddleware,
|
||||
adminAuthMiddleware: adminAuthMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
@@ -48,6 +51,14 @@ func (r *UserRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
authenticated.GET("/me", r.handler.GetProfile) // 获取当前用户信息
|
||||
authenticated.PUT("/me/password", r.handler.ChangePassword) // 修改密码
|
||||
}
|
||||
|
||||
// 管理员路由
|
||||
adminGroup := usersGroup.Group("/admin")
|
||||
adminGroup.Use(r.adminAuthMiddleware.Handle())
|
||||
{
|
||||
adminGroup.GET("/list", r.handler.ListUsers) // 管理员查看用户列表
|
||||
adminGroup.GET("/stats", r.handler.GetUserStats) // 管理员获取用户统计信息
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("用户路由注册完成")
|
||||
|
||||
Reference in New Issue
Block a user