f
This commit is contained in:
385
internal/infrastructure/cache/redis_cache.go
vendored
Normal file
385
internal/infrastructure/cache/redis_cache.go
vendored
Normal file
@@ -0,0 +1,385 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// RedisCache Redis缓存实现
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
logger *zap.Logger
|
||||
prefix string
|
||||
|
||||
// 统计信息
|
||||
hits int64
|
||||
misses int64
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存实例
|
||||
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
logger: logger,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (r *RedisCache) Name() string {
|
||||
return "redis-cache"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (r *RedisCache) Initialize(ctx context.Context) error {
|
||||
// 测试连接
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to connect to Redis", zap.Error(err))
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Redis cache service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (r *RedisCache) HealthCheck(ctx context.Context) error {
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (r *RedisCache) Shutdown(ctx context.Context) error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
val, err := r.client.Get(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
r.misses++
|
||||
return fmt.Errorf("cache miss: key %s not found", key)
|
||||
}
|
||||
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.hits++
|
||||
return json.Unmarshal([]byte(val), dest)
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
r.logger.Error("序列化缓存数据失败", zap.String("key", key), zap.Error(err))
|
||||
return fmt.Errorf("failed to marshal value: %w", err)
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
case string:
|
||||
expiration, _ = time.ParseDuration(v)
|
||||
default:
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
|
||||
err = r.client.Set(ctx, fullKey, data, expiration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("设置缓存失败", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Debug("设置缓存成功", zap.String("key", key), zap.Duration("ttl", expiration))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
err := r.client.Del(ctx, fullKeys...).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
count, err := r.client.Exists(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取
|
||||
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
values, err := r.client.MGet(ctx, fullKeys...).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("批量获取缓存失败", zap.Strings("keys", keys), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for i, val := range values {
|
||||
if val != nil {
|
||||
var data interface{}
|
||||
// 修复:改进JSON反序列化错误处理
|
||||
if err := json.Unmarshal([]byte(val.(string)), &data); err != nil {
|
||||
r.logger.Warn("反序列化缓存数据失败",
|
||||
zap.String("key", keys[i]),
|
||||
zap.String("value", val.(string)),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
result[keys[i]] = data
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置
|
||||
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
default:
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
for key, value := range data {
|
||||
fullKey := r.getFullKey(key)
|
||||
jsonData, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pipe.Set(ctx, fullKey, jsonData, expiration)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeletePattern 按模式删除
|
||||
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
|
||||
// 修复:避免重复添加前缀
|
||||
var fullPattern string
|
||||
if strings.HasPrefix(pattern, r.prefix+":") {
|
||||
fullPattern = pattern
|
||||
} else {
|
||||
fullPattern = r.getFullKey(pattern)
|
||||
}
|
||||
|
||||
// 检查上下文是否已取消
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
var cursor uint64
|
||||
var totalDeleted int64
|
||||
maxIterations := 100 // 防止无限循环
|
||||
iteration := 0
|
||||
|
||||
|
||||
for {
|
||||
// 检查迭代次数限制
|
||||
iteration++
|
||||
if iteration > maxIterations {
|
||||
r.logger.Warn("缓存删除操作达到最大迭代次数限制",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int("max_iterations", maxIterations),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
)
|
||||
break
|
||||
}
|
||||
|
||||
// 检查上下文是否已取消
|
||||
if ctx.Err() != nil {
|
||||
r.logger.Warn("缓存删除操作被取消",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
zap.Error(ctx.Err()),
|
||||
)
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// 执行SCAN操作
|
||||
keys, next, err := r.client.Scan(ctx, cursor, fullPattern, 1000).Result()
|
||||
if err != nil {
|
||||
// 如果是上下文取消错误,直接返回
|
||||
if err == context.Canceled || err == context.DeadlineExceeded {
|
||||
r.logger.Warn("缓存删除操作被取消",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Error("扫描缓存键失败",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 批量删除找到的键
|
||||
if len(keys) > 0 {
|
||||
// 使用pipeline批量删除,提高性能
|
||||
pipe := r.client.Pipeline()
|
||||
pipe.Del(ctx, keys...)
|
||||
|
||||
cmds, err := pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
r.logger.Error("批量删除缓存键失败",
|
||||
zap.Strings("keys", keys),
|
||||
zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 统计删除的键数量
|
||||
for _, cmd := range cmds {
|
||||
if delCmd, ok := cmd.(*redis.IntCmd); ok {
|
||||
if deleted, err := delCmd.Result(); err == nil {
|
||||
totalDeleted += deleted
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Debug("批量删除缓存键",
|
||||
zap.Strings("keys", keys),
|
||||
zap.Int("batch_size", len(keys)),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
)
|
||||
}
|
||||
|
||||
cursor = next
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Debug("缓存模式删除完成",
|
||||
zap.String("pattern", fullPattern),
|
||||
zap.Int64("total_deleted", totalDeleted),
|
||||
zap.Int("iterations", iteration),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys 获取匹配的键
|
||||
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 移除前缀
|
||||
result := make([]string, len(keys))
|
||||
prefixLen := len(r.prefix) + 1 // +1 for ":"
|
||||
for i, key := range keys {
|
||||
if len(key) > prefixLen {
|
||||
result[i] = key[prefixLen:]
|
||||
} else {
|
||||
result[i] = key
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计
|
||||
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
|
||||
dbSize, _ := r.client.DBSize(ctx).Result()
|
||||
|
||||
return interfaces.CacheStats{
|
||||
Hits: r.hits,
|
||||
Misses: r.misses,
|
||||
Keys: dbSize,
|
||||
Memory: 0, // 暂时设为0,后续可解析Redis info
|
||||
Connections: 0, // 暂时设为0,后续可解析Redis info
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getFullKey 获取完整键名
|
||||
func (r *RedisCache) getFullKey(key string) string {
|
||||
if r.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", r.prefix, key)
|
||||
}
|
||||
|
||||
// Flush 清空所有缓存
|
||||
func (r *RedisCache) Flush(ctx context.Context) error {
|
||||
if r.prefix == "" {
|
||||
return r.client.FlushDB(ctx).Err()
|
||||
}
|
||||
|
||||
// 只删除带前缀的键
|
||||
return r.DeletePattern(ctx, "*")
|
||||
}
|
||||
|
||||
// GetClient 获取原始Redis客户端
|
||||
func (r *RedisCache) GetClient() *redis.Client {
|
||||
return r.client
|
||||
}
|
||||
160
internal/infrastructure/database/database.go
Normal file
160
internal/infrastructure/database/database.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Name string
|
||||
SSLMode string
|
||||
Timezone string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
}
|
||||
|
||||
// DB 数据库包装器
|
||||
type DB struct {
|
||||
*gorm.DB
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewConnection 创建新的数据库连接
|
||||
func NewConnection(config Config) (*DB, error) {
|
||||
// 构建DSN
|
||||
dsn := buildDSN(config)
|
||||
|
||||
// 配置GORM
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true, // 使用单数表名
|
||||
},
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
NowFunc: func() time.Time {
|
||||
return time.Now().In(time.FixedZone("CST", 8*3600)) // 强制使用北京时间
|
||||
},
|
||||
PrepareStmt: true,
|
||||
DisableAutomaticPing: false,
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
return &DB{
|
||||
DB: db,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildDSN 构建数据库连接字符串
|
||||
func buildDSN(config Config) string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s options='-c timezone=%s'",
|
||||
config.Host,
|
||||
config.User,
|
||||
config.Password,
|
||||
config.Name,
|
||||
config.Port,
|
||||
config.SSLMode,
|
||||
config.Timezone,
|
||||
config.Timezone,
|
||||
)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// Ping 检查数据库连接
|
||||
func (db *DB) Ping() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Ping()
|
||||
}
|
||||
|
||||
// GetStats 获取连接池统计信息
|
||||
func (db *DB) GetStats() (map[string]interface{}, error) {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := sqlDB.Stats()
|
||||
return map[string]interface{}{
|
||||
"max_open_connections": stats.MaxOpenConnections,
|
||||
"open_connections": stats.OpenConnections,
|
||||
"in_use": stats.InUse,
|
||||
"idle": stats.Idle,
|
||||
"wait_count": stats.WaitCount,
|
||||
"wait_duration": stats.WaitDuration,
|
||||
"max_idle_closed": stats.MaxIdleClosed,
|
||||
"max_idle_time_closed": stats.MaxIdleTimeClosed,
|
||||
"max_lifetime_closed": stats.MaxLifetimeClosed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginTx 开始事务(已废弃,请使用shared/database.TransactionManager)
|
||||
// @deprecated 请使用 shared/database.TransactionManager
|
||||
func (db *DB) BeginTx() *gorm.DB {
|
||||
return db.DB.Begin()
|
||||
}
|
||||
|
||||
// Migrate 执行数据库迁移
|
||||
func (db *DB) Migrate(models ...interface{}) error {
|
||||
return db.DB.AutoMigrate(models...)
|
||||
}
|
||||
|
||||
// IsHealthy 检查数据库健康状态
|
||||
func (db *DB) IsHealthy() bool {
|
||||
return db.Ping() == nil
|
||||
}
|
||||
|
||||
// WithContext 返回带上下文的数据库实例
|
||||
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
|
||||
if c, ok := ctx.(context.Context); ok {
|
||||
return db.DB.WithContext(c)
|
||||
}
|
||||
return db.DB
|
||||
}
|
||||
|
||||
// 注意:事务相关功能已迁移到 shared/database.TransactionManager
|
||||
// 请使用 TransactionManager 进行事务管理
|
||||
@@ -0,0 +1,556 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/api/entities"
|
||||
"hyapi-server/internal/domains/api/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiCallsTable = "api_calls"
|
||||
ApiCallCacheTTL = 10 * time.Minute
|
||||
)
|
||||
|
||||
// ApiCallWithProduct 包含产品名称的API调用记录
|
||||
type ApiCallWithProduct struct {
|
||||
entities.ApiCall
|
||||
ProductName string `json:"product_name" gorm:"column:product_name"`
|
||||
}
|
||||
|
||||
type GormApiCallRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ApiCallRepository = (*GormApiCallRepository)(nil)
|
||||
|
||||
func NewGormApiCallRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiCallRepository {
|
||||
return &GormApiCallRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiCallsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) Create(ctx context.Context, call *entities.ApiCall) error {
|
||||
return r.CreateEntity(ctx, call)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) Update(ctx context.Context, call *entities.ApiCall) error {
|
||||
return r.UpdateEntity(ctx, call)
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindById(ctx context.Context, id string) (*entities.ApiCall, error) {
|
||||
var call entities.ApiCall
|
||||
err := r.SmartGetByID(ctx, id, &call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &call, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindByUserId(ctx context.Context, userId string, limit, offset int) ([]*entities.ApiCall, error) {
|
||||
var calls []*entities.ApiCall
|
||||
options := database.CacheListOptions{
|
||||
Where: "user_id = ?",
|
||||
Args: []interface{}{userId},
|
||||
Order: "created_at DESC",
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
err := r.ListWithCache(ctx, &calls, ApiCallCacheTTL, options)
|
||||
return calls, err
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
|
||||
var calls []*entities.ApiCall
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
|
||||
return calls, total, err
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
|
||||
var calls []*entities.ApiCall
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品ID筛选
|
||||
if productId, ok := filters["product_id"].(string); ok && productId != "" {
|
||||
whereCondition += " AND product_id = ?"
|
||||
whereArgs = append(whereArgs, productId)
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
|
||||
return calls, total, err
|
||||
}
|
||||
|
||||
// ListByUserIdWithFiltersAndProductName 根据用户ID和筛选条件获取API调用记录(包含产品名称)
|
||||
func (r *GormApiCallRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
|
||||
var callsWithProduct []*ApiCallWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "ac.user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND ac.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND ac.status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
query := r.GetDB(ctx).Table("api_calls ac").
|
||||
Select("ac.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON ac.product_id = p.id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("ac." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("ac.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&callsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.ApiCall并构建产品名称映射
|
||||
var calls []*entities.ApiCall
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, c := range callsWithProduct {
|
||||
call := c.ApiCall
|
||||
calls = append(calls, &call)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if c.ProductName != "" {
|
||||
productNameMap[call.ID] = c.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, calls, total, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ?", userId)
|
||||
}
|
||||
|
||||
// CountByUserIdAndProductId 按用户ID和产品ID统计API调用次数
|
||||
func (r *GormApiCallRepository) CountByUserIdAndProductId(ctx context.Context, userId string, productId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND product_id = ?", userId, productId)
|
||||
}
|
||||
|
||||
// CountByUserIdAndDateRange 按用户ID和日期范围统计API调用次数
|
||||
func (r *GormApiCallRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
|
||||
}
|
||||
|
||||
// GetDailyStatsByUserId 获取用户每日API调用统计
|
||||
func (r *GormApiCallRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE user_id = $1
|
||||
AND DATE(created_at) >= $2
|
||||
AND DATE(created_at) <= $3
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStatsByUserId 获取用户每月API调用统计
|
||||
func (r *GormApiCallRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE user_id = $1
|
||||
AND created_at >= $2
|
||||
AND created_at <= $3
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (r *GormApiCallRepository) FindByTransactionId(ctx context.Context, transactionId string) (*entities.ApiCall, error) {
|
||||
var call entities.ApiCall
|
||||
err := r.FindOne(ctx, &call, "transaction_id = ?", transactionId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &call, nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndProductName 管理端:根据条件筛选所有API调用记录(包含产品名称)
|
||||
func (r *GormApiCallRepository) ListWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
|
||||
var callsWithProduct []*ApiCallWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "1=1"
|
||||
whereArgs := []interface{}{}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 用户ID筛选(支持单个user_id和多个user_ids)
|
||||
// 如果同时存在,优先使用user_ids(批量查询)
|
||||
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
|
||||
// 解析逗号分隔的用户ID列表
|
||||
userIdsList := strings.Split(userIds, ",")
|
||||
// 去除空白字符
|
||||
var cleanUserIds []string
|
||||
for _, id := range userIdsList {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
cleanUserIds = append(cleanUserIds, id)
|
||||
}
|
||||
}
|
||||
if len(cleanUserIds) > 0 {
|
||||
placeholders := strings.Repeat("?,", len(cleanUserIds))
|
||||
placeholders = placeholders[:len(placeholders)-1] // 移除最后一个逗号
|
||||
whereCondition += " AND ac.user_id IN (" + placeholders + ")"
|
||||
for _, id := range cleanUserIds {
|
||||
whereArgs = append(whereArgs, id)
|
||||
}
|
||||
}
|
||||
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
|
||||
// 单个用户ID筛选
|
||||
whereCondition += " AND ac.user_id = ?"
|
||||
whereArgs = append(whereArgs, userId)
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND ac.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// TransactionID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND ac.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 企业名称筛选
|
||||
if companyName, ok := filters["company_name"].(string); ok && companyName != "" {
|
||||
whereCondition += " AND ei.company_name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+companyName+"%")
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status, ok := filters["status"].(string); ok && status != "" {
|
||||
whereCondition += " AND ac.status = ?"
|
||||
whereArgs = append(whereArgs, status)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
// 需要JOIN product表获取产品名称,JOIN users和enterprise_infos表获取企业名称
|
||||
query := r.GetDB(ctx).Table("api_calls ac").
|
||||
Select("ac.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON ac.product_id = p.id").
|
||||
Joins("LEFT JOIN users u ON ac.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("ac." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("ac.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&callsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.ApiCall并构建产品名称映射
|
||||
var calls []*entities.ApiCall
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, c := range callsWithProduct {
|
||||
call := c.ApiCall
|
||||
calls = append(calls, &call)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if c.ProductName != "" {
|
||||
productNameMap[call.ID] = c.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, calls, total, nil
|
||||
}
|
||||
|
||||
// GetSystemTotalCalls 获取系统总API调用次数
|
||||
func (r *GormApiCallRepository) GetSystemTotalCalls(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ApiCall{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetSystemCallsByDateRange 获取系统指定时间范围内的API调用次数
|
||||
// endDate 应该是结束日期当天的次日00:00:00(日统计)或下个月1号00:00:00(月统计),使用 < 而不是 <=
|
||||
func (r *GormApiCallRepository) GetSystemCallsByDateRange(ctx context.Context, startDate, endDate time.Time) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ApiCall{}).
|
||||
Where("created_at >= ? AND created_at < ?", startDate, endDate).
|
||||
Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetSystemDailyStats 获取系统每日API调用统计
|
||||
func (r *GormApiCallRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE DATE(created_at) >= $1
|
||||
AND DATE(created_at) <= $2
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyStats 获取系统每月API调用统计
|
||||
func (r *GormApiCallRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as calls
|
||||
FROM api_calls
|
||||
WHERE created_at >= $1
|
||||
AND created_at < $2
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetApiPopularityRanking 获取API受欢迎程度排行榜
|
||||
func (r *GormApiCallRepository) GetApiPopularityRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
p.id as product_id,
|
||||
p.name as api_name,
|
||||
p.description as api_description,
|
||||
COUNT(ac.id) as call_count
|
||||
FROM product p
|
||||
LEFT JOIN api_calls ac ON p.id = ac.product_id
|
||||
AND DATE(ac.created_at) = CURRENT_DATE
|
||||
WHERE p.deleted_at IS NULL
|
||||
GROUP BY p.id, p.name, p.description
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY call_count DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
p.id as product_id,
|
||||
p.name as api_name,
|
||||
p.description as api_description,
|
||||
COUNT(ac.id) as call_count
|
||||
FROM product p
|
||||
LEFT JOIN api_calls ac ON p.id = ac.product_id
|
||||
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE p.deleted_at IS NULL
|
||||
GROUP BY p.id, p.name, p.description
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY call_count DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
p.id as product_id,
|
||||
p.name as api_name,
|
||||
p.description as api_description,
|
||||
COUNT(ac.id) as call_count
|
||||
FROM product p
|
||||
LEFT JOIN api_calls ac ON p.id = ac.product_id
|
||||
WHERE p.deleted_at IS NULL
|
||||
GROUP BY p.id, p.name, p.description
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY call_count DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hyapi-server/internal/domains/api/entities"
|
||||
"hyapi-server/internal/domains/api/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ApiUsersTable = "api_users"
|
||||
ApiUserCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type GormApiUserRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ApiUserRepository = (*GormApiUserRepository)(nil)
|
||||
|
||||
func NewGormApiUserRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiUserRepository {
|
||||
return &GormApiUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiUsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) Create(ctx context.Context, user *entities.ApiUser) error {
|
||||
return r.CreateEntity(ctx, user)
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) Update(ctx context.Context, user *entities.ApiUser) error {
|
||||
return r.UpdateEntity(ctx, user)
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) FindByAccessId(ctx context.Context, accessId string) (*entities.ApiUser, error) {
|
||||
var user entities.ApiUser
|
||||
err := r.SmartGetByField(ctx, &user, "access_id", accessId, ApiUserCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormApiUserRepository) FindByUserId(ctx context.Context, userId string) (*entities.ApiUser, error) {
|
||||
var user entities.ApiUser
|
||||
err := r.SmartGetByField(ctx, &user, "user_id", userId, ApiUserCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"hyapi-server/internal/domains/api/entities"
|
||||
"hyapi-server/internal/domains/api/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ReportsTable = "reports"
|
||||
)
|
||||
|
||||
// GormReportRepository 报告记录 GORM 仓储实现
|
||||
type GormReportRepository struct {
|
||||
*database.BaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ReportRepository = (*GormReportRepository)(nil)
|
||||
|
||||
// NewGormReportRepository 创建报告记录仓储实现
|
||||
func NewGormReportRepository(db *gorm.DB, logger *zap.Logger) repositories.ReportRepository {
|
||||
return &GormReportRepository{
|
||||
BaseRepositoryImpl: database.NewBaseRepositoryImpl(db, logger),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建报告记录
|
||||
func (r *GormReportRepository) Create(ctx context.Context, report *entities.Report) error {
|
||||
return r.CreateEntity(ctx, report)
|
||||
}
|
||||
|
||||
// FindByReportID 根据报告编号查询记录
|
||||
func (r *GormReportRepository) FindByReportID(ctx context.Context, reportID string) (*entities.Report, error) {
|
||||
var report entities.Report
|
||||
if err := r.FindOneByField(ctx, &report, "report_id", reportID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &report, nil
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
repoQueries "hyapi-server/internal/domains/article/repositories/queries"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormAnnouncementRepository GORM公告仓储实现
|
||||
type GormAnnouncementRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.AnnouncementRepository = (*GormAnnouncementRepository)(nil)
|
||||
|
||||
// NewGormAnnouncementRepository 创建GORM公告仓储
|
||||
func NewGormAnnouncementRepository(db *gorm.DB, logger *zap.Logger) *GormAnnouncementRepository {
|
||||
return &GormAnnouncementRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建公告
|
||||
func (r *GormAnnouncementRepository) Create(ctx context.Context, entity entities.Announcement) (entities.Announcement, error) {
|
||||
r.logger.Info("创建公告", zap.String("id", entity.ID), zap.String("title", entity.Title))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("创建公告失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取公告
|
||||
func (r *GormAnnouncementRepository) GetByID(ctx context.Context, id string) (entities.Announcement, error) {
|
||||
var entity entities.Announcement
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("id = ?", id).
|
||||
First(&entity).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("公告不存在")
|
||||
}
|
||||
r.logger.Error("获取公告失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新公告
|
||||
func (r *GormAnnouncementRepository) Update(ctx context.Context, entity entities.Announcement) error {
|
||||
r.logger.Info("更新公告", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新公告失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除公告
|
||||
func (r *GormAnnouncementRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除公告", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除公告失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找公告
|
||||
func (r *GormAnnouncementRepository) FindByStatus(ctx context.Context, status entities.AnnouncementStatus) ([]*entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&announcements).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据状态查找公告失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Announcement, len(announcements))
|
||||
for i := range announcements {
|
||||
result[i] = &announcements[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindScheduled 查找定时发布的公告
|
||||
func (r *GormAnnouncementRepository) FindScheduled(ctx context.Context) ([]*entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
now := time.Now()
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.AnnouncementStatusDraft, now).
|
||||
Order("scheduled_at ASC").
|
||||
Find(&announcements).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("查找定时发布公告失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Announcement, len(announcements))
|
||||
for i := range announcements {
|
||||
result[i] = &announcements[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListAnnouncements 获取公告列表
|
||||
func (r *GormAnnouncementRepository) ListAnnouncements(ctx context.Context, query *repoQueries.ListAnnouncementQuery) ([]*entities.Announcement, int64, error) {
|
||||
var announcements []entities.Announcement
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.Title != "" {
|
||||
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取公告列表总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&announcements).Error; err != nil {
|
||||
r.logger.Error("获取公告列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Announcement, len(announcements))
|
||||
for i := range announcements {
|
||||
result[i] = &announcements[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// CountByStatus 根据状态统计公告数量
|
||||
func (r *GormAnnouncementRepository) CountByStatus(ctx context.Context, status entities.AnnouncementStatus) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.Announcement{}).
|
||||
Where("status = ?", status).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("统计公告数量失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// UpdateStatistics 更新统计信息
|
||||
// 注意:公告实体目前没有统计字段,此方法预留扩展
|
||||
func (r *GormAnnouncementRepository) UpdateStatistics(ctx context.Context, announcementID string) error {
|
||||
r.logger.Info("更新公告统计信息", zap.String("announcement_id", announcementID))
|
||||
// TODO: 如果将来需要统计字段(如阅读量等),可以在这里实现
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 实现 BaseRepository 接口的其他方法 ================
|
||||
|
||||
// Count 统计数量
|
||||
func (r *GormAnnouncementRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查是否存在
|
||||
func (r *GormAnnouncementRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Announcement{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除
|
||||
func (r *GormAnnouncementRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除
|
||||
func (r *GormAnnouncementRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Announcement{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建
|
||||
func (r *GormAnnouncementRepository) CreateBatch(ctx context.Context, entities []entities.Announcement) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取
|
||||
func (r *GormAnnouncementRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&announcements).Error
|
||||
return announcements, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新
|
||||
func (r *GormAnnouncementRepository) UpdateBatch(ctx context.Context, entities []entities.Announcement) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除
|
||||
func (r *GormAnnouncementRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 列表查询
|
||||
func (r *GormAnnouncementRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Announcement, error) {
|
||||
var announcements []entities.Announcement
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = strings.ToUpper(options.Order)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&announcements).Error
|
||||
return announcements, err
|
||||
}
|
||||
@@ -0,0 +1,592 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
repoQueries "hyapi-server/internal/domains/article/repositories/queries"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormArticleRepository GORM文章仓储实现
|
||||
type GormArticleRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ArticleRepository = (*GormArticleRepository)(nil)
|
||||
|
||||
// NewGormArticleRepository 创建GORM文章仓储
|
||||
func NewGormArticleRepository(db *gorm.DB, logger *zap.Logger) *GormArticleRepository {
|
||||
return &GormArticleRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建文章
|
||||
func (r *GormArticleRepository) Create(ctx context.Context, entity entities.Article) (entities.Article, error) {
|
||||
r.logger.Info("创建文章", zap.String("id", entity.ID), zap.String("title", entity.Title))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建文章失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取文章
|
||||
func (r *GormArticleRepository) GetByID(ctx context.Context, id string) (entities.Article, error) {
|
||||
var entity entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("id = ?", id).
|
||||
First(&entity).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("文章不存在")
|
||||
}
|
||||
r.logger.Error("获取文章失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新文章
|
||||
func (r *GormArticleRepository) Update(ctx context.Context, entity entities.Article) error {
|
||||
r.logger.Info("更新文章", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新文章失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除文章
|
||||
func (r *GormArticleRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除文章", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Article{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除文章失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByAuthorID 根据作者ID查找文章
|
||||
func (r *GormArticleRepository) FindByAuthorID(ctx context.Context, authorID string) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("author_id = ?", authorID).
|
||||
Order("created_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据作者ID查找文章失败", zap.String("author_id", authorID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根据分类ID查找文章
|
||||
func (r *GormArticleRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("category_id = ?", categoryID).
|
||||
Order("created_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据分类ID查找文章失败", zap.String("category_id", categoryID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找文章
|
||||
func (r *GormArticleRepository) FindByStatus(ctx context.Context, status entities.ArticleStatus) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据状态查找文章失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindFeatured 查找推荐文章
|
||||
func (r *GormArticleRepository) FindFeatured(ctx context.Context) ([]*entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Category").
|
||||
Preload("Tags").
|
||||
Where("is_featured = ? AND status = ?", true, entities.ArticleStatusPublished).
|
||||
Order("published_at DESC").
|
||||
Find(&articles).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("查找推荐文章失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Search 搜索文章
|
||||
func (r *GormArticleRepository) Search(ctx context.Context, query *repoQueries.SearchArticleQuery) ([]*entities.Article, int64, error) {
|
||||
var articles []entities.Article
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用搜索条件
|
||||
if query.Keyword != "" {
|
||||
keyword := "%" + query.Keyword + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ? OR summary LIKE ?", keyword, keyword, keyword)
|
||||
}
|
||||
|
||||
if query.CategoryID != "" {
|
||||
// 如果指定了分类ID,只查询该分类的文章(包括没有分类的文章,当CategoryID为空字符串时)
|
||||
if query.CategoryID == "null" || query.CategoryID == "" {
|
||||
// 查询没有分类的文章
|
||||
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
|
||||
} else {
|
||||
// 查询指定分类的文章
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
}
|
||||
|
||||
if query.AuthorID != "" {
|
||||
dbQuery = dbQuery.Where("author_id = ?", query.AuthorID)
|
||||
}
|
||||
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取搜索结果总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
dbQuery = dbQuery.Preload("Category").Preload("Tags")
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&articles).Error; err != nil {
|
||||
r.logger.Error("搜索文章失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// ListArticles 获取文章列表(用户端)
|
||||
func (r *GormArticleRepository) ListArticles(ctx context.Context, query *repoQueries.ListArticleQuery) ([]*entities.Article, int64, error) {
|
||||
var articles []entities.Article
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 用户端不显示归档文章
|
||||
dbQuery = dbQuery.Where("status != ?", entities.ArticleStatusArchived)
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.CategoryID != "" {
|
||||
// 如果指定了分类ID,只查询该分类的文章(包括没有分类的文章,当CategoryID为空字符串时)
|
||||
if query.CategoryID == "null" || query.CategoryID == "" {
|
||||
// 查询没有分类的文章
|
||||
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
|
||||
} else {
|
||||
// 查询指定分类的文章
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
}
|
||||
|
||||
if query.TagID != "" {
|
||||
// 如果指定了标签ID,只查询有关联该标签的文章
|
||||
// 使用子查询而不是JOIN,避免影响其他查询条件
|
||||
subQuery := r.db.WithContext(ctx).Table("article_tag_relations").
|
||||
Select("article_id").
|
||||
Where("tag_id = ?", query.TagID)
|
||||
dbQuery = dbQuery.Where("id IN (?)", subQuery)
|
||||
}
|
||||
|
||||
if query.Title != "" {
|
||||
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
|
||||
}
|
||||
|
||||
if query.Summary != "" {
|
||||
dbQuery = dbQuery.Where("summary ILIKE ?", "%"+query.Summary+"%")
|
||||
}
|
||||
|
||||
if query.IsFeatured != nil {
|
||||
dbQuery = dbQuery.Where("is_featured = ?", *query.IsFeatured)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取文章列表总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
dbQuery = dbQuery.Preload("Category").Preload("Tags")
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&articles).Error; err != nil {
|
||||
r.logger.Error("获取文章列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// ListArticlesForAdmin 获取文章列表(管理员端)
|
||||
func (r *GormArticleRepository) ListArticlesForAdmin(ctx context.Context, query *repoQueries.ListArticleQuery) ([]*entities.Article, int64, error) {
|
||||
var articles []entities.Article
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
|
||||
if query.CategoryID != "" {
|
||||
// 如果指定了分类ID,只查询该分类的文章(包括没有分类的文章,当CategoryID为空字符串时)
|
||||
if query.CategoryID == "null" || query.CategoryID == "" {
|
||||
// 查询没有分类的文章
|
||||
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
|
||||
} else {
|
||||
// 查询指定分类的文章
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
}
|
||||
|
||||
if query.TagID != "" {
|
||||
// 如果指定了标签ID,只查询有关联该标签的文章
|
||||
// 使用子查询而不是JOIN,避免影响其他查询条件
|
||||
subQuery := r.db.WithContext(ctx).Table("article_tag_relations").
|
||||
Select("article_id").
|
||||
Where("tag_id = ?", query.TagID)
|
||||
dbQuery = dbQuery.Where("id IN (?)", subQuery)
|
||||
}
|
||||
|
||||
if query.Title != "" {
|
||||
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
|
||||
}
|
||||
|
||||
if query.Summary != "" {
|
||||
dbQuery = dbQuery.Where("summary ILIKE ?", "%"+query.Summary+"%")
|
||||
}
|
||||
|
||||
if query.IsFeatured != nil {
|
||||
dbQuery = dbQuery.Where("is_featured = ?", *query.IsFeatured)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取文章列表总数失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.OrderBy != "" {
|
||||
orderDir := "DESC"
|
||||
if query.OrderDir != "" {
|
||||
orderDir = strings.ToUpper(query.OrderDir)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
dbQuery = dbQuery.Preload("Category").Preload("Tags")
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&articles).Error; err != nil {
|
||||
r.logger.Error("获取文章列表失败", zap.Error(err))
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Article, len(articles))
|
||||
for i := range articles {
|
||||
result[i] = &articles[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// CountByCategoryID 统计分类文章数量
|
||||
func (r *GormArticleRepository) CountByCategoryID(ctx context.Context, categoryID string) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.Article{}).
|
||||
Where("category_id = ?", categoryID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("统计分类文章数量失败", zap.String("category_id", categoryID), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByStatus 统计状态文章数量
|
||||
func (r *GormArticleRepository) CountByStatus(ctx context.Context, status entities.ArticleStatus) (int64, error) {
|
||||
var count int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
if status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", status)
|
||||
}
|
||||
|
||||
err := dbQuery.Count(&count).Error
|
||||
if err != nil {
|
||||
r.logger.Error("统计状态文章数量失败", zap.String("status", string(status)), zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// IncrementViewCount 增加阅读量
|
||||
func (r *GormArticleRepository) IncrementViewCount(ctx context.Context, articleID string) error {
|
||||
err := r.db.WithContext(ctx).Model(&entities.Article{}).
|
||||
Where("id = ?", articleID).
|
||||
UpdateColumn("view_count", gorm.Expr("view_count + ?", 1)).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("增加阅读量失败", zap.String("article_id", articleID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
// 实现 BaseRepository 接口的其他方法
|
||||
func (r *GormArticleRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Article{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Article{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Article{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) CreateBatch(ctx context.Context, entities []entities.Article) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&articles).Error
|
||||
return articles, err
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) UpdateBatch(ctx context.Context, entities []entities.Article) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Article{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormArticleRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Article, error) {
|
||||
var articles []entities.Article
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = strings.ToUpper(options.Order)
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&articles).Error
|
||||
return articles, err
|
||||
}
|
||||
@@ -0,0 +1,247 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormCategoryRepository GORM分类仓储实现
|
||||
type GormCategoryRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CategoryRepository = (*GormCategoryRepository)(nil)
|
||||
|
||||
// NewGormCategoryRepository 创建GORM分类仓储
|
||||
func NewGormCategoryRepository(db *gorm.DB, logger *zap.Logger) *GormCategoryRepository {
|
||||
return &GormCategoryRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建分类
|
||||
func (r *GormCategoryRepository) Create(ctx context.Context, entity entities.Category) (entities.Category, error) {
|
||||
r.logger.Info("创建分类", zap.String("id", entity.ID), zap.String("name", entity.Name))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建分类失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取分类
|
||||
func (r *GormCategoryRepository) GetByID(ctx context.Context, id string) (entities.Category, error) {
|
||||
var entity entities.Category
|
||||
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("分类不存在")
|
||||
}
|
||||
r.logger.Error("获取分类失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新分类
|
||||
func (r *GormCategoryRepository) Update(ctx context.Context, entity entities.Category) error {
|
||||
r.logger.Info("更新分类", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新分类失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除分类
|
||||
func (r *GormCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除分类", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Category{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除分类失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindActive 查找启用的分类
|
||||
func (r *GormCategoryRepository) FindActive(ctx context.Context) ([]*entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("active = ?", true).
|
||||
Order("sort_order ASC, created_at ASC").
|
||||
Find(&categories).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("查找启用分类失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Category, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindBySortOrder 按排序查找分类
|
||||
func (r *GormCategoryRepository) FindBySortOrder(ctx context.Context) ([]*entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Order("sort_order ASC, created_at ASC").
|
||||
Find(&categories).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("按排序查找分类失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Category, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CountActive 统计启用分类数量
|
||||
func (r *GormCategoryRepository) CountActive(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
|
||||
err := r.db.WithContext(ctx).Model(&entities.Category{}).
|
||||
Where("active = ?", true).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("统计启用分类数量失败", zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// 实现 BaseRepository 接口的其他方法
|
||||
func (r *GormCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Category{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ?", search, search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Category{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Category{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Category{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) CreateBatch(ctx context.Context, entities []entities.Category) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) UpdateBatch(ctx context.Context, entities []entities.Category) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Category{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Category, error) {
|
||||
var categories []entities.Category
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Category{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ?", search, search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("sort_order ASC, created_at ASC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormScheduledTaskRepository GORM定时任务仓储实现
|
||||
type GormScheduledTaskRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ScheduledTaskRepository = (*GormScheduledTaskRepository)(nil)
|
||||
|
||||
// NewGormScheduledTaskRepository 创建GORM定时任务仓储
|
||||
func NewGormScheduledTaskRepository(db *gorm.DB, logger *zap.Logger) *GormScheduledTaskRepository {
|
||||
return &GormScheduledTaskRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建定时任务记录
|
||||
func (r *GormScheduledTaskRepository) Create(ctx context.Context, task entities.ScheduledTask) (entities.ScheduledTask, error) {
|
||||
r.logger.Info("创建定时任务记录", zap.String("task_id", task.TaskID), zap.String("article_id", task.ArticleID))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&task).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建定时任务记录失败", zap.Error(err))
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// GetByTaskID 根据Asynq任务ID获取任务记录
|
||||
func (r *GormScheduledTaskRepository) GetByTaskID(ctx context.Context, taskID string) (entities.ScheduledTask, error) {
|
||||
var task entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("task_id = ?", taskID).
|
||||
First(&task).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return task, fmt.Errorf("定时任务不存在")
|
||||
}
|
||||
r.logger.Error("获取定时任务失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// GetByArticleID 根据文章ID获取任务记录
|
||||
func (r *GormScheduledTaskRepository) GetByArticleID(ctx context.Context, articleID string) (entities.ScheduledTask, error) {
|
||||
var task entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("article_id = ? AND status IN (?)", articleID, []string{"pending", "running"}).
|
||||
First(&task).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return task, fmt.Errorf("文章没有活动的定时任务")
|
||||
}
|
||||
r.logger.Error("获取文章定时任务失败", zap.String("article_id", articleID), zap.Error(err))
|
||||
return task, err
|
||||
}
|
||||
|
||||
return task, nil
|
||||
}
|
||||
|
||||
// Update 更新任务记录
|
||||
func (r *GormScheduledTaskRepository) Update(ctx context.Context, task entities.ScheduledTask) error {
|
||||
r.logger.Info("更新定时任务记录", zap.String("task_id", task.TaskID), zap.String("status", string(task.Status)))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&task).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新定时任务记录失败", zap.String("task_id", task.TaskID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除任务记录
|
||||
func (r *GormScheduledTaskRepository) Delete(ctx context.Context, taskID string) error {
|
||||
r.logger.Info("删除定时任务记录", zap.String("task_id", taskID))
|
||||
|
||||
err := r.db.WithContext(ctx).Where("task_id = ?", taskID).Delete(&entities.ScheduledTask{}).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除定时任务记录失败", zap.String("task_id", taskID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAsCancelled 标记任务为已取消
|
||||
func (r *GormScheduledTaskRepository) MarkAsCancelled(ctx context.Context, taskID string) error {
|
||||
r.logger.Info("标记定时任务为已取消", zap.String("task_id", taskID))
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.ScheduledTask{}).
|
||||
Where("task_id = ? AND status IN (?)", taskID, []string{"pending", "running"}).
|
||||
Updates(map[string]interface{}{
|
||||
"status": entities.TaskStatusCancelled,
|
||||
"completed_at": time.Now(),
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
r.logger.Error("标记定时任务为已取消失败", zap.String("task_id", taskID), zap.Error(result.Error))
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
r.logger.Warn("没有找到需要取消的定时任务", zap.String("task_id", taskID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActiveTasks 获取活动状态的任务列表
|
||||
func (r *GormScheduledTaskRepository) GetActiveTasks(ctx context.Context) ([]entities.ScheduledTask, error) {
|
||||
var tasks []entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("status IN (?)", []string{"pending", "running"}).
|
||||
Order("scheduled_at ASC").
|
||||
Find(&tasks).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("获取活动定时任务列表失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
|
||||
// GetExpiredTasks 获取过期的任务列表
|
||||
func (r *GormScheduledTaskRepository) GetExpiredTasks(ctx context.Context) ([]entities.ScheduledTask, error) {
|
||||
var tasks []entities.ScheduledTask
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Preload("Article").
|
||||
Where("status = ? AND scheduled_at < ?", entities.TaskStatusPending, time.Now()).
|
||||
Order("scheduled_at ASC").
|
||||
Find(&tasks).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("获取过期定时任务列表失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tasks, nil
|
||||
}
|
||||
@@ -0,0 +1,279 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/article/entities"
|
||||
"hyapi-server/internal/domains/article/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormTagRepository GORM标签仓储实现
|
||||
type GormTagRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.TagRepository = (*GormTagRepository)(nil)
|
||||
|
||||
// NewGormTagRepository 创建GORM标签仓储
|
||||
func NewGormTagRepository(db *gorm.DB, logger *zap.Logger) *GormTagRepository {
|
||||
return &GormTagRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建标签
|
||||
func (r *GormTagRepository) Create(ctx context.Context, entity entities.Tag) (entities.Tag, error) {
|
||||
r.logger.Info("创建标签", zap.String("id", entity.ID), zap.String("name", entity.Name))
|
||||
|
||||
err := r.db.WithContext(ctx).Create(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("创建标签失败", zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取标签
|
||||
func (r *GormTagRepository) GetByID(ctx context.Context, id string) (entities.Tag, error) {
|
||||
var entity entities.Tag
|
||||
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entity, fmt.Errorf("标签不存在")
|
||||
}
|
||||
r.logger.Error("获取标签失败", zap.String("id", id), zap.Error(err))
|
||||
return entity, err
|
||||
}
|
||||
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
// Update 更新标签
|
||||
func (r *GormTagRepository) Update(ctx context.Context, entity entities.Tag) error {
|
||||
r.logger.Info("更新标签", zap.String("id", entity.ID))
|
||||
|
||||
err := r.db.WithContext(ctx).Save(&entity).Error
|
||||
if err != nil {
|
||||
r.logger.Error("更新标签失败", zap.String("id", entity.ID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除标签
|
||||
func (r *GormTagRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除标签", zap.String("id", id))
|
||||
|
||||
err := r.db.WithContext(ctx).Delete(&entities.Tag{}, "id = ?", id).Error
|
||||
if err != nil {
|
||||
r.logger.Error("删除标签失败", zap.String("id", id), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByArticleID 根据文章ID查找标签
|
||||
func (r *GormTagRepository) FindByArticleID(ctx context.Context, articleID string) ([]*entities.Tag, error) {
|
||||
var tags []entities.Tag
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("JOIN article_tag_relations ON article_tag_relations.tag_id = tags.id").
|
||||
Where("article_tag_relations.article_id = ?", articleID).
|
||||
Find(&tags).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("根据文章ID查找标签失败", zap.String("article_id", articleID), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Tag, len(tags))
|
||||
for i := range tags {
|
||||
result[i] = &tags[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByName 根据名称查找标签
|
||||
func (r *GormTagRepository) FindByName(ctx context.Context, name string) (*entities.Tag, error) {
|
||||
var tag entities.Tag
|
||||
|
||||
err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
r.logger.Error("根据名称查找标签失败", zap.String("name", name), zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tag, nil
|
||||
}
|
||||
|
||||
// AddTagToArticle 为文章添加标签
|
||||
func (r *GormTagRepository) AddTagToArticle(ctx context.Context, articleID string, tagID string) error {
|
||||
// 检查关联是否已存在
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Table("article_tag_relations").
|
||||
Where("article_id = ? AND tag_id = ?", articleID, tagID).
|
||||
Count(&count).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("检查标签关联失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
// 关联已存在,不需要重复添加
|
||||
return nil
|
||||
}
|
||||
|
||||
// 创建关联
|
||||
err = r.db.WithContext(ctx).Exec(`
|
||||
INSERT INTO article_tag_relations (article_id, tag_id)
|
||||
VALUES (?, ?)
|
||||
`, articleID, tagID).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("添加标签到文章失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("添加标签到文章成功", zap.String("article_id", articleID), zap.String("tag_id", tagID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTagFromArticle 从文章移除标签
|
||||
func (r *GormTagRepository) RemoveTagFromArticle(ctx context.Context, articleID string, tagID string) error {
|
||||
err := r.db.WithContext(ctx).Exec(`
|
||||
DELETE FROM article_tag_relations
|
||||
WHERE article_id = ? AND tag_id = ?
|
||||
`, articleID, tagID).Error
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("从文章移除标签失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("从文章移除标签成功", zap.String("article_id", articleID), zap.String("tag_id", tagID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetArticleTags 获取文章的所有标签
|
||||
func (r *GormTagRepository) GetArticleTags(ctx context.Context, articleID string) ([]*entities.Tag, error) {
|
||||
return r.FindByArticleID(ctx, articleID)
|
||||
}
|
||||
|
||||
// 实现 BaseRepository 接口的其他方法
|
||||
func (r *GormTagRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Tag{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ?", search)
|
||||
}
|
||||
|
||||
var count int64
|
||||
err := dbQuery.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Tag{}).
|
||||
Where("id = ?", id).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Tag{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Tag{}).
|
||||
Where("id = ?", id).
|
||||
Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) CreateBatch(ctx context.Context, entities []entities.Tag) error {
|
||||
return r.db.WithContext(ctx).Create(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Tag, error) {
|
||||
var tags []entities.Tag
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&tags).Error
|
||||
return tags, err
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) UpdateBatch(ctx context.Context, entities []entities.Tag) error {
|
||||
return r.db.WithContext(ctx).Save(&entities).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.Tag{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormTagRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Tag, error) {
|
||||
var tags []entities.Tag
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Tag{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
dbQuery = dbQuery.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
search := "%" + options.Search + "%"
|
||||
dbQuery = dbQuery.Where("name LIKE ?", search)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "DESC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at ASC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 预加载关联数据
|
||||
if len(options.Include) > 0 {
|
||||
for _, include := range options.Include {
|
||||
dbQuery = dbQuery.Preload(include)
|
||||
}
|
||||
}
|
||||
|
||||
err := dbQuery.Find(&tags).Error
|
||||
return tags, err
|
||||
}
|
||||
@@ -0,0 +1,370 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/certification/entities"
|
||||
"hyapi-server/internal/domains/certification/enums"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// ================ 常量定义 ================
|
||||
|
||||
const (
|
||||
// 表名常量
|
||||
CertificationsTable = "certifications"
|
||||
|
||||
// 缓存时间常量
|
||||
CacheTTLPrimaryQuery = 30 * time.Minute // 主键查询缓存时间
|
||||
CacheTTLBusinessQuery = 15 * time.Minute // 业务查询缓存时间
|
||||
CacheTTLUserQuery = 10 * time.Minute // 用户相关查询缓存时间
|
||||
CacheTTLWarmupLong = 30 * time.Minute // 预热长期缓存
|
||||
CacheTTLWarmupMedium = 15 * time.Minute // 预热中期缓存
|
||||
|
||||
// 缓存键模式常量
|
||||
CachePatternTable = "gorm_cache:certifications:*"
|
||||
CachePatternUser = "certification:user_id:*"
|
||||
)
|
||||
|
||||
// ================ Repository 实现 ================
|
||||
|
||||
// GormCertificationCommandRepository 认证命令仓储GORM实现
|
||||
//
|
||||
// 特性说明:
|
||||
// - 基于 CachedBaseRepositoryImpl 实现自动缓存管理
|
||||
// - 支持多级缓存策略(主键查询30分钟,业务查询15分钟)
|
||||
// - 自动缓存失效:写操作时自动清理相关缓存
|
||||
// - 智能缓存选择:根据查询复杂度自动选择缓存策略
|
||||
// - 内置监控支持:提供缓存统计和性能监控
|
||||
type GormCertificationCommandRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CertificationCommandRepository = (*GormCertificationCommandRepository)(nil)
|
||||
|
||||
// NewGormCertificationCommandRepository 创建认证命令仓储
|
||||
//
|
||||
// 参数:
|
||||
// - db: GORM数据库连接实例
|
||||
// - logger: 日志记录器
|
||||
//
|
||||
// 返回:
|
||||
// - repositories.CertificationCommandRepository: 仓储接口实现
|
||||
func NewGormCertificationCommandRepository(db *gorm.DB, logger *zap.Logger) repositories.CertificationCommandRepository {
|
||||
return &GormCertificationCommandRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, CertificationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建认证
|
||||
//
|
||||
// 业务说明:
|
||||
// - 创建新的认证申请
|
||||
// - 自动触发相关缓存失效
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - cert: 认证实体
|
||||
//
|
||||
// 返回:
|
||||
// - error: 创建失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) Create(ctx context.Context, cert entities.Certification) error {
|
||||
r.GetLogger().Info("创建认证申请",
|
||||
zap.String("user_id", cert.UserID),
|
||||
zap.String("status", string(cert.Status)))
|
||||
|
||||
return r.CreateEntity(ctx, &cert)
|
||||
}
|
||||
|
||||
// Update 更新认证
|
||||
//
|
||||
// 缓存影响:
|
||||
// - GORM缓存插件会自动失效相关缓存
|
||||
// - 无需手动管理缓存一致性
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - cert: 认证实体
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) Update(ctx context.Context, cert entities.Certification) error {
|
||||
r.GetLogger().Info("更新认证",
|
||||
zap.String("id", cert.ID),
|
||||
zap.String("status", string(cert.Status)))
|
||||
|
||||
return r.UpdateEntity(ctx, &cert)
|
||||
}
|
||||
|
||||
// Delete 删除认证
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
//
|
||||
// 返回:
|
||||
// - error: 删除失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) Delete(ctx context.Context, id string) error {
|
||||
r.GetLogger().Info("删除认证", zap.String("id", id))
|
||||
return r.DeleteEntity(ctx, id, &entities.Certification{})
|
||||
}
|
||||
|
||||
// ================ 业务特定的更新操作 ================
|
||||
|
||||
// UpdateStatus 更新认证状态
|
||||
//
|
||||
// 业务说明:
|
||||
// - 更新认证的状态
|
||||
// - 自动更新时间戳
|
||||
//
|
||||
// 缓存影响:
|
||||
// - GORM缓存插件会自动失效表相关的缓存
|
||||
// - 状态更新会影响列表查询和统计结果
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - status: 新状态
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateStatus(ctx context.Context, id string, status enums.CertificationStatus) error {
|
||||
r.GetLogger().Info("更新认证状态",
|
||||
zap.String("id", id),
|
||||
zap.String("status", string(status)))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateAuthFlowID 更新认证流程ID
|
||||
//
|
||||
// 业务说明:
|
||||
// - 记录e签宝企业认证流程ID
|
||||
// - 用于回调处理和状态跟踪
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - authFlowID: 认证流程ID
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateAuthFlowID(ctx context.Context, id string, authFlowID string) error {
|
||||
r.GetLogger().Info("更新认证流程ID",
|
||||
zap.String("id", id),
|
||||
zap.String("auth_flow_id", authFlowID))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"auth_flow_id": authFlowID,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateContractInfo 更新合同信息
|
||||
//
|
||||
// 业务说明:
|
||||
// - 记录合同相关的ID和URL信息
|
||||
// - 用于合同管理和用户下载
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - contractFileID: 合同文件ID
|
||||
// - esignFlowID: e签宝流程ID
|
||||
// - contractURL: 合同URL
|
||||
// - contractSignURL: 合同签署URL
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateContractInfo(ctx context.Context, id string, contractFileID, esignFlowID, contractURL, contractSignURL string) error {
|
||||
r.GetLogger().Info("更新合同信息",
|
||||
zap.String("id", id),
|
||||
zap.String("contract_file_id", contractFileID),
|
||||
zap.String("esign_flow_id", esignFlowID))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"contract_file_id": contractFileID,
|
||||
"esign_flow_id": esignFlowID,
|
||||
"contract_url": contractURL,
|
||||
"contract_sign_url": contractSignURL,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// UpdateFailureInfo 更新失败信息
|
||||
//
|
||||
// 业务说明:
|
||||
// - 记录认证失败的原因和详细信息
|
||||
// - 用于错误分析和用户提示
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
// - reason: 失败原因
|
||||
// - message: 失败详细信息
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) UpdateFailureInfo(ctx context.Context, id string, reason enums.FailureReason, message string) error {
|
||||
r.GetLogger().Info("更新失败信息",
|
||||
zap.String("id", id),
|
||||
zap.String("reason", string(reason)),
|
||||
zap.String("message", message))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"failure_reason": reason,
|
||||
"failure_message": message,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
return r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// ================ 批量操作 ================
|
||||
|
||||
// BatchUpdateStatus 批量更新状态
|
||||
//
|
||||
// 业务说明:
|
||||
// - 批量更新多个认证的状态
|
||||
// - 适用于管理员批量操作
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - ids: 认证ID列表
|
||||
// - status: 新状态
|
||||
//
|
||||
// 返回:
|
||||
// - error: 更新失败时的错误信息
|
||||
func (r *GormCertificationCommandRepository) BatchUpdateStatus(ctx context.Context, ids []string, status enums.CertificationStatus) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("批量更新状态:ID列表不能为空")
|
||||
}
|
||||
|
||||
r.GetLogger().Info("批量更新认证状态",
|
||||
zap.Strings("ids", ids),
|
||||
zap.String("status", string(status)))
|
||||
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
"updated_at": time.Now(),
|
||||
}
|
||||
|
||||
result := r.GetDB(ctx).Model(&entities.Certification{}).
|
||||
Where("id IN ?", ids).
|
||||
Updates(updates)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量更新认证状态失败: %w", result.Error)
|
||||
}
|
||||
|
||||
r.GetLogger().Info("批量更新完成", zap.Int64("affected_rows", result.RowsAffected))
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 事务支持 ================
|
||||
|
||||
// WithTx 使用事务
|
||||
//
|
||||
// 业务说明:
|
||||
// - 返回支持事务的仓储实例
|
||||
// - 用于复杂业务操作的事务一致性保证
|
||||
//
|
||||
// 参数:
|
||||
// - tx: 事务对象
|
||||
//
|
||||
// 返回:
|
||||
// - repositories.CertificationCommandRepository: 支持事务的仓储实例
|
||||
func (r *GormCertificationCommandRepository) WithTx(tx interfaces.Transaction) repositories.CertificationCommandRepository {
|
||||
// 获取事务的底层*gorm.DB
|
||||
txDB := tx.GetDB()
|
||||
if gormDB, ok := txDB.(*gorm.DB); ok {
|
||||
return &GormCertificationCommandRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormDB, r.GetLogger(), CertificationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
r.GetLogger().Warn("不支持的事务类型,返回原始仓储")
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 缓存管理方法 ================
|
||||
|
||||
// WarmupCache 预热认证缓存
|
||||
//
|
||||
// 业务说明:
|
||||
// - 系统启动时预热常用查询的缓存
|
||||
// - 提升首次访问的响应速度
|
||||
//
|
||||
// 预热策略:
|
||||
// - 活跃认证:30分钟长期缓存
|
||||
// - 最近创建:15分钟中期缓存
|
||||
func (r *GormCertificationCommandRepository) WarmupCache(ctx context.Context) error {
|
||||
r.GetLogger().Info("开始预热认证缓存")
|
||||
|
||||
queries := []database.WarmupQuery{
|
||||
{
|
||||
Name: "active_certifications",
|
||||
TTL: CacheTTLWarmupLong,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
{
|
||||
Name: "recent_certifications",
|
||||
TTL: CacheTTLWarmupMedium,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
}
|
||||
|
||||
return r.WarmupCommonQueries(ctx, queries)
|
||||
}
|
||||
|
||||
// RefreshCache 刷新认证缓存
|
||||
//
|
||||
// 业务说明:
|
||||
// - 手动刷新认证相关的所有缓存
|
||||
// - 适用于数据迁移或批量更新后的缓存清理
|
||||
func (r *GormCertificationCommandRepository) RefreshCache(ctx context.Context) error {
|
||||
r.GetLogger().Info("刷新认证缓存")
|
||||
return r.CachedBaseRepositoryImpl.RefreshCache(ctx, CachePatternTable)
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计信息
|
||||
//
|
||||
// 返回当前Repository的缓存使用统计,包括:
|
||||
// - 基础缓存信息(命中率、键数量等)
|
||||
// - 特定的缓存模式列表
|
||||
// - 性能指标
|
||||
func (r *GormCertificationCommandRepository) GetCacheStats() map[string]interface{} {
|
||||
stats := r.GetCacheInfo()
|
||||
stats["specific_patterns"] = []string{
|
||||
CachePatternTable,
|
||||
CachePatternUser,
|
||||
}
|
||||
return stats
|
||||
}
|
||||
@@ -0,0 +1,469 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/certification/entities"
|
||||
"hyapi-server/internal/domains/certification/enums"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/domains/certification/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// ================ 常量定义 ================
|
||||
|
||||
const (
|
||||
// 缓存时间常量
|
||||
QueryCacheTTLPrimaryQuery = 30 * time.Minute // 主键查询缓存时间
|
||||
QueryCacheTTLBusinessQuery = 15 * time.Minute // 业务查询缓存时间
|
||||
QueryCacheTTLUserQuery = 10 * time.Minute // 用户相关查询缓存时间
|
||||
QueryCacheTTLSearchQuery = 2 * time.Minute // 搜索查询缓存时间
|
||||
QueryCacheTTLActiveRecords = 5 * time.Minute // 活跃记录查询缓存时间
|
||||
QueryCacheTTLWarmupLong = 30 * time.Minute // 预热长期缓存
|
||||
QueryCacheTTLWarmupMedium = 15 * time.Minute // 预热中期缓存
|
||||
|
||||
// 缓存键模式常量
|
||||
QueryCachePatternTable = "gorm_cache:certifications:*"
|
||||
QueryCachePatternUser = "certification:user_id:*"
|
||||
)
|
||||
|
||||
// ================ Repository 实现 ================
|
||||
|
||||
// GormCertificationQueryRepository 认证查询仓储GORM实现
|
||||
//
|
||||
// 特性说明:
|
||||
// - 基于 CachedBaseRepositoryImpl 实现自动缓存管理
|
||||
// - 支持多级缓存策略(主键查询30分钟,业务查询15分钟,搜索2分钟)
|
||||
// - 自动缓存失效:写操作时自动清理相关缓存
|
||||
// - 智能缓存选择:根据查询复杂度自动选择缓存策略
|
||||
// - 内置监控支持:提供缓存统计和性能监控
|
||||
type GormCertificationQueryRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CertificationQueryRepository = (*GormCertificationQueryRepository)(nil)
|
||||
|
||||
// NewGormCertificationQueryRepository 创建认证查询仓储
|
||||
//
|
||||
// 参数:
|
||||
// - db: GORM数据库连接实例
|
||||
// - logger: 日志记录器
|
||||
//
|
||||
// 返回:
|
||||
// - repositories.CertificationQueryRepository: 仓储接口实现
|
||||
func NewGormCertificationQueryRepository(
|
||||
db *gorm.DB,
|
||||
logger *zap.Logger,
|
||||
) repositories.CertificationQueryRepository {
|
||||
return &GormCertificationQueryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, CertificationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础查询操作 ================
|
||||
|
||||
// GetByID 根据ID获取认证
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 使用智能主键查询,自动缓存30分钟
|
||||
// - 主键查询命中率高,适合长期缓存
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - id: 认证ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) GetByID(ctx context.Context, id string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
if err := r.SmartGetByID(ctx, id, &cert); err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("认证记录不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询认证记录失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取认证
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 业务查询,缓存15分钟
|
||||
// - 用户查询频率较高,适合中期缓存
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - userID: 用户ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) GetByUserID(ctx context.Context, userID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.SmartGetByField(ctx, &cert, "user_id", userID, QueryCacheTTLUserQuery)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("用户尚未创建认证申请")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户认证记录失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// Exists 检查认证是否存在
|
||||
func (r *GormCertificationQueryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.Certification{})
|
||||
}
|
||||
|
||||
func (r *GormCertificationQueryRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Certification{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("查询用户认证是否存在失败: %w", err)
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// ================ 列表查询 ================
|
||||
|
||||
// List 分页列表查询
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 搜索查询:短期缓存2分钟(避免频繁数据库查询但保证实时性)
|
||||
// - 常规列表:智能缓存(根据查询复杂度自动选择缓存策略)
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - query: 列表查询条件
|
||||
//
|
||||
// 返回:
|
||||
// - []*entities.Certification: 查询结果列表
|
||||
// - int64: 总记录数
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) List(ctx context.Context, query *queries.ListCertificationsQuery) ([]*entities.Certification, int64, error) {
|
||||
db := r.GetDB(ctx).Model(&entities.Certification{})
|
||||
|
||||
// 应用过滤条件
|
||||
if query.UserID != "" {
|
||||
db = db.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Status != "" {
|
||||
db = db.Where("status = ?", query.Status)
|
||||
}
|
||||
if len(query.Statuses) > 0 {
|
||||
db = db.Where("status IN ?", query.Statuses)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
var total int64
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("查询认证总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 应用排序和分页
|
||||
if query.SortBy != "" {
|
||||
orderClause := query.SortBy
|
||||
if query.SortOrder != "" {
|
||||
orderClause += " " + strings.ToUpper(query.SortOrder)
|
||||
}
|
||||
db = db.Order(orderClause)
|
||||
} else {
|
||||
db = db.Order("created_at DESC")
|
||||
}
|
||||
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
db = db.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 执行查询
|
||||
var certifications []*entities.Certification
|
||||
if err := db.Find(&certifications).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("查询认证列表失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, total, nil
|
||||
}
|
||||
|
||||
// ListByUserIDs 根据用户ID列表查询
|
||||
func (r *GormCertificationQueryRepository) ListByUserIDs(ctx context.Context, userIDs []string) ([]*entities.Certification, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return []*entities.Certification{}, nil
|
||||
}
|
||||
|
||||
var certifications []*entities.Certification
|
||||
if err := r.GetDB(ctx).Where("user_id IN ?", userIDs).Order("created_at DESC").Find(&certifications).Error; err != nil {
|
||||
return nil, fmt.Errorf("根据用户ID列表查询认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态查询
|
||||
func (r *GormCertificationQueryRepository) ListByStatus(ctx context.Context, status enums.CertificationStatus, limit int) ([]*entities.Certification, error) {
|
||||
db := r.GetDB(ctx).Where("status = ?", status).Order("created_at DESC")
|
||||
if limit > 0 {
|
||||
db = db.Limit(limit)
|
||||
}
|
||||
|
||||
var certifications []*entities.Certification
|
||||
if err := db.Find(&certifications).Error; err != nil {
|
||||
return nil, fmt.Errorf("根据状态查询认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// ================ 业务查询 ================
|
||||
|
||||
// FindByAuthFlowID 根据认证流程ID查询
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 业务查询,缓存15分钟
|
||||
// - 回调查询频率较高
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - authFlowID: 认证流程ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) FindByAuthFlowID(ctx context.Context, authFlowID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.SmartGetByField(ctx, &cert, "auth_flow_id", authFlowID, QueryCacheTTLBusinessQuery)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("认证流程不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("根据认证流程ID查询失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// FindByEsignFlowID 根据e签宝流程ID查询
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 业务查询,缓存15分钟
|
||||
// - 回调查询频率较高
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - esignFlowID: e签宝流程ID
|
||||
//
|
||||
// 返回:
|
||||
// - *entities.Certification: 查询到的认证,未找到时返回nil
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) FindByEsignFlowID(ctx context.Context, esignFlowID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.SmartGetByField(ctx, &cert, "esign_flow_id", esignFlowID, QueryCacheTTLBusinessQuery)
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("e签宝流程不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("根据e签宝流程ID查询失败: %w", err)
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// ListPendingRetry 查询待重试的认证
|
||||
//
|
||||
// 缓存策略:
|
||||
// - 管理查询,不缓存保证数据实时性
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - maxRetryCount: 最大重试次数
|
||||
//
|
||||
// 返回:
|
||||
// - []*entities.Certification: 待重试的认证列表
|
||||
// - error: 查询失败时的错误信息
|
||||
func (r *GormCertificationQueryRepository) ListPendingRetry(ctx context.Context, maxRetryCount int) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status IN ? AND retry_count < ?",
|
||||
[]enums.CertificationStatus{
|
||||
enums.StatusInfoRejected,
|
||||
enums.StatusContractRejected,
|
||||
enums.StatusContractExpired,
|
||||
},
|
||||
maxRetryCount).
|
||||
Order("created_at ASC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询待重试认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetPendingCertifications 获取待处理认证
|
||||
func (r *GormCertificationQueryRepository) GetPendingCertifications(ctx context.Context) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status IN ?", []enums.CertificationStatus{
|
||||
enums.StatusPending,
|
||||
enums.StatusInfoSubmitted,
|
||||
}).
|
||||
Order("created_at ASC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询待处理认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetExpiredContracts 获取过期合同
|
||||
func (r *GormCertificationQueryRepository) GetExpiredContracts(ctx context.Context) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.WithoutCache().GetDB(ctx).
|
||||
Where("status = ?", enums.StatusContractExpired).
|
||||
Order("updated_at DESC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询过期合同失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetCertificationsByDateRange 根据日期范围获取认证
|
||||
func (r *GormCertificationQueryRepository) GetCertificationsByDateRange(ctx context.Context, startDate, endDate time.Time) ([]*entities.Certification, error) {
|
||||
var certifications []*entities.Certification
|
||||
err := r.GetDB(ctx).
|
||||
Where("created_at BETWEEN ? AND ?", startDate, endDate).
|
||||
Order("created_at DESC").
|
||||
Find(&certifications).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("根据日期范围查询认证失败: %w", err)
|
||||
}
|
||||
|
||||
return certifications, nil
|
||||
}
|
||||
|
||||
// GetUserActiveCertification 获取用户当前活跃认证
|
||||
func (r *GormCertificationQueryRepository) GetUserActiveCertification(ctx context.Context, userID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND status NOT IN ?", userID, []enums.CertificationStatus{
|
||||
enums.StatusContractSigned,
|
||||
enums.StatusInfoRejected,
|
||||
enums.StatusContractRejected,
|
||||
enums.StatusContractExpired,
|
||||
}).
|
||||
First(&cert).Error
|
||||
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("用户没有活跃的认证申请")
|
||||
}
|
||||
return nil, fmt.Errorf("查询用户活跃认证失败: %w", err)
|
||||
}
|
||||
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// ================ 统计查询 ================
|
||||
|
||||
|
||||
|
||||
// CountByFailureReason 按失败原因统计
|
||||
func (r *GormCertificationQueryRepository) CountByFailureReason(ctx context.Context, reason enums.FailureReason) (int64, error) {
|
||||
var count int64
|
||||
if err := r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).Where("failure_reason = ?", reason).Count(&count).Error; err != nil {
|
||||
return 0, fmt.Errorf("按失败原因统计认证失败: %w", err)
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetProgressStatistics 获取进度统计
|
||||
func (r *GormCertificationQueryRepository) GetProgressStatistics(ctx context.Context) (*repositories.CertificationProgressStats, error) {
|
||||
// 简化实现
|
||||
return &repositories.CertificationProgressStats{
|
||||
StatusProgress: make(map[enums.CertificationStatus]int64),
|
||||
ProgressDistribution: make(map[int]int64),
|
||||
StageTimeStats: make(map[string]*repositories.CertificationStageTimeInfo),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SearchByCompanyName 按公司名搜索
|
||||
func (r *GormCertificationQueryRepository) SearchByCompanyName(ctx context.Context, companyName string, limit int) ([]*entities.Certification, error) {
|
||||
// 简化实现,暂时返回空结果
|
||||
r.GetLogger().Warn("按公司名搜索功能待实现,需要企业信息服务支持")
|
||||
return []*entities.Certification{}, nil
|
||||
}
|
||||
|
||||
// SearchByLegalPerson 按法人搜索
|
||||
func (r *GormCertificationQueryRepository) SearchByLegalPerson(ctx context.Context, legalPersonName string, limit int) ([]*entities.Certification, error) {
|
||||
// 简化实现,暂时返回空结果
|
||||
r.GetLogger().Warn("按法人搜索功能待实现,需要企业信息服务支持")
|
||||
return []*entities.Certification{}, nil
|
||||
}
|
||||
|
||||
// InvalidateCache 清除缓存
|
||||
func (r *GormCertificationQueryRepository) InvalidateCache(ctx context.Context, keys ...string) error {
|
||||
// 简化实现,暂不处理缓存
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshCache 刷新缓存
|
||||
func (r *GormCertificationQueryRepository) RefreshCache(ctx context.Context, certificationID string) error {
|
||||
// 简化实现,暂不处理缓存
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 缓存管理方法 ================
|
||||
|
||||
// WarmupCache 预热认证查询缓存
|
||||
//
|
||||
// 业务说明:
|
||||
// - 系统启动时预热常用查询的缓存
|
||||
// - 提升首次访问的响应速度
|
||||
//
|
||||
// 预热策略:
|
||||
// - 活跃认证:30分钟长期缓存
|
||||
// - 待处理认证:15分钟中期缓存
|
||||
func (r *GormCertificationQueryRepository) WarmupCache(ctx context.Context) error {
|
||||
r.GetLogger().Info("开始预热认证查询缓存")
|
||||
|
||||
queries := []database.WarmupQuery{
|
||||
{
|
||||
Name: "active_certifications",
|
||||
TTL: QueryCacheTTLWarmupLong,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
{
|
||||
Name: "pending_certifications",
|
||||
TTL: QueryCacheTTLWarmupMedium,
|
||||
Dest: &[]entities.Certification{},
|
||||
},
|
||||
}
|
||||
|
||||
return r.WarmupCommonQueries(ctx, queries)
|
||||
}
|
||||
|
||||
// GetCacheStats 获取缓存统计信息
|
||||
//
|
||||
// 返回当前Repository的缓存使用统计,包括:
|
||||
// - 基础缓存信息(命中率、键数量等)
|
||||
// - 特定的缓存模式列表
|
||||
// - 性能指标
|
||||
func (r *GormCertificationQueryRepository) GetCacheStats() map[string]interface{} {
|
||||
stats := r.GetCacheInfo()
|
||||
stats["specific_patterns"] = []string{
|
||||
QueryCachePatternTable,
|
||||
QueryCachePatternUser,
|
||||
}
|
||||
return stats
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package certification
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hyapi-server/internal/domains/certification/entities"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
EnterpriseInfoSubmitRecordsTable = "enterprise_info_submit_records"
|
||||
)
|
||||
|
||||
type GormEnterpriseInfoSubmitRecordRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.EnterpriseInfoSubmitRecord{})
|
||||
}
|
||||
|
||||
func NewGormEnterpriseInfoSubmitRecordRepository(db *gorm.DB, logger *zap.Logger) *GormEnterpriseInfoSubmitRecordRepository {
|
||||
return &GormEnterpriseInfoSubmitRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, EnterpriseInfoSubmitRecordsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Create(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
|
||||
return r.CreateEntity(ctx, record)
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Update(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
|
||||
return r.UpdateEntity(ctx, record)
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) Exists(ctx context.Context, ID string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, ID, &entities.EnterpriseInfoSubmitRecord{})
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindByID(ctx context.Context, id string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).Where("id = ?", id).First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Order("submit_at DESC").
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestVerifiedByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
|
||||
var record entities.EnterpriseInfoSubmitRecord
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND status = ?", userID, "verified").
|
||||
Order("verified_at DESC").
|
||||
First(&record).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// ExistsByUnifiedSocialCodeExcludeUser 检查该统一社会信用代码是否已被其他用户占用(已提交或已通过验证的记录)
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) ExistsByUnifiedSocialCodeExcludeUser(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
if unifiedSocialCode == "" {
|
||||
return false, nil
|
||||
}
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{}).
|
||||
Where("unified_social_code = ? AND status IN (?, ?)", unifiedSocialCode, "submitted", "verified")
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("user_id != ?", excludeUserID)
|
||||
}
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormEnterpriseInfoSubmitRecordRepository) List(ctx context.Context, filter repositories.ListSubmitRecordsFilter) (*repositories.ListSubmitRecordsResult, error) {
|
||||
base := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{})
|
||||
if filter.CertificationStatus != "" {
|
||||
base = base.Joins("JOIN certifications ON certifications.user_id = enterprise_info_submit_records.user_id AND certifications.deleted_at IS NULL").
|
||||
Where("certifications.status = ?", filter.CertificationStatus)
|
||||
}
|
||||
if filter.CompanyName != "" {
|
||||
base = base.Where("enterprise_info_submit_records.company_name LIKE ?", "%"+filter.CompanyName+"%")
|
||||
}
|
||||
if filter.LegalPersonPhone != "" {
|
||||
base = base.Where("enterprise_info_submit_records.legal_person_phone = ?", filter.LegalPersonPhone)
|
||||
}
|
||||
if filter.LegalPersonName != "" {
|
||||
base = base.Where("enterprise_info_submit_records.legal_person_name LIKE ?", "%"+filter.LegalPersonName+"%")
|
||||
}
|
||||
var total int64
|
||||
if err := base.Count(&total).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if filter.PageSize <= 0 {
|
||||
filter.PageSize = 10
|
||||
}
|
||||
if filter.Page <= 0 {
|
||||
filter.Page = 1
|
||||
}
|
||||
offset := (filter.Page - 1) * filter.PageSize
|
||||
var records []*entities.EnterpriseInfoSubmitRecord
|
||||
q := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{})
|
||||
if filter.CertificationStatus != "" {
|
||||
q = q.Joins("JOIN certifications ON certifications.user_id = enterprise_info_submit_records.user_id AND certifications.deleted_at IS NULL").
|
||||
Where("certifications.status = ?", filter.CertificationStatus)
|
||||
}
|
||||
if filter.CompanyName != "" {
|
||||
q = q.Where("enterprise_info_submit_records.company_name LIKE ?", "%"+filter.CompanyName+"%")
|
||||
}
|
||||
if filter.LegalPersonPhone != "" {
|
||||
q = q.Where("enterprise_info_submit_records.legal_person_phone = ?", filter.LegalPersonPhone)
|
||||
}
|
||||
if filter.LegalPersonName != "" {
|
||||
q = q.Where("enterprise_info_submit_records.legal_person_name LIKE ?", "%"+filter.LegalPersonName+"%")
|
||||
}
|
||||
err := q.Order("enterprise_info_submit_records.submit_at DESC").Offset(offset).Limit(filter.PageSize).Find(&records).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &repositories.ListSubmitRecordsResult{Records: records, Total: total}, nil
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
AlipayOrdersTable = "typay_orders"
|
||||
)
|
||||
|
||||
type GormAlipayOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.AlipayOrderRepository = (*GormAlipayOrderRepository)(nil)
|
||||
|
||||
func NewGormAlipayOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.AlipayOrderRepository {
|
||||
return &GormAlipayOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, AlipayOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Create(ctx context.Context, order entities.AlipayOrder) (entities.AlipayOrder, error) {
|
||||
err := r.CreateEntity(ctx, &order)
|
||||
return order, err
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByID(ctx context.Context, id string) (entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.AlipayOrder{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.AlipayOrder{}, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.AlipayOrder, error) {
|
||||
var order entities.AlipayOrder
|
||||
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.AlipayOrder, error) {
|
||||
var orders []entities.AlipayOrder
|
||||
err := r.GetDB(ctx).
|
||||
Joins("JOIN recharge_records ON typay_orders.recharge_id = recharge_records.id").
|
||||
Where("recharge_records.user_id = ?", userID).
|
||||
Order("typay_orders.created_at DESC").
|
||||
Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Update(ctx context.Context, order entities.AlipayOrder) error {
|
||||
return r.UpdateEntity(ctx, &order)
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.AlipayOrderStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.AlipayOrder{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormAlipayOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
"hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
PurchaseOrdersTable = "ty_purchase_orders"
|
||||
)
|
||||
|
||||
type GormPurchaseOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.PurchaseOrderRepository = (*GormPurchaseOrderRepository)(nil)
|
||||
|
||||
func NewGormPurchaseOrderRepository(db *gorm.DB, logger *zap.Logger) repositories.PurchaseOrderRepository {
|
||||
return &GormPurchaseOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, PurchaseOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) Create(ctx context.Context, order *entities.PurchaseOrder) (*entities.PurchaseOrder, error) {
|
||||
err := r.CreateEntity(ctx, order)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) Update(ctx context.Context, order *entities.PurchaseOrder) error {
|
||||
return r.UpdateEntity(ctx, order)
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByID(ctx context.Context, id string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByOrderNo(ctx context.Context, orderNo string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).Where("order_no = ?", orderNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.PurchaseOrder, int64, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
var count int64
|
||||
|
||||
db := r.GetDB(ctx).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
err := db.Model(&entities.PurchaseOrder{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
err = db.Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
result := make([]*entities.PurchaseOrder, len(orders))
|
||||
for i := range orders {
|
||||
result[i] = &orders[i]
|
||||
}
|
||||
|
||||
return result, count, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByUserIDAndProductID(ctx context.Context, userID, productID string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND product_id = ? AND status = ?", userID, productID, entities.PurchaseOrderStatusPaid).
|
||||
First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByPaymentTypeAndTransactionID(ctx context.Context, paymentType, transactionID string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Where("payment_type = ? AND trade_no = ?", paymentType, transactionID).
|
||||
First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByTradeNo(ctx context.Context, tradeNo string) (*entities.PurchaseOrder, error) {
|
||||
var order entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).Where("trade_no = ?", tradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) UpdatePaymentStatus(ctx context.Context, orderID string, status entities.PurchaseOrderStatus, tradeNo *string, payAmount, receiptAmount *string, paymentTime *time.Time) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if tradeNo != nil {
|
||||
updates["trade_no"] = *tradeNo
|
||||
}
|
||||
|
||||
if payAmount != nil {
|
||||
updates["pay_amount"] = *payAmount
|
||||
}
|
||||
|
||||
if receiptAmount != nil {
|
||||
updates["receipt_amount"] = *receiptAmount
|
||||
}
|
||||
|
||||
if paymentTime != nil {
|
||||
updates["pay_time"] = *paymentTime
|
||||
updates["notify_time"] = *paymentTime
|
||||
}
|
||||
|
||||
err := r.GetDB(ctx).
|
||||
Model(&entities.PurchaseOrder{}).
|
||||
Where("id = ?", orderID).
|
||||
Updates(updates).Error
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetUserPurchasedProductCodes(ctx context.Context, userID string) ([]string, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Select("product_code").
|
||||
Where("user_id = ? AND status = ?", userID, entities.PurchaseOrderStatusPaid).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
codesMap := make(map[string]bool)
|
||||
for _, order := range orders {
|
||||
// 添加主产品编号
|
||||
if order.ProductCode != "" {
|
||||
codesMap[order.ProductCode] = true
|
||||
}
|
||||
}
|
||||
|
||||
codes := make([]string, 0, len(codesMap))
|
||||
for code := range codesMap {
|
||||
codes = append(codes, code)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetUserPaidProductIDs(ctx context.Context, userID string) ([]string, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
err := r.GetDB(ctx).
|
||||
Select("product_id").
|
||||
Where("user_id = ? AND status = ?", userID, entities.PurchaseOrderStatusPaid).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idsMap := make(map[string]bool)
|
||||
for _, order := range orders {
|
||||
// 添加主产品ID
|
||||
if order.ProductID != "" {
|
||||
idsMap[order.ProductID] = true
|
||||
}
|
||||
}
|
||||
|
||||
ids := make([]string, 0, len(idsMap))
|
||||
for id := range idsMap {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) HasUserPurchased(ctx context.Context, userID string, productCode string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.PurchaseOrder{}).
|
||||
Where("user_id = ? AND product_code = ? AND status = ?", userID, productCode, entities.PurchaseOrderStatusPaid).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetExpiringOrders(ctx context.Context, before time.Time, limit int) ([]*entities.PurchaseOrder, error) {
|
||||
// 购买订单实体没有过期时间字段,此方法返回空结果
|
||||
return []*entities.PurchaseOrder{}, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetExpiredOrders(ctx context.Context, limit int) ([]*entities.PurchaseOrder, error) {
|
||||
// 购买订单实体没有过期时间字段,此方法返回空结果
|
||||
return []*entities.PurchaseOrder{}, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByStatus(ctx context.Context, status entities.PurchaseOrderStatus, limit, offset int) ([]*entities.PurchaseOrder, int64, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
var count int64
|
||||
|
||||
db := r.GetDB(ctx).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
err := db.Model(&entities.PurchaseOrder{}).Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
err = db.Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Offset(offset).
|
||||
Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
result := make([]*entities.PurchaseOrder, len(orders))
|
||||
for i := range orders {
|
||||
result[i] = &orders[i]
|
||||
}
|
||||
|
||||
return result, count, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) GetByFilters(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.PurchaseOrder, error) {
|
||||
var orders []entities.PurchaseOrder
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
if userID, ok := filters["user_id"]; ok {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if status, ok := filters["status"]; ok && status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if paymentType, ok := filters["payment_type"]; ok && paymentType != "" {
|
||||
db = db.Where("payment_type = ?", paymentType)
|
||||
}
|
||||
if payChannel, ok := filters["pay_channel"]; ok && payChannel != "" {
|
||||
db = db.Where("pay_channel = ?", payChannel)
|
||||
}
|
||||
if startTime, ok := filters["start_time"]; ok && startTime != "" {
|
||||
db = db.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"]; ok && endTime != "" {
|
||||
db = db.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用排序和分页
|
||||
// 默认按创建时间倒序
|
||||
db = db.Order("created_at DESC")
|
||||
|
||||
// 应用分页
|
||||
if options.PageSize > 0 {
|
||||
db = db.Limit(options.PageSize)
|
||||
}
|
||||
|
||||
if options.Page > 0 {
|
||||
db = db.Offset((options.Page - 1) * options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err := db.Find(&orders).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.PurchaseOrder, len(orders))
|
||||
for i := range orders {
|
||||
result[i] = &orders[i]
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormPurchaseOrderRepository) CountByFilters(ctx context.Context, filters map[string]interface{}) (int64, error) {
|
||||
var count int64
|
||||
|
||||
db := r.GetDB(ctx).Model(&entities.PurchaseOrder{})
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
if userID, ok := filters["user_id"]; ok {
|
||||
db = db.Where("user_id = ?", userID)
|
||||
}
|
||||
if status, ok := filters["status"]; ok && status != "" {
|
||||
db = db.Where("status = ?", status)
|
||||
}
|
||||
if paymentType, ok := filters["payment_type"]; ok && paymentType != "" {
|
||||
db = db.Where("payment_type = ?", paymentType)
|
||||
}
|
||||
if payChannel, ok := filters["pay_channel"]; ok && payChannel != "" {
|
||||
db = db.Where("pay_channel = ?", payChannel)
|
||||
}
|
||||
if startTime, ok := filters["start_time"]; ok && startTime != "" {
|
||||
db = db.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"]; ok && endTime != "" {
|
||||
db = db.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
|
||||
// 执行计数
|
||||
err := db.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,509 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
RechargeRecordsTable = "recharge_records"
|
||||
)
|
||||
|
||||
type GormRechargeRecordRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.RechargeRecordRepository = (*GormRechargeRecordRepository)(nil)
|
||||
|
||||
func NewGormRechargeRecordRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.RechargeRecordRepository {
|
||||
return &GormRechargeRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, RechargeRecordsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Create(ctx context.Context, record entities.RechargeRecord) (entities.RechargeRecord, error) {
|
||||
err := r.CreateEntity(ctx, &record)
|
||||
return record, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByID(ctx context.Context, id string) (entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.SmartGetByID(ctx, id, &record)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.RechargeRecord{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.RechargeRecord{}, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByUserID(ctx context.Context, userID string) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByAlipayOrderID(ctx context.Context, alipayOrderID string) (*entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("alipay_order_id = ?", alipayOrderID).First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByTransferOrderID(ctx context.Context, transferOrderID string) (*entities.RechargeRecord, error) {
|
||||
var record entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("transfer_order_id = ?", transferOrderID).First(&record).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Update(ctx context.Context, record entities.RechargeRecord) error {
|
||||
return r.UpdateEntity(ctx, &record)
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) UpdateStatus(ctx context.Context, id string, status entities.RechargeStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
|
||||
// 检查是否有 company_name 筛选,如果有则需要 JOIN 表
|
||||
hasCompanyNameFilter := false
|
||||
if options.Filters != nil {
|
||||
if companyName, ok := options.Filters["company_name"].(string); ok && companyName != "" {
|
||||
hasCompanyNameFilter = true
|
||||
}
|
||||
}
|
||||
|
||||
var query *gorm.DB
|
||||
if hasCompanyNameFilter {
|
||||
// 使用 JOIN 查询以支持企业名称筛选
|
||||
query = r.GetDB(ctx).Table("recharge_records rr").
|
||||
Joins("LEFT JOIN users u ON rr.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id")
|
||||
} else {
|
||||
// 普通查询
|
||||
query = r.GetDB(ctx).Model(&entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
// 特殊处理时间范围过滤器
|
||||
if key == "start_time" {
|
||||
if startTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at >= ?", startTime)
|
||||
} else {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
}
|
||||
} else if key == "end_time" {
|
||||
if endTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at <= ?", endTime)
|
||||
} else {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
} else if key == "company_name" {
|
||||
// 处理企业名称筛选
|
||||
if companyName, ok := value.(string); ok && companyName != "" {
|
||||
query = query.Where("ei.company_name LIKE ?", "%"+companyName+"%")
|
||||
}
|
||||
} else if key == "min_amount" {
|
||||
// 处理最小金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount >= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount >= ?", amount)
|
||||
}
|
||||
}
|
||||
} else if key == "max_amount" {
|
||||
// 处理最大金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount <= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount <= ?", amount)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他过滤器使用等值查询
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr."+key+" = ?", value)
|
||||
} else {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.user_id LIKE ? OR rr.transfer_order_id LIKE ? OR rr.alipay_order_id LIKE ? OR rr.wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
} else {
|
||||
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ? OR wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
}
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
|
||||
// 检查是否有 company_name 筛选,如果有则需要 JOIN 表
|
||||
hasCompanyNameFilter := false
|
||||
if options.Filters != nil {
|
||||
if companyName, ok := options.Filters["company_name"].(string); ok && companyName != "" {
|
||||
hasCompanyNameFilter = true
|
||||
}
|
||||
}
|
||||
|
||||
var query *gorm.DB
|
||||
if hasCompanyNameFilter {
|
||||
// 使用 JOIN 查询以支持企业名称筛选
|
||||
query = r.GetDB(ctx).Table("recharge_records rr").
|
||||
Select("rr.*").
|
||||
Joins("LEFT JOIN users u ON rr.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id")
|
||||
} else {
|
||||
// 普通查询
|
||||
query = r.GetDB(ctx).Model(&entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
// 特殊处理 user_ids 过滤器
|
||||
if key == "user_ids" {
|
||||
if userIds, ok := value.(string); ok && userIds != "" {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.user_id IN ?", strings.Split(userIds, ","))
|
||||
} else {
|
||||
query = query.Where("user_id IN ?", strings.Split(userIds, ","))
|
||||
}
|
||||
}
|
||||
} else if key == "company_name" {
|
||||
// 处理企业名称筛选
|
||||
if companyName, ok := value.(string); ok && companyName != "" {
|
||||
query = query.Where("ei.company_name LIKE ?", "%"+companyName+"%")
|
||||
}
|
||||
} else if key == "start_time" {
|
||||
// 处理开始时间范围
|
||||
if startTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at >= ?", startTime)
|
||||
} else {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
}
|
||||
} else if key == "end_time" {
|
||||
// 处理结束时间范围
|
||||
if endTime, ok := value.(time.Time); ok {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.created_at <= ?", endTime)
|
||||
} else {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
}
|
||||
} else if key == "min_amount" {
|
||||
// 处理最小金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount >= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount >= ?", amount)
|
||||
}
|
||||
}
|
||||
} else if key == "max_amount" {
|
||||
// 处理最大金额,支持string、int、int64类型
|
||||
if amount, err := r.parseAmount(value); err == nil {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.amount <= ?", amount)
|
||||
} else {
|
||||
query = query.Where("amount <= ?", amount)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他过滤器使用等值查询
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr."+key+" = ?", value)
|
||||
} else {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Where("rr.user_id LIKE ? OR rr.transfer_order_id LIKE ? OR rr.alipay_order_id LIKE ? OR rr.wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
} else {
|
||||
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ? OR wechat_order_id LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Order("rr." + options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
} else {
|
||||
if hasCompanyNameFilter {
|
||||
query = query.Order("rr.created_at DESC")
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) CreateBatch(ctx context.Context, records []entities.RechargeRecord) error {
|
||||
return r.GetDB(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.RechargeRecord, error) {
|
||||
var records []entities.RechargeRecord
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) UpdateBatch(ctx context.Context, records []entities.RechargeRecord) error {
|
||||
return r.GetDB(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.RechargeRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.RechargeRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormRechargeRecordRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), RechargeRecordsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
func (r *GormRechargeRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.RechargeRecord{})
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserId 获取用户总充值金额(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ? AND status = ? AND recharge_type != ?", userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总充值金额(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ? AND status = ? AND recharge_type != ? AND created_at >= ? AND created_at < ?", userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetDailyStatsByUserId 获取用户每日充值统计(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE user_id = $1
|
||||
AND status = $2
|
||||
AND recharge_type != $3
|
||||
AND DATE(created_at) >= $4
|
||||
AND DATE(created_at) <= $5
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStatsByUserId 获取用户每月充值统计(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE user_id = $1
|
||||
AND status = $2
|
||||
AND recharge_type != $3
|
||||
AND created_at >= $4
|
||||
AND created_at <= $5
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemTotalAmount 获取系统总充值金额(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Where("status = ? AND recharge_type != ?", entities.RechargeStatusSuccess, entities.RechargeTypeGift).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemAmountByDateRange 获取系统指定时间范围内的充值金额(排除赠送)
|
||||
// endDate 应该是结束日期当天的次日00:00:00(日统计)或下个月1号00:00:00(月统计),使用 < 而不是 <=
|
||||
func (r *GormRechargeRecordRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
|
||||
Where("status = ? AND recharge_type != ? AND created_at >= ? AND created_at < ?", entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemDailyStats 获取系统每日充值统计(排除赠送)
|
||||
// startDate 和 endDate 应该是时间对象,endDate 应该是结束日期当天的次日00:00:00,使用 < 而不是 <=
|
||||
func (r *GormRechargeRecordRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE status = ?
|
||||
AND recharge_type != ?
|
||||
AND created_at >= ?
|
||||
AND created_at < ?
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyStats 获取系统每月充值统计(排除赠送)
|
||||
func (r *GormRechargeRecordRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM recharge_records
|
||||
WHERE status = ?
|
||||
AND recharge_type != ?
|
||||
AND created_at >= ?
|
||||
AND created_at < ?
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// parseAmount 解析金额值,支持string、int、int64类型,转换为decimal.Decimal
|
||||
func (r *GormRechargeRecordRepository) parseAmount(value interface{}) (decimal.Decimal, error) {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return decimal.Zero, fmt.Errorf("empty string")
|
||||
}
|
||||
return decimal.NewFromString(v)
|
||||
case int:
|
||||
return decimal.NewFromInt(int64(v)), nil
|
||||
case int64:
|
||||
return decimal.NewFromInt(v), nil
|
||||
case float64:
|
||||
return decimal.NewFromFloat(v), nil
|
||||
case decimal.Decimal:
|
||||
return v, nil
|
||||
default:
|
||||
return decimal.Zero, fmt.Errorf("unsupported type: %T", value)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,348 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
WalletsTable = "wallets"
|
||||
)
|
||||
|
||||
type GormWalletRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
|
||||
|
||||
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
|
||||
return &GormWalletRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
|
||||
err := r.CreateEntity(ctx, &wallet)
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.SmartGetByID(ctx, id, &wallet)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Wallet{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Wallet{}, err
|
||||
}
|
||||
return wallet, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
|
||||
return r.UpdateEntity(ctx, &wallet)
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.Wallet{})
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Wallet{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
return r.GetDB(ctx).Create(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&wallets).Error
|
||||
return wallets, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
return r.GetDB(ctx).Save(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
query := r.GetDB(ctx).Model(&entities.Wallet{})
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
return wallets, query.Find(&wallets).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// UpdateBalanceWithVersion 乐观锁自动重试,最大重试maxRetry次
|
||||
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, amount decimal.Decimal, operation string) (bool, error) {
|
||||
maxRetry := 10
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
// 每次重试都重新获取最新的钱包信息
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("id = ?", walletID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("获取钱包信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 重新计算新余额
|
||||
var newBalance decimal.Decimal
|
||||
switch operation {
|
||||
case "add":
|
||||
newBalance = wallet.Balance.Add(amount)
|
||||
case "subtract":
|
||||
newBalance = wallet.Balance.Sub(amount)
|
||||
default:
|
||||
return false, fmt.Errorf("不支持的操作类型: %s", operation)
|
||||
}
|
||||
|
||||
// 乐观锁更新
|
||||
result := r.GetDB(ctx).Model(&entities.Wallet{}).
|
||||
Where("id = ? AND version = ?", walletID, wallet.Version).
|
||||
Updates(map[string]interface{}{
|
||||
"balance": newBalance.String(),
|
||||
"version": wallet.Version + 1,
|
||||
})
|
||||
|
||||
if result.Error != nil {
|
||||
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 乐观锁冲突,继续重试
|
||||
// 注意:这里可以添加日志记录,但需要确保logger可用
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
|
||||
}
|
||||
|
||||
// UpdateBalanceByUserID 乐观锁更新(通过用户ID直接更新,使用原生SQL)
|
||||
func (r *GormWalletRepository) UpdateBalanceByUserID(ctx context.Context, userID string, amount decimal.Decimal, operation string) (bool, error) {
|
||||
maxRetry := 20 // 增加重试次数
|
||||
baseDelay := 1 // 基础延迟毫秒
|
||||
|
||||
for i := 0; i < maxRetry; i++ {
|
||||
// 每次重试都重新获取最新的钱包信息
|
||||
var wallet entities.Wallet
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("获取钱包信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 重新计算新余额
|
||||
var newBalance decimal.Decimal
|
||||
switch operation {
|
||||
case "add":
|
||||
newBalance = wallet.Balance.Add(amount)
|
||||
case "subtract":
|
||||
newBalance = wallet.Balance.Sub(amount)
|
||||
default:
|
||||
return false, fmt.Errorf("不支持的操作类型: %s", operation)
|
||||
}
|
||||
|
||||
// 使用原生SQL进行乐观锁更新
|
||||
newVersion := wallet.Version + 1
|
||||
result := r.GetDB(ctx).Exec(`
|
||||
UPDATE wallets
|
||||
SET balance = ?, version = ?, updated_at = NOW()
|
||||
WHERE user_id = ? AND version = ?
|
||||
`, newBalance.String(), newVersion, userID, wallet.Version)
|
||||
|
||||
if result.Error != nil {
|
||||
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 1 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// 乐观锁冲突,添加指数退避延迟
|
||||
if i < maxRetry-1 {
|
||||
delay := baseDelay * (1 << i) // 指数退避: 1ms, 2ms, 4ms, 8ms...
|
||||
if delay > 50 {
|
||||
delay = 50 // 最大延迟50ms
|
||||
}
|
||||
time.Sleep(time.Duration(delay) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 总钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 激活钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 用户钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户激活钱包数
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 用户今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,643 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WalletTransactionWithProduct 包含产品名称的钱包交易记录
|
||||
type WalletTransactionWithProduct struct {
|
||||
entities.WalletTransaction
|
||||
ProductName string `json:"product_name" gorm:"column:product_name"`
|
||||
}
|
||||
|
||||
const (
|
||||
WalletTransactionsTable = "wallet_transactions"
|
||||
)
|
||||
|
||||
type GormWalletTransactionRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WalletTransactionRepository = (*GormWalletTransactionRepository)(nil)
|
||||
|
||||
func NewGormWalletTransactionRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletTransactionRepository {
|
||||
return &GormWalletTransactionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletTransactionsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Create(ctx context.Context, transaction entities.WalletTransaction) (entities.WalletTransaction, error) {
|
||||
err := r.CreateEntity(ctx, &transaction)
|
||||
return transaction, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Update(ctx context.Context, transaction entities.WalletTransaction) error {
|
||||
return r.UpdateEntity(ctx, &transaction)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByID(ctx context.Context, id string) (entities.WalletTransaction, error) {
|
||||
var transaction entities.WalletTransaction
|
||||
err := r.SmartGetByID(ctx, id, &transaction)
|
||||
return transaction, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.WalletTransaction, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
options := database.CacheListOptions{
|
||||
Where: "user_id = ?",
|
||||
Args: []interface{}{userID},
|
||||
Order: "created_at DESC",
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
err := r.ListWithCache(ctx, &transactions, 10*time.Minute, options)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByApiCallID(ctx context.Context, apiCallID string) (*entities.WalletTransaction, error) {
|
||||
var transaction entities.WalletTransaction
|
||||
err := r.FindOne(ctx, &transaction, "api_call_id = ?", apiCallID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &transaction, nil
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
|
||||
var transactions []*entities.WalletTransaction
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 关键词筛选(支持transaction_id和product_name)
|
||||
if keyword, ok := filters["keyword"].(string); ok && keyword != "" {
|
||||
whereCondition += " AND (transaction_id LIKE ? OR product_id IN (SELECT id FROM product WHERE name LIKE ?))"
|
||||
whereArgs = append(whereArgs, "%"+keyword+"%", "%"+keyword+"%")
|
||||
}
|
||||
|
||||
// API调用ID筛选
|
||||
if apiCallId, ok := filters["api_call_id"].(string); ok && apiCallId != "" {
|
||||
whereCondition += " AND api_call_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+apiCallId+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 使用基础仓储的分页查询方法
|
||||
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, total, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ?", userId)
|
||||
}
|
||||
|
||||
// CountByUserIdAndDateRange 按用户ID和日期范围统计钱包交易次数
|
||||
func (r *GormWalletTransactionRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
|
||||
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserId 获取用户总消费金额
|
||||
func (r *GormWalletTransactionRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ?", userId).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总消费金额
|
||||
func (r *GormWalletTransactionRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Where("user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate).
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetDailyStatsByUserId 获取用户每日消费统计
|
||||
func (r *GormWalletTransactionRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE user_id = $1
|
||||
AND DATE(created_at) >= $2
|
||||
AND DATE(created_at) <= $3
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetMonthlyStatsByUserId 获取用户每月消费统计
|
||||
func (r *GormWalletTransactionRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
// 构建SQL查询 - 使用PostgreSQL语法,使用具体的日期范围
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE user_id = $1
|
||||
AND created_at >= $2
|
||||
AND created_at <= $3
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// 实现interfaces.Repository接口的其他方法
|
||||
func (r *GormWalletTransactionRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.WalletTransaction, error) {
|
||||
var transactions []entities.WalletTransaction
|
||||
err := r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
return r.CountWithOptions(ctx, &entities.WalletTransaction{}, options)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) CreateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
|
||||
return r.CreateBatchEntity(ctx, &transactions)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.WalletTransaction, error) {
|
||||
var transactions []entities.WalletTransaction
|
||||
err := r.GetEntitiesByIDs(ctx, ids, &transactions)
|
||||
return transactions, err
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) UpdateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
|
||||
return r.UpdateBatchEntity(ctx, &transactions)
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.DeleteBatchEntity(ctx, ids, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) WithTx(tx interface{}) interfaces.Repository[entities.WalletTransaction] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletTransactionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletTransactionsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.SoftDeleteEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.RestoreEntity(ctx, id, &entities.WalletTransaction{})
|
||||
}
|
||||
|
||||
func (r *GormWalletTransactionRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
|
||||
var transactionsWithProduct []*WalletTransactionWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "wt.user_id = ?"
|
||||
whereArgs := []interface{}{userId}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND wt.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND wt.amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND wt.amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("wt." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("wt.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction并构建产品名称映射
|
||||
var transactions []*entities.WalletTransaction
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if t.ProductName != "" {
|
||||
productNameMap[transaction.ProductID] = t.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, transactions, total, nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndProductName 管理端:根据条件筛选所有钱包交易记录(包含产品名称)
|
||||
func (r *GormWalletTransactionRepository) ListWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
|
||||
var transactionsWithProduct []*WalletTransactionWithProduct
|
||||
var total int64
|
||||
|
||||
// 构建基础查询条件
|
||||
whereCondition := "1=1"
|
||||
whereArgs := []interface{}{}
|
||||
|
||||
// 应用筛选条件
|
||||
if filters != nil {
|
||||
// 用户ID筛选(支持单个和多个)
|
||||
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
|
||||
// 多个用户ID,逗号分隔
|
||||
userIdsList := strings.Split(userIds, ",")
|
||||
whereCondition += " AND wt.user_id IN ?"
|
||||
whereArgs = append(whereArgs, userIdsList)
|
||||
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
|
||||
// 单个用户ID
|
||||
whereCondition += " AND wt.user_id = ?"
|
||||
whereArgs = append(whereArgs, userId)
|
||||
}
|
||||
|
||||
// 产品ID筛选(支持多个)
|
||||
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
|
||||
// 多个产品ID,逗号分隔
|
||||
productIdsList := strings.Split(productIds, ",")
|
||||
whereCondition += " AND wt.product_id IN ?"
|
||||
whereArgs = append(whereArgs, productIdsList)
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at >= ?"
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereCondition += " AND wt.created_at <= ?"
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereCondition += " AND wt.transaction_id LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereCondition += " AND p.name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 企业名称筛选
|
||||
if companyName, ok := filters["company_name"].(string); ok && companyName != "" {
|
||||
whereCondition += " AND ei.company_name LIKE ?"
|
||||
whereArgs = append(whereArgs, "%"+companyName+"%")
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereCondition += " AND wt.amount >= ?"
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereCondition += " AND wt.amount <= ?"
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
}
|
||||
|
||||
// 构建JOIN查询
|
||||
// 需要JOIN product表获取产品名称,JOIN users和enterprise_infos表获取企业名称
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id").
|
||||
Joins("LEFT JOIN users u ON wt.user_id = u.id").
|
||||
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id").
|
||||
Where(whereCondition, whereArgs...)
|
||||
|
||||
// 获取总数
|
||||
var count int64
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
total = count
|
||||
|
||||
// 应用排序和分页
|
||||
if options.Sort != "" {
|
||||
query = query.Order("wt." + options.Sort + " " + options.Order)
|
||||
} else {
|
||||
query = query.Order("wt.created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
// 执行查询
|
||||
err = query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction并构建产品名称映射
|
||||
var transactions []*entities.WalletTransaction
|
||||
productNameMap := make(map[string]string)
|
||||
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
// 构建产品ID到产品名称的映射
|
||||
if t.ProductName != "" {
|
||||
productNameMap[transaction.ProductID] = t.ProductName
|
||||
}
|
||||
}
|
||||
|
||||
return productNameMap, transactions, total, nil
|
||||
}
|
||||
|
||||
// ExportWithFiltersAndProductName 导出钱包交易记录(包含产品名称和企业信息)
|
||||
func (r *GormWalletTransactionRepository) ExportWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}) ([]*entities.WalletTransaction, error) {
|
||||
var transactionsWithProduct []WalletTransactionWithProduct
|
||||
|
||||
// 构建查询
|
||||
query := r.GetDB(ctx).Table("wallet_transactions wt").
|
||||
Select("wt.*, p.name as product_name").
|
||||
Joins("LEFT JOIN product p ON wt.product_id = p.id")
|
||||
|
||||
// 构建WHERE条件
|
||||
var whereConditions []string
|
||||
var whereArgs []interface{}
|
||||
|
||||
// 用户ID筛选
|
||||
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
|
||||
whereConditions = append(whereConditions, "wt.user_id IN (?)")
|
||||
whereArgs = append(whereArgs, strings.Split(userIds, ","))
|
||||
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
|
||||
whereConditions = append(whereConditions, "wt.user_id = ?")
|
||||
whereArgs = append(whereArgs, userId)
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime, ok := filters["start_time"].(time.Time); ok {
|
||||
whereConditions = append(whereConditions, "wt.created_at >= ?")
|
||||
whereArgs = append(whereArgs, startTime)
|
||||
}
|
||||
if endTime, ok := filters["end_time"].(time.Time); ok {
|
||||
whereConditions = append(whereConditions, "wt.created_at <= ?")
|
||||
whereArgs = append(whereArgs, endTime)
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
|
||||
whereConditions = append(whereConditions, "wt.transaction_id LIKE ?")
|
||||
whereArgs = append(whereArgs, "%"+transactionId+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName, ok := filters["product_name"].(string); ok && productName != "" {
|
||||
whereConditions = append(whereConditions, "p.name LIKE ?")
|
||||
whereArgs = append(whereArgs, "%"+productName+"%")
|
||||
}
|
||||
|
||||
// 产品ID列表筛选
|
||||
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
|
||||
whereConditions = append(whereConditions, "wt.product_id IN (?)")
|
||||
whereArgs = append(whereArgs, strings.Split(productIds, ","))
|
||||
}
|
||||
|
||||
// 金额范围筛选
|
||||
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
|
||||
whereConditions = append(whereConditions, "wt.amount >= ?")
|
||||
whereArgs = append(whereArgs, minAmount)
|
||||
}
|
||||
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
|
||||
whereConditions = append(whereConditions, "wt.amount <= ?")
|
||||
whereArgs = append(whereArgs, maxAmount)
|
||||
}
|
||||
|
||||
// 应用WHERE条件
|
||||
if len(whereConditions) > 0 {
|
||||
query = query.Where(strings.Join(whereConditions, " AND "), whereArgs...)
|
||||
}
|
||||
|
||||
// 排序
|
||||
query = query.Order("wt.created_at DESC")
|
||||
|
||||
// 执行查询
|
||||
err := query.Find(&transactionsWithProduct).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为entities.WalletTransaction
|
||||
var transactions []*entities.WalletTransaction
|
||||
for _, t := range transactionsWithProduct {
|
||||
transaction := t.WalletTransaction
|
||||
transactions = append(transactions, &transaction)
|
||||
}
|
||||
|
||||
return transactions, nil
|
||||
}
|
||||
|
||||
// GetSystemTotalAmount 获取系统总消费金额
|
||||
func (r *GormWalletTransactionRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemAmountByDateRange 获取系统指定时间范围内的消费金额
|
||||
// endDate 应该是结束日期当天的次日00:00:00(日统计)或下个月1号00:00:00(月统计),使用 < 而不是 <=
|
||||
func (r *GormWalletTransactionRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
|
||||
var total float64
|
||||
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
|
||||
Where("created_at >= ? AND created_at < ?", startDate, endDate).
|
||||
Select("COALESCE(SUM(amount), 0)").
|
||||
Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetSystemDailyStats 获取系统每日消费统计
|
||||
func (r *GormWalletTransactionRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE DATE(created_at) >= ?
|
||||
AND DATE(created_at) <= ?
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyStats 获取系统每月消费统计
|
||||
func (r *GormWalletTransactionRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COALESCE(SUM(amount), 0) as amount
|
||||
FROM wallet_transactions
|
||||
WHERE created_at >= ?
|
||||
AND created_at < ?
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
WechatOrdersTable = "typay_orders"
|
||||
)
|
||||
|
||||
type GormWechatOrderRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ domain_finance_repo.WechatOrderRepository = (*GormWechatOrderRepository)(nil)
|
||||
|
||||
func NewGormWechatOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WechatOrderRepository {
|
||||
return &GormWechatOrderRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WechatOrdersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Create(ctx context.Context, order entities.WechatOrder) (entities.WechatOrder, error) {
|
||||
err := r.CreateEntity(ctx, &order)
|
||||
return order, err
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByID(ctx context.Context, id string) (entities.WechatOrder, error) {
|
||||
var order entities.WechatOrder
|
||||
err := r.SmartGetByID(ctx, id, &order)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.WechatOrder{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.WechatOrder{}, err
|
||||
}
|
||||
return order, nil
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.WechatOrder, error) {
|
||||
var order entities.WechatOrder
|
||||
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.WechatOrder, error) {
|
||||
var order entities.WechatOrder
|
||||
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &order, nil
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.WechatOrder, error) {
|
||||
var orders []entities.WechatOrder
|
||||
// 需要通过充值记录关联查询,这里简化处理
|
||||
err := r.GetDB(ctx).Find(&orders).Error
|
||||
return orders, err
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Update(ctx context.Context, order entities.WechatOrder) error {
|
||||
return r.UpdateEntity(ctx, &order)
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.WechatOrderStatus) error {
|
||||
return r.GetDB(ctx).Model(&entities.WechatOrder{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.WechatOrder{})
|
||||
}
|
||||
|
||||
func (r *GormWechatOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.WechatOrder{})
|
||||
}
|
||||
@@ -0,0 +1,342 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
"hyapi-server/internal/domains/finance/repositories"
|
||||
"hyapi-server/internal/domains/finance/value_objects"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormInvoiceApplicationRepository 发票申请仓储的GORM实现
|
||||
type GormInvoiceApplicationRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormInvoiceApplicationRepository 创建发票申请仓储
|
||||
func NewGormInvoiceApplicationRepository(db *gorm.DB) repositories.InvoiceApplicationRepository {
|
||||
return &GormInvoiceApplicationRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建发票申请
|
||||
func (r *GormInvoiceApplicationRepository) Create(ctx context.Context, application *entities.InvoiceApplication) error {
|
||||
return r.db.WithContext(ctx).Create(application).Error
|
||||
}
|
||||
|
||||
// Update 更新发票申请
|
||||
func (r *GormInvoiceApplicationRepository) Update(ctx context.Context, application *entities.InvoiceApplication) error {
|
||||
return r.db.WithContext(ctx).Save(application).Error
|
||||
}
|
||||
|
||||
// Save 保存发票申请
|
||||
func (r *GormInvoiceApplicationRepository) Save(ctx context.Context, application *entities.InvoiceApplication) error {
|
||||
return r.db.WithContext(ctx).Save(application).Error
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByID(ctx context.Context, id string) (*entities.InvoiceApplication, error) {
|
||||
var application entities.InvoiceApplication
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&application).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &application, nil
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找发票申请列表
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
// 获取总数
|
||||
err := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID).Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.WithContext(ctx).Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindPendingApplications 查找待处理的发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindPendingApplications(ctx context.Context, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
// 获取总数
|
||||
err := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).
|
||||
Where("status = ?", entities.ApplicationStatusPending).
|
||||
Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = r.db.WithContext(ctx).
|
||||
Where("status = ?", entities.ApplicationStatusPending).
|
||||
Order("created_at ASC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByUserIDAndStatus 根据用户ID和状态查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserIDAndStatus(ctx context.Context, userID string, status entities.ApplicationStatus, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID)
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByUserIDAndStatusWithTimeRange 根据用户ID、状态和时间范围查找发票申请列表
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserIDAndStatusWithTimeRange(ctx context.Context, userID string, status entities.ApplicationStatus, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID)
|
||||
|
||||
// 添加状态筛选
|
||||
if status != "" {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
|
||||
// 添加时间范围筛选
|
||||
if startTime != nil {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime != nil {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByStatus(ctx context.Context, status entities.ApplicationStatus) ([]*entities.InvoiceApplication, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&applications).Error
|
||||
return applications, err
|
||||
}
|
||||
|
||||
// GetUserInvoiceInfo 获取用户发票信息
|
||||
|
||||
|
||||
|
||||
|
||||
// GetUserTotalInvoicedAmount 获取用户已开票总金额
|
||||
func (r *GormInvoiceApplicationRepository) GetUserTotalInvoicedAmount(ctx context.Context, userID string) (string, error) {
|
||||
var total string
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&entities.InvoiceApplication{}).
|
||||
Select("COALESCE(SUM(CAST(amount AS DECIMAL(10,2))), '0')").
|
||||
Where("user_id = ? AND status = ?", userID, entities.ApplicationStatusCompleted).
|
||||
Scan(&total).Error
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetUserTotalAppliedAmount 获取用户申请开票总金额
|
||||
func (r *GormInvoiceApplicationRepository) GetUserTotalAppliedAmount(ctx context.Context, userID string) (string, error) {
|
||||
var total string
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&entities.InvoiceApplication{}).
|
||||
Select("COALESCE(SUM(CAST(amount AS DECIMAL(10,2))), '0')").
|
||||
Where("user_id = ?", userID).
|
||||
Scan(&total).Error
|
||||
|
||||
return total, err
|
||||
}
|
||||
|
||||
// FindByUserIDAndInvoiceType 根据用户ID和发票类型查找申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByUserIDAndInvoiceType(ctx context.Context, userID string, invoiceType value_objects.InvoiceType, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ? AND invoice_type = ?", userID, invoiceType)
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByDateRange 根据日期范围查找申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByDateRange(ctx context.Context, startDate, endDate string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{})
|
||||
if startDate != "" {
|
||||
query = query.Where("DATE(created_at) >= ?", startDate)
|
||||
}
|
||||
if endDate != "" {
|
||||
query = query.Where("DATE(created_at) <= ?", endDate)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// SearchApplications 搜索发票申请
|
||||
func (r *GormInvoiceApplicationRepository) SearchApplications(ctx context.Context, keyword string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).
|
||||
Where("company_name LIKE ? OR email LIKE ? OR tax_number LIKE ?",
|
||||
fmt.Sprintf("%%%s%%", keyword),
|
||||
fmt.Sprintf("%%%s%%", keyword),
|
||||
fmt.Sprintf("%%%s%%", keyword))
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindByStatusWithTimeRange 根据状态和时间范围查找发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindByStatusWithTimeRange(ctx context.Context, status entities.ApplicationStatus, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("status = ?", status)
|
||||
|
||||
// 添加时间范围筛选
|
||||
if startTime != nil {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime != nil {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
|
||||
// FindAllWithTimeRange 根据时间范围查找所有发票申请
|
||||
func (r *GormInvoiceApplicationRepository) FindAllWithTimeRange(ctx context.Context, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
|
||||
var applications []*entities.InvoiceApplication
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{})
|
||||
|
||||
// 添加时间范围筛选
|
||||
if startTime != nil {
|
||||
query = query.Where("created_at >= ?", startTime)
|
||||
}
|
||||
if endTime != nil {
|
||||
query = query.Where("created_at <= ?", endTime)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
err := query.Count(&total).Error
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取分页数据
|
||||
offset := (page - 1) * pageSize
|
||||
err = query.Order("created_at DESC").
|
||||
Offset(offset).
|
||||
Limit(pageSize).
|
||||
Find(&applications).Error
|
||||
|
||||
return applications, total, err
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"hyapi-server/internal/domains/finance/entities"
|
||||
"hyapi-server/internal/domains/finance/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormUserInvoiceInfoRepository 用户开票信息仓储的GORM实现
|
||||
type GormUserInvoiceInfoRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormUserInvoiceInfoRepository 创建用户开票信息仓储
|
||||
func NewGormUserInvoiceInfoRepository(db *gorm.DB) repositories.UserInvoiceInfoRepository {
|
||||
return &GormUserInvoiceInfoRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建用户开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) Create(ctx context.Context, info *entities.UserInvoiceInfo) error {
|
||||
return r.db.WithContext(ctx).Create(info).Error
|
||||
}
|
||||
|
||||
// Update 更新用户开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) Update(ctx context.Context, info *entities.UserInvoiceInfo) error {
|
||||
return r.db.WithContext(ctx).Save(info).Error
|
||||
}
|
||||
|
||||
// Save 保存用户开票信息(创建或更新)
|
||||
func (r *GormUserInvoiceInfoRepository) Save(ctx context.Context, info *entities.UserInvoiceInfo) error {
|
||||
return r.db.WithContext(ctx).Save(info).Error
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserInvoiceInfo, error) {
|
||||
var info entities.UserInvoiceInfo
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&info).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) FindByID(ctx context.Context, id string) (*entities.UserInvoiceInfo, error) {
|
||||
var info entities.UserInvoiceInfo
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&info).Error
|
||||
if err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// Delete 删除用户开票信息
|
||||
func (r *GormUserInvoiceInfoRepository) Delete(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&entities.UserInvoiceInfo{}).Error
|
||||
}
|
||||
|
||||
// Exists 检查用户开票信息是否存在
|
||||
func (r *GormUserInvoiceInfoRepository) Exists(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserInvoiceInfo{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ComponentReportDownloadsTable = "component_report_downloads"
|
||||
)
|
||||
|
||||
type GormComponentReportRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ComponentReportRepository = (*GormComponentReportRepository)(nil)
|
||||
|
||||
func NewGormComponentReportRepository(db *gorm.DB, logger *zap.Logger) repositories.ComponentReportRepository {
|
||||
return &GormComponentReportRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ComponentReportDownloadsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) Create(ctx context.Context, download *entities.ComponentReportDownload) error {
|
||||
return r.CreateEntity(ctx, download)
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) UpdateDownload(ctx context.Context, download *entities.ComponentReportDownload) error {
|
||||
return r.UpdateEntity(ctx, download)
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetDownloadByID(ctx context.Context, id string) (*entities.ComponentReportDownload, error) {
|
||||
var download entities.ComponentReportDownload
|
||||
err := r.SmartGetByID(ctx, id, &download)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &download, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetUserDownloads(ctx context.Context, userID string, productID *string) ([]*entities.ComponentReportDownload, error) {
|
||||
var downloads []entities.ComponentReportDownload
|
||||
query := r.GetDB(ctx).Where("user_id = ?", userID)
|
||||
|
||||
if productID != nil && *productID != "" {
|
||||
query = query.Where("product_id = ?", *productID)
|
||||
}
|
||||
|
||||
err := query.Order("created_at DESC").Find(&downloads).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ComponentReportDownload, len(downloads))
|
||||
for i := range downloads {
|
||||
result[i] = &downloads[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) HasUserDownloaded(ctx context.Context, userID string, productCode string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).
|
||||
Where("user_id = ? AND product_code = ?", userID, productCode).
|
||||
Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetUserDownloadedProductCodes(ctx context.Context, userID string) ([]string, error) {
|
||||
var downloads []entities.ComponentReportDownload
|
||||
err := r.GetDB(ctx).
|
||||
Select("DISTINCT sub_product_codes").
|
||||
Where("user_id = ?", userID).
|
||||
Find(&downloads).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
codesMap := make(map[string]bool)
|
||||
for _, download := range downloads {
|
||||
if download.SubProductCodes != "" {
|
||||
var codes []string
|
||||
if err := json.Unmarshal([]byte(download.SubProductCodes), &codes); err == nil {
|
||||
for _, code := range codes {
|
||||
codesMap[code] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
// 也添加主产品编号
|
||||
if download.ProductCode != "" {
|
||||
codesMap[download.ProductCode] = true
|
||||
}
|
||||
}
|
||||
|
||||
codes := make([]string, 0, len(codesMap))
|
||||
for code := range codesMap {
|
||||
codes = append(codes, code)
|
||||
}
|
||||
return codes, nil
|
||||
}
|
||||
|
||||
func (r *GormComponentReportRepository) GetDownloadByPaymentOrderID(ctx context.Context, orderID string) (*entities.ComponentReportDownload, error) {
|
||||
var download entities.ComponentReportDownload
|
||||
err := r.GetDB(ctx).Where("order_id = ?", orderID).First(&download).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &download, nil
|
||||
}
|
||||
|
||||
// GetActiveDownload 获取用户有效的下载记录
|
||||
func (r *GormComponentReportRepository) GetActiveDownload(ctx context.Context, userID, productID string) (*entities.ComponentReportDownload, error) {
|
||||
var download entities.ComponentReportDownload
|
||||
|
||||
// 先尝试查找有支付订单号的下载记录(已支付)
|
||||
err := r.GetDB(ctx).
|
||||
Where("user_id = ? AND product_id = ? AND order_number IS NOT NULL AND deleted_at IS NULL", userID, productID).
|
||||
Order("created_at DESC").
|
||||
First(&download).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
// 如果没有找到有支付订单号的记录,尝试查找任何有效的下载记录
|
||||
err = r.GetDB(ctx).
|
||||
Where("user_id = ? AND product_id = ? AND deleted_at IS NULL", userID, productID).
|
||||
Order("created_at DESC").
|
||||
First(&download).Error
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 如果找到了下载记录,检查关联的购买订单状态
|
||||
if download.OrderID != nil {
|
||||
// 这里需要查询购买订单状态,但当前仓库没有依赖购买订单仓库
|
||||
// 所以只检查是否有过期时间设置,如果有则认为已支付
|
||||
if download.ExpiresAt == nil {
|
||||
return nil, nil // 没有过期时间,表示未支付
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否已过期
|
||||
if download.IsExpired() {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &download, nil
|
||||
}
|
||||
|
||||
// UpdateFilePath 更新下载记录文件路径
|
||||
func (r *GormComponentReportRepository) UpdateFilePath(ctx context.Context, downloadID, filePath string) error {
|
||||
return r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).Where("id = ?", downloadID).Update("file_path", filePath).Error
|
||||
}
|
||||
|
||||
// IncrementDownloadCount 增加下载次数
|
||||
func (r *GormComponentReportRepository) IncrementDownloadCount(ctx context.Context, downloadID string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).
|
||||
Where("id = ?", downloadID).
|
||||
Updates(map[string]interface{}{
|
||||
"download_count": gorm.Expr("download_count + 1"),
|
||||
"last_download_at": &now,
|
||||
}).Error
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductApiConfigsTable = "product_api_configs"
|
||||
ProductApiConfigCacheTTL = 30 * time.Minute
|
||||
)
|
||||
|
||||
type GormProductApiConfigRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ProductApiConfigRepository = (*GormProductApiConfigRepository)(nil)
|
||||
|
||||
func NewGormProductApiConfigRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductApiConfigRepository {
|
||||
return &GormProductApiConfigRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductApiConfigsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Create(ctx context.Context, config entities.ProductApiConfig) error {
|
||||
return r.CreateEntity(ctx, &config)
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Update(ctx context.Context, config entities.ProductApiConfig) error {
|
||||
return r.UpdateEntity(ctx, &config)
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductApiConfig{})
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) GetByID(ctx context.Context, id string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.SmartGetByID(ctx, id, &config)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.SmartGetByField(ctx, &config, "product_id", productID, ProductApiConfigCacheTTL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductCode(ctx context.Context, productCode string) (*entities.ProductApiConfig, error) {
|
||||
var config entities.ProductApiConfig
|
||||
err := r.GetDB(ctx).Joins("JOIN products ON products.id = product_api_configs.product_id").
|
||||
Where("products.code = ?", productCode).
|
||||
First(&config).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductApiConfig, error) {
|
||||
var configs []*entities.ProductApiConfig
|
||||
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&configs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func (r *GormProductApiConfigRepository) ExistsByProductID(ctx context.Context, productID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductApiConfig{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
@@ -0,0 +1,281 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/domains/product/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductCategoriesTable = "product_categories"
|
||||
)
|
||||
|
||||
type GormProductCategoryRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductCategory{})
|
||||
}
|
||||
|
||||
var _ repositories.ProductCategoryRepository = (*GormProductCategoryRepository)(nil)
|
||||
|
||||
func NewGormProductCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductCategoryRepository {
|
||||
return &GormProductCategoryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductCategoriesTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Create(ctx context.Context, entity entities.ProductCategory) (entities.ProductCategory, error) {
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string) (entities.ProductCategory, error) {
|
||||
var entity entities.ProductCategory
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.ProductCategory{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.ProductCategory{}, err
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (r *GormProductCategoryRepository) Update(ctx context.Context, entity entities.ProductCategory) error {
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// FindByCode 根据编号查找产品分类
|
||||
func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductCategory, error) {
|
||||
var entity entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindVisible 查找可见分类
|
||||
func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindEnabled 查找启用分类
|
||||
func (r *GormProductCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("is_enabled = ?", true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListCategories 获取分类列表
|
||||
func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, query *queries.ListCategoriesQuery) ([]*entities.ProductCategory, int64, error) {
|
||||
var categories []entities.ProductCategory
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.IsEnabled != nil {
|
||||
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
|
||||
}
|
||||
if query.IsVisible != nil {
|
||||
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
// 默认按排序字段和创建时间排序
|
||||
dbQuery = dbQuery.Order("sort ASC, created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
if err := dbQuery.Find(&categories).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// CountEnabled 统计启用分类数量
|
||||
func (r *GormProductCategoryRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountVisible 统计可见分类数量
|
||||
func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// 基础Repository接口方法
|
||||
|
||||
// Count 返回分类总数
|
||||
func (r *GormProductCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取分类
|
||||
func (r *GormProductCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建分类
|
||||
func (r *GormProductCategoryRepository) CreateBatch(ctx context.Context, categories []entities.ProductCategory) error {
|
||||
return r.GetDB(ctx).Create(&categories).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新分类
|
||||
func (r *GormProductCategoryRepository) UpdateBatch(ctx context.Context, categories []entities.ProductCategory) error {
|
||||
return r.GetDB(ctx).Save(&categories).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除分类
|
||||
func (r *GormProductCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取分类列表(基础方法)
|
||||
func (r *GormProductCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ProductCategory, error) {
|
||||
var categories []entities.ProductCategory
|
||||
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按排序字段和创建时间倒序
|
||||
query = query.Order("sort ASC, created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&categories).Error
|
||||
return categories, err
|
||||
}
|
||||
|
||||
// Exists 检查分类是否存在
|
||||
func (r *GormProductCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除分类
|
||||
func (r *GormProductCategoryRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的分类
|
||||
func (r *GormProductCategoryRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormProductCategoryRepository) WithTx(tx interface{}) interfaces.Repository[entities.ProductCategory] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormProductCategoryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductCategoriesTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductDocumentationsTable = "product_documentations"
|
||||
)
|
||||
|
||||
type GormProductDocumentationRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductDocumentationRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductDocumentation{})
|
||||
}
|
||||
|
||||
var _ repositories.ProductDocumentationRepository = (*GormProductDocumentationRepository)(nil)
|
||||
|
||||
func NewGormProductDocumentationRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductDocumentationRepository {
|
||||
return &GormProductDocumentationRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductDocumentationsTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建文档
|
||||
func (r *GormProductDocumentationRepository) Create(ctx context.Context, documentation *entities.ProductDocumentation) error {
|
||||
return r.CreateEntity(ctx, documentation)
|
||||
}
|
||||
|
||||
// Update 更新文档
|
||||
func (r *GormProductDocumentationRepository) Update(ctx context.Context, documentation *entities.ProductDocumentation) error {
|
||||
return r.UpdateEntity(ctx, documentation)
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找文档
|
||||
func (r *GormProductDocumentationRepository) FindByID(ctx context.Context, id string) (*entities.ProductDocumentation, error) {
|
||||
var entity entities.ProductDocumentation
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByProductID 根据产品ID查找文档
|
||||
func (r *GormProductDocumentationRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductDocumentation, error) {
|
||||
var entity entities.ProductDocumentation
|
||||
err := r.GetDB(ctx).Where("product_id = ?", productID).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByProductIDs 根据产品ID列表批量查找文档
|
||||
func (r *GormProductDocumentationRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductDocumentation, error) {
|
||||
var documentations []entities.ProductDocumentation
|
||||
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&documentations).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductDocumentation, len(documentations))
|
||||
for i := range documentations {
|
||||
result[i] = &documentations[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新文档
|
||||
func (r *GormProductDocumentationRepository) UpdateBatch(ctx context.Context, documentations []*entities.ProductDocumentation) error {
|
||||
if len(documentations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用事务进行批量更新
|
||||
return r.GetDB(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for _, doc := range documentations {
|
||||
if err := tx.Save(doc).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CountByProductID 统计指定产品的文档数量
|
||||
func (r *GormProductDocumentationRepository) CountByProductID(ctx context.Context, productID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ProductDocumentation{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,521 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/domains/product/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductsTable = "products"
|
||||
)
|
||||
|
||||
type GormProductRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Product{})
|
||||
}
|
||||
|
||||
var _ repositories.ProductRepository = (*GormProductRepository)(nil)
|
||||
|
||||
func NewGormProductRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductRepository {
|
||||
return &GormProductRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Create(ctx context.Context, entity entities.Product) (entities.Product, error) {
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Product{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Product{}, err
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (r *GormProductRepository) Update(ctx context.Context, entity entities.Product) error {
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// 其它方法同理迁移,全部用r.GetDB(ctx)
|
||||
|
||||
// FindByCode 根据编号查找产品
|
||||
func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.SmartGetByField(ctx, &entity, "code", code) // 自动缓存
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByOldID 根据旧ID查找产品
|
||||
func (r *GormProductRepository) FindByOldID(ctx context.Context, oldID string) (*entities.Product, error) {
|
||||
var entity entities.Product
|
||||
err := r.GetDB(ctx).Where("old_id = ?", oldID).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根据分类ID查找产品
|
||||
func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Preload("Category").Where("category_id = ?", categoryID).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindVisible 查找可见产品
|
||||
func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindEnabled 查找启用产品
|
||||
func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Preload("Category").Where("is_enabled = ?", true).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ListProducts 获取产品列表
|
||||
func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, int64, error) {
|
||||
var productEntities []entities.Product
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
|
||||
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
if query.CategoryID != "" {
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
if query.MinPrice != nil {
|
||||
dbQuery = dbQuery.Where("price >= ?", *query.MinPrice)
|
||||
}
|
||||
if query.MaxPrice != nil {
|
||||
dbQuery = dbQuery.Where("price <= ?", *query.MaxPrice)
|
||||
}
|
||||
if query.IsEnabled != nil {
|
||||
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
|
||||
}
|
||||
if query.IsVisible != nil {
|
||||
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
|
||||
}
|
||||
if query.IsPackage != nil {
|
||||
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载分类信息并获取数据
|
||||
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// ListProductsWithSubscriptionStatus 获取产品列表(包含订阅状态)
|
||||
func (r *GormProductRepository) ListProductsWithSubscriptionStatus(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, map[string]bool, int64, error) {
|
||||
var productEntities []entities.Product
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
|
||||
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
if query.CategoryID != "" {
|
||||
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
|
||||
}
|
||||
if query.MinPrice != nil {
|
||||
dbQuery = dbQuery.Where("price >= ?", *query.MinPrice)
|
||||
}
|
||||
if query.MaxPrice != nil {
|
||||
dbQuery = dbQuery.Where("price <= ?", *query.MaxPrice)
|
||||
}
|
||||
if query.IsEnabled != nil {
|
||||
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
|
||||
}
|
||||
if query.IsVisible != nil {
|
||||
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
|
||||
}
|
||||
if query.IsPackage != nil {
|
||||
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
|
||||
}
|
||||
|
||||
// 如果指定了用户ID,添加订阅状态筛选
|
||||
if query.UserID != "" && query.IsSubscribed != nil {
|
||||
if *query.IsSubscribed {
|
||||
// 筛选已订阅的产品
|
||||
dbQuery = dbQuery.Where("EXISTS (SELECT 1 FROM subscription WHERE subscription.product_id = product.id AND subscription.user_id = ?)", query.UserID)
|
||||
} else {
|
||||
// 筛选未订阅的产品
|
||||
dbQuery = dbQuery.Where("NOT EXISTS (SELECT 1 FROM subscription WHERE subscription.product_id = product.id AND subscription.user_id = ?)", query.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载分类信息并获取数据
|
||||
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
|
||||
// 获取订阅状态映射
|
||||
subscriptionStatusMap := make(map[string]bool)
|
||||
if query.UserID != "" && len(result) > 0 {
|
||||
productIDs := make([]string, len(result))
|
||||
for i, product := range result {
|
||||
productIDs[i] = product.ID
|
||||
}
|
||||
|
||||
// 查询用户的订阅状态
|
||||
var subscriptions []struct {
|
||||
ProductID string `gorm:"column:product_id"`
|
||||
}
|
||||
err := r.GetDB(ctx).Table("subscription").
|
||||
Select("product_id").
|
||||
Where("user_id = ? AND product_id IN ?", query.UserID, productIDs).
|
||||
Find(&subscriptions).Error
|
||||
|
||||
if err == nil {
|
||||
for _, sub := range subscriptions {
|
||||
subscriptionStatusMap[sub.ProductID] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, subscriptionStatusMap, total, nil
|
||||
}
|
||||
|
||||
// FindSubscribableProducts 查找可订阅产品
|
||||
func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, userID string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindProductsByIDs 根据ID列表查找产品
|
||||
func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []string) ([]*entities.Product, error) {
|
||||
var productEntities []entities.Product
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&productEntities).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Product, len(productEntities))
|
||||
for i := range productEntities {
|
||||
result[i] = &productEntities[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CountByCategory 统计分类下的产品数量
|
||||
func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID string) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
if categoryID != "" {
|
||||
query = query.Where("category_id = ?", categoryID)
|
||||
}
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountEnabled 统计启用产品数量
|
||||
func (r *GormProductRepository) CountEnabled(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountVisible 统计可见产品数量
|
||||
func (r *GormProductRepository) CountVisible(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Count 返回产品总数
|
||||
func (r *GormProductRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取产品
|
||||
func (r *GormProductRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Product, error) {
|
||||
var products []entities.Product
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建产品
|
||||
func (r *GormProductRepository) CreateBatch(ctx context.Context, products []entities.Product) error {
|
||||
return r.GetDB(ctx).Create(&products).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新产品
|
||||
func (r *GormProductRepository) UpdateBatch(ctx context.Context, products []entities.Product) error {
|
||||
return r.GetDB(ctx).Save(&products).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除产品
|
||||
func (r *GormProductRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取产品列表(基础方法)
|
||||
func (r *GormProductRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Product, error) {
|
||||
var products []entities.Product
|
||||
query := r.GetDB(ctx).Model(&entities.Product{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&products).Error
|
||||
return products, err
|
||||
}
|
||||
|
||||
// Exists 检查产品是否存在
|
||||
func (r *GormProductRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除产品
|
||||
func (r *GormProductRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.Product{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的产品
|
||||
func (r *GormProductRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// GetPackageItems 获取组合包项目
|
||||
func (r *GormProductRepository) GetPackageItems(ctx context.Context, packageID string) ([]*entities.ProductPackageItem, error) {
|
||||
var packageItems []entities.ProductPackageItem
|
||||
err := r.GetDB(ctx).
|
||||
Preload("Product").
|
||||
Where("package_id = ?", packageID).
|
||||
Order("sort_order ASC").
|
||||
Find(&packageItems).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductPackageItem, len(packageItems))
|
||||
for i := range packageItems {
|
||||
result[i] = &packageItems[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// CreatePackageItem 创建组合包项目
|
||||
func (r *GormProductRepository) CreatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
|
||||
return r.GetDB(ctx).Create(packageItem).Error
|
||||
}
|
||||
|
||||
// GetPackageItemByID 根据ID获取组合包项目
|
||||
func (r *GormProductRepository) GetPackageItemByID(ctx context.Context, itemID string) (*entities.ProductPackageItem, error) {
|
||||
var packageItem entities.ProductPackageItem
|
||||
err := r.GetDB(ctx).
|
||||
Preload("Product").
|
||||
Preload("Package").
|
||||
Where("id = ?", itemID).
|
||||
First(&packageItem).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &packageItem, nil
|
||||
}
|
||||
|
||||
// UpdatePackageItem 更新组合包项目
|
||||
func (r *GormProductRepository) UpdatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
|
||||
return r.GetDB(ctx).Save(packageItem).Error
|
||||
}
|
||||
|
||||
// DeletePackageItem 删除组合包项目(硬删除)
|
||||
func (r *GormProductRepository) DeletePackageItem(ctx context.Context, itemID string) error {
|
||||
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "id = ?", itemID).Error
|
||||
}
|
||||
|
||||
// DeletePackageItemsByPackageID 根据组合包ID删除所有子产品(硬删除)
|
||||
func (r *GormProductRepository) DeletePackageItemsByPackageID(ctx context.Context, packageID string) error {
|
||||
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "package_id = ?", packageID).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormProductRepository) WithTx(tx interface{}) interfaces.Repository[entities.Product] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormProductRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ProductSubCategoriesTable = "product_sub_categories"
|
||||
)
|
||||
|
||||
type GormProductSubCategoryRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.ProductSubCategoryRepository = (*GormProductSubCategoryRepository)(nil)
|
||||
|
||||
func NewGormProductSubCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductSubCategoryRepository {
|
||||
return &GormProductSubCategoryRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductSubCategoriesTable),
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建二级分类
|
||||
func (r *GormProductSubCategoryRepository) Create(ctx context.Context, category entities.ProductSubCategory) (*entities.ProductSubCategory, error) {
|
||||
err := r.CreateEntity(ctx, &category)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &category, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取二级分类
|
||||
func (r *GormProductSubCategoryRepository) GetByID(ctx context.Context, id string) (*entities.ProductSubCategory, error) {
|
||||
var entity entities.ProductSubCategory
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// Update 更新二级分类
|
||||
func (r *GormProductSubCategoryRepository) Update(ctx context.Context, category entities.ProductSubCategory) error {
|
||||
return r.UpdateEntity(ctx, &category)
|
||||
}
|
||||
|
||||
// Delete 删除二级分类
|
||||
func (r *GormProductSubCategoryRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.ProductSubCategory{})
|
||||
}
|
||||
|
||||
// List 获取所有二级分类
|
||||
func (r *GormProductSubCategoryRepository) List(ctx context.Context) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByCode 根据编号查找二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductSubCategory, error) {
|
||||
var entity entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// FindByCategoryID 根据一级分类ID查找二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("category_id = ?", categoryID).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindVisible 查找可见的二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindEnabled 查找启用的二级分类
|
||||
func (r *GormProductSubCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductSubCategory, error) {
|
||||
var categories []entities.ProductSubCategory
|
||||
err := r.GetDB(ctx).Where("is_enabled = ?", true).Order("sort ASC, created_at DESC").Find(&categories).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.ProductSubCategory, len(categories))
|
||||
for i := range categories {
|
||||
result[i] = &categories[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormProductUIComponentRepository 产品UI组件关联仓储实现
|
||||
type GormProductUIComponentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormProductUIComponentRepository 创建产品UI组件关联仓储实例
|
||||
func NewGormProductUIComponentRepository(db *gorm.DB) repositories.ProductUIComponentRepository {
|
||||
return &GormProductUIComponentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建产品UI组件关联
|
||||
func (r *GormProductUIComponentRepository) Create(ctx context.Context, relation entities.ProductUIComponent) (entities.ProductUIComponent, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&relation).Error; err != nil {
|
||||
return entities.ProductUIComponent{}, fmt.Errorf("创建产品UI组件关联失败: %w", err)
|
||||
}
|
||||
return relation, nil
|
||||
}
|
||||
|
||||
// GetByProductID 根据产品ID获取UI组件关联列表
|
||||
func (r *GormProductUIComponentRepository) GetByProductID(ctx context.Context, productID string) ([]entities.ProductUIComponent, error) {
|
||||
var relations []entities.ProductUIComponent
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("UIComponent").
|
||||
Where("product_id = ?", productID).
|
||||
Find(&relations).Error; err != nil {
|
||||
return nil, fmt.Errorf("获取产品UI组件关联列表失败: %w", err)
|
||||
}
|
||||
return relations, nil
|
||||
}
|
||||
|
||||
// GetByUIComponentID 根据UI组件ID获取产品关联列表
|
||||
func (r *GormProductUIComponentRepository) GetByUIComponentID(ctx context.Context, componentID string) ([]entities.ProductUIComponent, error) {
|
||||
var relations []entities.ProductUIComponent
|
||||
if err := r.db.WithContext(ctx).
|
||||
Preload("Product").
|
||||
Where("ui_component_id = ?", componentID).
|
||||
Find(&relations).Error; err != nil {
|
||||
return nil, fmt.Errorf("获取UI组件产品关联列表失败: %w", err)
|
||||
}
|
||||
return relations, nil
|
||||
}
|
||||
|
||||
// Delete 删除产品UI组件关联
|
||||
func (r *GormProductUIComponentRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.ProductUIComponent{}, id).Error; err != nil {
|
||||
return fmt.Errorf("删除产品UI组件关联失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByProductID 根据产品ID删除所有关联
|
||||
func (r *GormProductUIComponentRepository) DeleteByProductID(ctx context.Context, productID string) error {
|
||||
if err := r.db.WithContext(ctx).Where("product_id = ?", productID).Delete(&entities.ProductUIComponent{}).Error; err != nil {
|
||||
return fmt.Errorf("根据产品ID删除UI组件关联失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建产品UI组件关联
|
||||
func (r *GormProductUIComponentRepository) BatchCreate(ctx context.Context, relations []entities.ProductUIComponent) error {
|
||||
if len(relations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).CreateInBatches(relations, 100).Error; err != nil {
|
||||
return fmt.Errorf("批量创建产品UI组件关联失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,354 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
"hyapi-server/internal/domains/product/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
SubscriptionsTable = "subscription"
|
||||
SubscriptionCacheTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
type GormSubscriptionRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.Subscription{})
|
||||
}
|
||||
|
||||
var _ repositories.SubscriptionRepository = (*GormSubscriptionRepository)(nil)
|
||||
|
||||
func NewGormSubscriptionRepository(db *gorm.DB, logger *zap.Logger) repositories.SubscriptionRepository {
|
||||
return &GormSubscriptionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SubscriptionsTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Create(ctx context.Context, entity entities.Subscription) (entities.Subscription, error) {
|
||||
err := r.CreateEntity(ctx, &entity)
|
||||
return entity, err
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (entities.Subscription, error) {
|
||||
var entity entities.Subscription
|
||||
err := r.SmartGetByID(ctx, id, &entity)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.Subscription{}, gorm.ErrRecordNotFound
|
||||
}
|
||||
return entities.Subscription{}, err
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
func (r *GormSubscriptionRepository) Update(ctx context.Context, entity entities.Subscription) error {
|
||||
return r.UpdateEntity(ctx, &entity)
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&subscriptions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Subscription, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
result[i] = &subscriptions[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByProductID 根据产品ID查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, productID string) ([]*entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("product_id = ?", productID).Order("created_at DESC").Find(&subscriptions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Subscription, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
result[i] = &subscriptions[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// FindByUserAndProduct 根据用户和产品查找订阅
|
||||
func (r *GormSubscriptionRepository) FindByUserAndProduct(ctx context.Context, userID, productID string) (*entities.Subscription, error) {
|
||||
var entity entities.Subscription
|
||||
// 组合缓存key的条件
|
||||
where := "user_id = ? AND product_id = ?"
|
||||
ttl := SubscriptionCacheTTL // 缓存10分钟,可根据业务调整
|
||||
err := r.GetWithCache(ctx, &entity, ttl, where, userID, productID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// ListSubscriptions 获取订阅列表
|
||||
func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, query *queries.ListSubscriptionsQuery) ([]*entities.Subscription, int64, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
var total int64
|
||||
|
||||
dbQuery := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("subscription.user_id = ?", query.UserID)
|
||||
}
|
||||
|
||||
// 关键词搜索(产品名称或编码)
|
||||
if query.Keyword != "" {
|
||||
dbQuery = dbQuery.Joins("LEFT JOIN product ON product.id = subscription.product_id").
|
||||
Where("product.name LIKE ? OR product.code LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if query.ProductName != "" {
|
||||
dbQuery = dbQuery.Joins("LEFT JOIN product ON product.id = subscription.product_id").
|
||||
Where("product.name LIKE ?", "%"+query.ProductName+"%")
|
||||
}
|
||||
|
||||
// 企业名称筛选(需要关联用户和企业信息)
|
||||
if query.CompanyName != "" {
|
||||
dbQuery = dbQuery.Joins("LEFT JOIN users ON users.id = subscription.user_id").
|
||||
Joins("LEFT JOIN enterprise_infos ON enterprise_infos.user_id = users.id").
|
||||
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if query.StartTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", query.StartTime); err == nil {
|
||||
dbQuery = dbQuery.Where("subscription.created_at >= ?", t)
|
||||
}
|
||||
}
|
||||
if query.EndTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", query.EndTime); err == nil {
|
||||
dbQuery = dbQuery.Where("subscription.created_at <= ?", t)
|
||||
}
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if query.SortBy != "" {
|
||||
order := query.SortBy
|
||||
if query.SortOrder == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
dbQuery = dbQuery.Order(order)
|
||||
} else {
|
||||
dbQuery = dbQuery.Order("subscription.created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if query.Page > 0 && query.PageSize > 0 {
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
}
|
||||
|
||||
// 预加载Product的id、name、code、price、cost_price、is_package字段,并同时预加载ProductCategory的id、name、code字段
|
||||
if err := dbQuery.
|
||||
Preload("Product", func(db *gorm.DB) *gorm.DB {
|
||||
return db.Select("id", "name", "code", "price", "cost_price", "is_package", "category_id").
|
||||
Preload("Category", func(db2 *gorm.DB) *gorm.DB {
|
||||
return db2.Select("id", "name", "code")
|
||||
})
|
||||
}).
|
||||
Find(&subscriptions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
result := make([]*entities.Subscription, len(subscriptions))
|
||||
for i := range subscriptions {
|
||||
result[i] = &subscriptions[i]
|
||||
}
|
||||
|
||||
return result, total, nil
|
||||
}
|
||||
|
||||
// CountByUser 统计用户订阅数量
|
||||
func (r *GormSubscriptionRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// CountByProduct 统计产品的订阅数量
|
||||
func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, productID string) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetTotalRevenue 获取总收入
|
||||
func (r *GormSubscriptionRepository) GetTotalRevenue(ctx context.Context) (float64, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Select("COALESCE(SUM(price), 0)").Scan(&total).Error
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return total.InexactFloat64(), nil
|
||||
}
|
||||
|
||||
// 基础Repository接口方法
|
||||
|
||||
// Count 返回订阅总数
|
||||
func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR product_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取订阅
|
||||
func (r *GormSubscriptionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
err := r.GetDB(ctx).WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&subscriptions).Error
|
||||
return subscriptions, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建订阅
|
||||
func (r *GormSubscriptionRepository) CreateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Create(&subscriptions).Error
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新订阅
|
||||
func (r *GormSubscriptionRepository) UpdateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Save(&subscriptions).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除订阅
|
||||
func (r *GormSubscriptionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取订阅列表(基础方法)
|
||||
func (r *GormSubscriptionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Subscription, error) {
|
||||
var subscriptions []entities.Subscription
|
||||
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR product_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := options.Sort
|
||||
if options.Order == "desc" {
|
||||
order += " DESC"
|
||||
} else {
|
||||
order += " ASC"
|
||||
}
|
||||
query = query.Order(order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&subscriptions).Error
|
||||
return subscriptions, err
|
||||
}
|
||||
|
||||
// Exists 检查订阅是否存在
|
||||
func (r *GormSubscriptionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除订阅
|
||||
func (r *GormSubscriptionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的订阅
|
||||
func (r *GormSubscriptionRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormSubscriptionRepository) WithTx(tx interface{}) interfaces.Repository[entities.Subscription] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormSubscriptionRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), SubscriptionsTable),
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// IncrementAPIUsageWithOptimisticLock 使用乐观锁增加API使用次数
|
||||
func (r *GormSubscriptionRepository) IncrementAPIUsageWithOptimisticLock(ctx context.Context, subscriptionID string, increment int64) error {
|
||||
// 使用原生SQL进行乐观锁更新
|
||||
result := r.GetDB(ctx).WithContext(ctx).Exec(`
|
||||
UPDATE subscription
|
||||
SET api_used = api_used + ?, version = version + 1, updated_at = NOW()
|
||||
WHERE id = ? AND version = (
|
||||
SELECT version FROM subscription WHERE id = ?
|
||||
)
|
||||
`, increment, subscriptionID, subscriptionID)
|
||||
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
// 检查是否有行被更新
|
||||
if result.RowsAffected == 0 {
|
||||
return gorm.ErrRecordNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,130 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"hyapi-server/internal/domains/product/entities"
|
||||
"hyapi-server/internal/domains/product/repositories"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// GormUIComponentRepository UI组件仓储实现
|
||||
type GormUIComponentRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormUIComponentRepository 创建UI组件仓储实例
|
||||
func NewGormUIComponentRepository(db *gorm.DB) repositories.UIComponentRepository {
|
||||
return &GormUIComponentRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建UI组件
|
||||
func (r *GormUIComponentRepository) Create(ctx context.Context, component entities.UIComponent) (entities.UIComponent, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&component).Error; err != nil {
|
||||
return entities.UIComponent{}, fmt.Errorf("创建UI组件失败: %w", err)
|
||||
}
|
||||
return component, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取UI组件
|
||||
func (r *GormUIComponentRepository) GetByID(ctx context.Context, id string) (*entities.UIComponent, error) {
|
||||
var component entities.UIComponent
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&component).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("获取UI组件失败: %w", err)
|
||||
}
|
||||
return &component, nil
|
||||
}
|
||||
|
||||
// GetByCode 根据编码获取UI组件
|
||||
func (r *GormUIComponentRepository) GetByCode(ctx context.Context, code string) (*entities.UIComponent, error) {
|
||||
var component entities.UIComponent
|
||||
if err := r.db.WithContext(ctx).Where("component_code = ?", code).First(&component).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("获取UI组件失败: %w", err)
|
||||
}
|
||||
return &component, nil
|
||||
}
|
||||
|
||||
// List 获取UI组件列表
|
||||
func (r *GormUIComponentRepository) List(ctx context.Context, filters map[string]interface{}) ([]entities.UIComponent, int64, error) {
|
||||
var components []entities.UIComponent
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.UIComponent{})
|
||||
|
||||
// 应用过滤条件
|
||||
if isActive, ok := filters["is_active"]; ok {
|
||||
query = query.Where("is_active = ?", isActive)
|
||||
}
|
||||
|
||||
if keyword, ok := filters["keyword"]; ok && keyword != "" {
|
||||
query = query.Where("component_name LIKE ? OR component_code LIKE ? OR description LIKE ?",
|
||||
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("获取UI组件总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页
|
||||
if page, ok := filters["page"]; ok {
|
||||
if pageSize, ok := filters["page_size"]; ok {
|
||||
offset := (page.(int) - 1) * pageSize.(int)
|
||||
query = query.Offset(offset).Limit(pageSize.(int))
|
||||
}
|
||||
}
|
||||
|
||||
// 排序
|
||||
if sortBy, ok := filters["sort_by"]; ok {
|
||||
if sortOrder, ok := filters["sort_order"]; ok {
|
||||
query = query.Order(fmt.Sprintf("%s %s", sortBy, sortOrder))
|
||||
}
|
||||
} else {
|
||||
query = query.Order("sort_order ASC, created_at DESC")
|
||||
}
|
||||
|
||||
// 获取数据
|
||||
if err := query.Find(&components).Error; err != nil {
|
||||
return nil, 0, fmt.Errorf("获取UI组件列表失败: %w", err)
|
||||
}
|
||||
|
||||
return components, total, nil
|
||||
}
|
||||
|
||||
// Update 更新UI组件
|
||||
func (r *GormUIComponentRepository) Update(ctx context.Context, component entities.UIComponent) error {
|
||||
if err := r.db.WithContext(ctx).Save(&component).Error; err != nil {
|
||||
return fmt.Errorf("更新UI组件失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除UI组件
|
||||
func (r *GormUIComponentRepository) Delete(ctx context.Context, id string) error {
|
||||
// 记录删除操作的详细信息
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.UIComponent{}).Error; err != nil {
|
||||
return fmt.Errorf("删除UI组件失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByCodes 根据编码列表获取UI组件
|
||||
func (r *GormUIComponentRepository) GetByCodes(ctx context.Context, codes []string) ([]entities.UIComponent, error) {
|
||||
var components []entities.UIComponent
|
||||
if len(codes) == 0 {
|
||||
return components, nil
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("component_code IN ?", codes).Find(&components).Error; err != nil {
|
||||
return nil, fmt.Errorf("根据编码列表获取UI组件失败: %w", err)
|
||||
}
|
||||
|
||||
return components, nil
|
||||
}
|
||||
@@ -0,0 +1,461 @@
|
||||
package statistics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/statistics/entities"
|
||||
"hyapi-server/internal/domains/statistics/repositories"
|
||||
)
|
||||
|
||||
// GormStatisticsDashboardRepository GORM统计仪表板仓储实现
|
||||
type GormStatisticsDashboardRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormStatisticsDashboardRepository 创建GORM统计仪表板仓储
|
||||
func NewGormStatisticsDashboardRepository(db *gorm.DB) repositories.StatisticsDashboardRepository {
|
||||
return &GormStatisticsDashboardRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) Save(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
|
||||
if dashboard == nil {
|
||||
return fmt.Errorf("统计仪表板不能为空")
|
||||
}
|
||||
|
||||
// 验证仪表板
|
||||
if err := dashboard.Validate(); err != nil {
|
||||
return fmt.Errorf("统计仪表板验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result := r.db.WithContext(ctx).Save(dashboard)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsDashboard, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
var dashboard entities.StatisticsDashboard
|
||||
result := r.db.WithContext(ctx).Where("id = ?", id).First(&dashboard)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("统计仪表板不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &dashboard, nil
|
||||
}
|
||||
|
||||
// FindByUser 根据用户查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("created_by = ?", userID)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindByUserRole 根据用户角色查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByUserRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// Update 更新统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) Update(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
|
||||
if dashboard == nil {
|
||||
return fmt.Errorf("统计仪表板不能为空")
|
||||
}
|
||||
|
||||
if dashboard.ID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
// 验证仪表板
|
||||
if err := dashboard.Validate(); err != nil {
|
||||
return fmt.Errorf("统计仪表板验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
result := r.db.WithContext(ctx).Save(dashboard)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("统计仪表板不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByRole 根据角色查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindDefaultByRole 根据角色查找默认统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindDefaultByRole(ctx context.Context, userRole string) (*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboard entities.StatisticsDashboard
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("user_role = ? AND is_default = ? AND is_active = ?", userRole, true, true).
|
||||
First(&dashboard)
|
||||
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("默认统计仪表板不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询默认统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &dashboard, nil
|
||||
}
|
||||
|
||||
// FindActiveByRole 根据角色查找激活的统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindActiveByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if userRole == "" {
|
||||
return nil, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("user_role = ? AND is_active = ?", userRole, true)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询激活统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByStatus(ctx context.Context, isActive bool, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("is_active = ?", isActive)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// FindByAccessLevel 根据访问级别查找统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) FindByAccessLevel(ctx context.Context, accessLevel string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
|
||||
if accessLevel == "" {
|
||||
return nil, fmt.Errorf("访问级别不能为空")
|
||||
}
|
||||
|
||||
var dashboards []*entities.StatisticsDashboard
|
||||
query := r.db.WithContext(ctx).Where("access_level = ?", accessLevel)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&dashboards)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return dashboards, nil
|
||||
}
|
||||
|
||||
// CountByUser 根据用户统计数量
|
||||
func (r *GormStatisticsDashboardRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
if userID == "" {
|
||||
return 0, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("created_by = ?", userID).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByRole 根据角色统计数量
|
||||
func (r *GormStatisticsDashboardRepository) CountByRole(ctx context.Context, userRole string) (int64, error) {
|
||||
if userRole == "" {
|
||||
return 0, fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("user_role = ?", userRole).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByStatus 根据状态统计数量
|
||||
func (r *GormStatisticsDashboardRepository) CountByStatus(ctx context.Context, isActive bool) (int64, error) {
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("is_active = ?", isActive).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BatchSave 批量保存统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) BatchSave(ctx context.Context, dashboards []*entities.StatisticsDashboard) error {
|
||||
if len(dashboards) == 0 {
|
||||
return fmt.Errorf("统计仪表板列表不能为空")
|
||||
}
|
||||
|
||||
// 验证所有仪表板
|
||||
for _, dashboard := range dashboards {
|
||||
if err := dashboard.Validate(); err != nil {
|
||||
return fmt.Errorf("统计仪表板验证失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存
|
||||
result := r.db.WithContext(ctx).CreateInBatches(dashboards, 100)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量保存统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除统计仪表板
|
||||
func (r *GormStatisticsDashboardRepository) BatchDelete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("仪表板ID列表不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id IN ?", ids)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量删除统计仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDefaultDashboard 设置默认仪表板
|
||||
func (r *GormStatisticsDashboardRepository) SetDefaultDashboard(ctx context.Context, dashboardID string) error {
|
||||
if dashboardID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
tx := r.db.WithContext(ctx).Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 先取消同角色的所有默认状态
|
||||
var dashboard entities.StatisticsDashboard
|
||||
if err := tx.Where("id = ?", dashboardID).First(&dashboard).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("查询仪表板失败: %w", err)
|
||||
}
|
||||
|
||||
// 取消同角色的所有默认状态
|
||||
if err := tx.Model(&entities.StatisticsDashboard{}).
|
||||
Where("user_role = ? AND is_default = ?", dashboard.UserRole, true).
|
||||
Update("is_default", false).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("取消默认状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置新的默认状态
|
||||
if err := tx.Model(&entities.StatisticsDashboard{}).
|
||||
Where("id = ?", dashboardID).
|
||||
Update("is_default", true).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return fmt.Errorf("设置默认状态失败: %w", err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return fmt.Errorf("提交事务失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveDefaultDashboard 移除默认仪表板
|
||||
func (r *GormStatisticsDashboardRepository) RemoveDefaultDashboard(ctx context.Context, userRole string) error {
|
||||
if userRole == "" {
|
||||
return fmt.Errorf("用户角色不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("user_role = ? AND is_default = ?", userRole, true).
|
||||
Update("is_default", false)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("移除默认仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ActivateDashboard 激活仪表板
|
||||
func (r *GormStatisticsDashboardRepository) ActivateDashboard(ctx context.Context, dashboardID string) error {
|
||||
if dashboardID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("id = ?", dashboardID).
|
||||
Update("is_active", true)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("激活仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("仪表板不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeactivateDashboard 停用仪表板
|
||||
func (r *GormStatisticsDashboardRepository) DeactivateDashboard(ctx context.Context, dashboardID string) error {
|
||||
if dashboardID == "" {
|
||||
return fmt.Errorf("仪表板ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsDashboard{}).
|
||||
Where("id = ?", dashboardID).
|
||||
Update("is_active", false)
|
||||
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("停用仪表板失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("仪表板不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
package statistics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/statistics/entities"
|
||||
"hyapi-server/internal/domains/statistics/repositories"
|
||||
)
|
||||
|
||||
// GormStatisticsReportRepository GORM统计报告仓储实现
|
||||
type GormStatisticsReportRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormStatisticsReportRepository 创建GORM统计报告仓储
|
||||
func NewGormStatisticsReportRepository(db *gorm.DB) repositories.StatisticsReportRepository {
|
||||
return &GormStatisticsReportRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存统计报告
|
||||
func (r *GormStatisticsReportRepository) Save(ctx context.Context, report *entities.StatisticsReport) error {
|
||||
if report == nil {
|
||||
return fmt.Errorf("统计报告不能为空")
|
||||
}
|
||||
|
||||
// 验证报告
|
||||
if err := report.Validate(); err != nil {
|
||||
return fmt.Errorf("统计报告验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result := r.db.WithContext(ctx).Save(report)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsReport, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("报告ID不能为空")
|
||||
}
|
||||
|
||||
var report entities.StatisticsReport
|
||||
result := r.db.WithContext(ctx).Where("id = ?", id).First(&report)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("统计报告不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &report, nil
|
||||
}
|
||||
|
||||
// FindByUser 根据用户查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).Where("generated_by = ?", userID)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByStatus 根据状态查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByStatus(ctx context.Context, status string) ([]*entities.StatisticsReport, error) {
|
||||
if status == "" {
|
||||
return nil, fmt.Errorf("报告状态不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("status = ?", status).
|
||||
Order("created_at DESC").
|
||||
Find(&reports)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// Update 更新统计报告
|
||||
func (r *GormStatisticsReportRepository) Update(ctx context.Context, report *entities.StatisticsReport) error {
|
||||
if report == nil {
|
||||
return fmt.Errorf("统计报告不能为空")
|
||||
}
|
||||
|
||||
if report.ID == "" {
|
||||
return fmt.Errorf("报告ID不能为空")
|
||||
}
|
||||
|
||||
// 验证报告
|
||||
if err := report.Validate(); err != nil {
|
||||
return fmt.Errorf("统计报告验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
result := r.db.WithContext(ctx).Save(report)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除统计报告
|
||||
func (r *GormStatisticsReportRepository) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("报告ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("统计报告不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByType 根据类型查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByType(ctx context.Context, reportType string, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if reportType == "" {
|
||||
return nil, fmt.Errorf("报告类型不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).Where("report_type = ?", reportType)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByTypeAndPeriod 根据类型和周期查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByTypeAndPeriod(ctx context.Context, reportType, period string, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if reportType == "" {
|
||||
return nil, fmt.Errorf("报告类型不能为空")
|
||||
}
|
||||
|
||||
if period == "" {
|
||||
return nil, fmt.Errorf("统计周期不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("report_type = ? AND period = ?", reportType, period)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByDateRange 根据日期范围查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByDateRange(ctx context.Context, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("created_at >= ? AND created_at < ?", startDate, endDate)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// FindByUserAndDateRange 根据用户和日期范围查找统计报告
|
||||
func (r *GormStatisticsReportRepository) FindByUserAndDateRange(ctx context.Context, userID string, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
|
||||
if userID == "" {
|
||||
return nil, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var reports []*entities.StatisticsReport
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("generated_by = ? AND created_at >= ? AND created_at < ?", userID, startDate, endDate)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&reports)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return reports, nil
|
||||
}
|
||||
|
||||
// CountByUser 根据用户统计数量
|
||||
func (r *GormStatisticsReportRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
|
||||
if userID == "" {
|
||||
return 0, fmt.Errorf("用户ID不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsReport{}).
|
||||
Where("generated_by = ?", userID).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByType 根据类型统计数量
|
||||
func (r *GormStatisticsReportRepository) CountByType(ctx context.Context, reportType string) (int64, error) {
|
||||
if reportType == "" {
|
||||
return 0, fmt.Errorf("报告类型不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsReport{}).
|
||||
Where("report_type = ?", reportType).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByStatus 根据状态统计数量
|
||||
func (r *GormStatisticsReportRepository) CountByStatus(ctx context.Context, status string) (int64, error) {
|
||||
if status == "" {
|
||||
return 0, fmt.Errorf("报告状态不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsReport{}).
|
||||
Where("status = ?", status).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BatchSave 批量保存统计报告
|
||||
func (r *GormStatisticsReportRepository) BatchSave(ctx context.Context, reports []*entities.StatisticsReport) error {
|
||||
if len(reports) == 0 {
|
||||
return fmt.Errorf("统计报告列表不能为空")
|
||||
}
|
||||
|
||||
// 验证所有报告
|
||||
for _, report := range reports {
|
||||
if err := report.Validate(); err != nil {
|
||||
return fmt.Errorf("统计报告验证失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存
|
||||
result := r.db.WithContext(ctx).CreateInBatches(reports, 100)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量保存统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除统计报告
|
||||
func (r *GormStatisticsReportRepository) BatchDelete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("报告ID列表不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id IN ?", ids)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量删除统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteExpiredReports 删除过期报告
|
||||
func (r *GormStatisticsReportRepository) DeleteExpiredReports(ctx context.Context, expiredBefore time.Time) error {
|
||||
if expiredBefore.IsZero() {
|
||||
return fmt.Errorf("过期时间不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsReport{}, "expires_at IS NOT NULL AND expires_at < ?", expiredBefore)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除过期报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByStatus 根据状态删除统计报告
|
||||
func (r *GormStatisticsReportRepository) DeleteByStatus(ctx context.Context, status string) error {
|
||||
if status == "" {
|
||||
return fmt.Errorf("报告状态不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsReport{}, "status = ?", status)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("根据状态删除统计报告失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,381 @@
|
||||
package statistics
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/statistics/entities"
|
||||
"hyapi-server/internal/domains/statistics/repositories"
|
||||
)
|
||||
|
||||
// GormStatisticsRepository GORM统计指标仓储实现
|
||||
type GormStatisticsRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewGormStatisticsRepository 创建GORM统计指标仓储
|
||||
func NewGormStatisticsRepository(db *gorm.DB) repositories.StatisticsRepository {
|
||||
return &GormStatisticsRepository{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Save 保存统计指标
|
||||
func (r *GormStatisticsRepository) Save(ctx context.Context, metric *entities.StatisticsMetric) error {
|
||||
if metric == nil {
|
||||
return fmt.Errorf("统计指标不能为空")
|
||||
}
|
||||
|
||||
// 验证指标
|
||||
if err := metric.Validate(); err != nil {
|
||||
return fmt.Errorf("统计指标验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 保存到数据库
|
||||
result := r.db.WithContext(ctx).Create(metric)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("保存统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByID 根据ID查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsMetric, error) {
|
||||
if id == "" {
|
||||
return nil, fmt.Errorf("指标ID不能为空")
|
||||
}
|
||||
|
||||
var metric entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).Where("id = ?", id).First(&metric)
|
||||
if result.Error != nil {
|
||||
if result.Error == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("统计指标不存在")
|
||||
}
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return &metric, nil
|
||||
}
|
||||
|
||||
// FindByType 根据类型查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByType(ctx context.Context, metricType string, limit, offset int) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
query := r.db.WithContext(ctx).Where("metric_type = ?", metricType)
|
||||
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
|
||||
result := query.Order("created_at DESC").Find(&metrics)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// Update 更新统计指标
|
||||
func (r *GormStatisticsRepository) Update(ctx context.Context, metric *entities.StatisticsMetric) error {
|
||||
if metric == nil {
|
||||
return fmt.Errorf("统计指标不能为空")
|
||||
}
|
||||
|
||||
if metric.ID == "" {
|
||||
return fmt.Errorf("指标ID不能为空")
|
||||
}
|
||||
|
||||
// 验证指标
|
||||
if err := metric.Validate(); err != nil {
|
||||
return fmt.Errorf("统计指标验证失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新数据库
|
||||
result := r.db.WithContext(ctx).Save(metric)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("更新统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除统计指标
|
||||
func (r *GormStatisticsRepository) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("指标ID不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id = ?", id)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return fmt.Errorf("统计指标不存在")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindByTypeAndDateRange 根据类型和日期范围查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
|
||||
Order("date ASC").
|
||||
Find(&metrics)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// FindByTypeDimensionAndDateRange 根据类型、维度和日期范围查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByTypeDimensionAndDateRange(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
query := r.db.WithContext(ctx).
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate)
|
||||
|
||||
if dimension != "" {
|
||||
query = query.Where("dimension = ?", dimension)
|
||||
}
|
||||
|
||||
result := query.Order("date ASC").Find(&metrics)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// FindByTypeNameAndDateRange 根据类型、名称和日期范围查找统计指标
|
||||
func (r *GormStatisticsRepository) FindByTypeNameAndDateRange(ctx context.Context, metricType, metricName string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if metricName == "" {
|
||||
return nil, fmt.Errorf("指标名称不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("metric_type = ? AND metric_name = ? AND date >= ? AND date < ?",
|
||||
metricType, metricName, startDate, endDate).
|
||||
Order("date ASC").
|
||||
Find(&metrics)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// GetAggregatedMetrics 获取聚合指标
|
||||
func (r *GormStatisticsRepository) GetAggregatedMetrics(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) (map[string]float64, error) {
|
||||
if metricType == "" {
|
||||
return nil, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
type AggregatedResult struct {
|
||||
MetricName string `json:"metric_name"`
|
||||
TotalValue float64 `json:"total_value"`
|
||||
}
|
||||
|
||||
var results []AggregatedResult
|
||||
query := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsMetric{}).
|
||||
Select("metric_name, SUM(value) as total_value").
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
|
||||
Group("metric_name")
|
||||
|
||||
if dimension != "" {
|
||||
query = query.Where("dimension = ?", dimension)
|
||||
}
|
||||
|
||||
result := query.Find(&results)
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询聚合指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
// 转换为map
|
||||
aggregated := make(map[string]float64)
|
||||
for _, res := range results {
|
||||
aggregated[res.MetricName] = res.TotalValue
|
||||
}
|
||||
|
||||
return aggregated, nil
|
||||
}
|
||||
|
||||
// GetMetricsByDimension 根据维度获取指标
|
||||
func (r *GormStatisticsRepository) GetMetricsByDimension(ctx context.Context, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
|
||||
if dimension == "" {
|
||||
return nil, fmt.Errorf("统计维度不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return nil, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var metrics []*entities.StatisticsMetric
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("dimension = ? AND date >= ? AND date < ?", dimension, startDate, endDate).
|
||||
Order("date ASC").
|
||||
Find(&metrics)
|
||||
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return metrics, nil
|
||||
}
|
||||
|
||||
// CountByType 根据类型统计数量
|
||||
func (r *GormStatisticsRepository) CountByType(ctx context.Context, metricType string) (int64, error) {
|
||||
if metricType == "" {
|
||||
return 0, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsMetric{}).
|
||||
Where("metric_type = ?", metricType).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// CountByTypeAndDateRange 根据类型和日期范围统计数量
|
||||
func (r *GormStatisticsRepository) CountByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) (int64, error) {
|
||||
if metricType == "" {
|
||||
return 0, fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return 0, fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
var count int64
|
||||
result := r.db.WithContext(ctx).
|
||||
Model(&entities.StatisticsMetric{}).
|
||||
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
|
||||
Count(&count)
|
||||
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// BatchSave 批量保存统计指标
|
||||
func (r *GormStatisticsRepository) BatchSave(ctx context.Context, metrics []*entities.StatisticsMetric) error {
|
||||
if len(metrics) == 0 {
|
||||
return fmt.Errorf("统计指标列表不能为空")
|
||||
}
|
||||
|
||||
// 验证所有指标
|
||||
for _, metric := range metrics {
|
||||
if err := metric.Validate(); err != nil {
|
||||
return fmt.Errorf("统计指标验证失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 批量保存
|
||||
result := r.db.WithContext(ctx).CreateInBatches(metrics, 100)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量保存统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除统计指标
|
||||
func (r *GormStatisticsRepository) BatchDelete(ctx context.Context, ids []string) error {
|
||||
if len(ids) == 0 {
|
||||
return fmt.Errorf("指标ID列表不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id IN ?", ids)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("批量删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByDateRange 根据日期范围删除统计指标
|
||||
func (r *GormStatisticsRepository) DeleteByDateRange(ctx context.Context, startDate, endDate time.Time) error {
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsMetric{}, "date >= ? AND date < ?", startDate, endDate)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("根据日期范围删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByTypeAndDateRange 根据类型和日期范围删除统计指标
|
||||
func (r *GormStatisticsRepository) DeleteByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) error {
|
||||
if metricType == "" {
|
||||
return fmt.Errorf("指标类型不能为空")
|
||||
}
|
||||
|
||||
if startDate.IsZero() || endDate.IsZero() {
|
||||
return fmt.Errorf("开始日期和结束日期不能为空")
|
||||
}
|
||||
|
||||
result := r.db.WithContext(ctx).
|
||||
Delete(&entities.StatisticsMetric{}, "metric_type = ? AND date >= ? AND date < ?",
|
||||
metricType, startDate, endDate)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("根据类型和日期范围删除统计指标失败: %w", result.Error)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/shared/database"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const (
|
||||
ContractInfosTable = "contract_infos"
|
||||
)
|
||||
|
||||
type GormContractInfoRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
|
||||
return &GormContractInfoRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
|
||||
return r.CreateEntity(ctx, contract)
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
|
||||
var contract entities.ContractInfo
|
||||
err := r.SmartGetByID(ctx, contractID, &contract)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &contract, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
|
||||
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
|
||||
var contracts []entities.ContractInfo
|
||||
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*entities.ContractInfo, len(contracts))
|
||||
for i := range contracts {
|
||||
result[i] = &contracts[i]
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormEnterpriseInfoRepository 企业信息GORM仓储实现
|
||||
type GormEnterpriseInfoRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGormEnterpriseInfoRepository 创建企业信息GORM仓储
|
||||
func NewGormEnterpriseInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.EnterpriseInfoRepository {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) (entities.EnterpriseInfo, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("创建企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("创建企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// Update 更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
|
||||
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("更新企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("更新企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("软删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("软删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("恢复企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByUnifiedSocialCode 根据统一社会信用代码获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// CheckUnifiedSocialCodeExists 检查统一社会信用代码是否已存在
|
||||
func (r *GormEnterpriseInfoRepository) CheckUnifiedSocialCodeExists(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("user_id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
r.logger.Error("检查统一社会信用代码失败", zap.Error(err))
|
||||
return false, fmt.Errorf("检查统一社会信用代码失败: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// UpdateVerificationStatus 更新验证状态
|
||||
func (r *GormEnterpriseInfoRepository) UpdateVerificationStatus(ctx context.Context, userID string, isOCRVerified, isFaceVerified, isCertified bool) error {
|
||||
updates := map[string]interface{}{
|
||||
"is_ocr_verified": isOCRVerified,
|
||||
"is_face_verified": isFaceVerified,
|
||||
"is_certified": isCertified,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新验证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新验证状态失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOCRData 更新OCR数据
|
||||
func (r *GormEnterpriseInfoRepository) UpdateOCRData(ctx context.Context, userID string, rawData string, confidence float64) error {
|
||||
updates := map[string]interface{}{
|
||||
"ocr_raw_data": rawData,
|
||||
"ocr_confidence": confidence,
|
||||
"is_ocr_verified": true,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新OCR数据失败", zap.Error(err))
|
||||
return fmt.Errorf("更新OCR数据失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteCertification 完成认证
|
||||
func (r *GormEnterpriseInfoRepository) CompleteCertification(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
updates := map[string]interface{}{
|
||||
"is_certified": true,
|
||||
"certified_at": &now,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("完成认证失败", zap.Error(err))
|
||||
return fmt.Errorf("完成认证失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计企业信息数量
|
||||
func (r *GormEnterpriseInfoRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查企业信息是否存在
|
||||
func (r *GormEnterpriseInfoRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) CreateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Create(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) UpdateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Save(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取企业信息列表
|
||||
func (r *GormEnterpriseInfoRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
// 默认按创建时间倒序
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormEnterpriseInfoRepository) WithTx(tx interface{}) interfaces.Repository[entities.EnterpriseInfo] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/domains/user/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
SMSCodesTable = "sms_codes"
|
||||
)
|
||||
|
||||
// GormSMSCodeRepository 短信验证码GORM仓储实现(无缓存,确保安全性)
|
||||
type GormSMSCodeRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
// NewGormSMSCodeRepository 创建短信验证码仓储
|
||||
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
|
||||
return &GormSMSCodeRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
|
||||
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
|
||||
|
||||
// ================ Repository[T] 接口实现 ================
|
||||
|
||||
// Create 创建短信验证码记录(不缓存,确保安全性)
|
||||
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
|
||||
err := r.GetDB(ctx).Create(&smsCode).Error
|
||||
return smsCode, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// Update 更新验证码记录
|
||||
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Save(&smsCode).Error
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建短信验证码
|
||||
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Create(&smsCodes).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&smsCodes).Error
|
||||
return smsCodes, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新短信验证码
|
||||
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
return r.GetDB(ctx).Save(&smsCodes).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除短信验证码
|
||||
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取短信验证码列表
|
||||
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用预加载
|
||||
for _, include := range options.Include {
|
||||
query = query.Preload(include)
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order == "desc" || options.Order == "DESC" {
|
||||
order = "DESC"
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
} else {
|
||||
query = query.Order("created_at DESC")
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return smsCodes, query.Find(&smsCodes).Error
|
||||
}
|
||||
|
||||
// ================ BaseRepository 接口实现 ================
|
||||
|
||||
// Delete 删除短信验证码
|
||||
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Exists 检查短信验证码是否存在
|
||||
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Count 统计短信验证码数量
|
||||
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// SoftDelete 软删除短信验证码
|
||||
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复短信验证码
|
||||
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
// GetByPhone 根据手机号获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetLatestByPhone 根据手机号获取最新短信验证码
|
||||
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetValidByPhone 根据手机号获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("有效的短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.GetDB(ctx).
|
||||
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("有效的短信验证码不存在")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
|
||||
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
|
||||
var smsCodes []*entities.SMSCode
|
||||
var total int64
|
||||
|
||||
// 构建查询条件
|
||||
db := r.GetDB(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("phone = ?", query.Phone)
|
||||
}
|
||||
if query.Purpose != "" {
|
||||
db = db.Where("scene = ?", query.Purpose)
|
||||
}
|
||||
if query.Status != "" {
|
||||
db = db.Where("used = ?", query.Status == "used")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
if err := db.Offset(offset).Limit(query.PageSize).Order("created_at DESC").Find(&smsCodes).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return smsCodes, total, nil
|
||||
}
|
||||
|
||||
// CreateCode 创建验证码
|
||||
func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, code string, purpose string) (entities.SMSCode, error) {
|
||||
smsCode := entities.SMSCode{
|
||||
Phone: phone,
|
||||
Code: code,
|
||||
Scene: entities.SMSScene(purpose), // 使用Scene字段
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
|
||||
}
|
||||
|
||||
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// ValidateCode 验证验证码
|
||||
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// InvalidateCode 使验证码失效
|
||||
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND used_at IS NULL", phone).
|
||||
Update("used_at", &now).Error
|
||||
}
|
||||
|
||||
// CheckSendFrequency 检查发送频率
|
||||
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
|
||||
// 检查1分钟内是否已发送
|
||||
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
|
||||
var count int64
|
||||
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
|
||||
Count(&count).Error
|
||||
|
||||
// 如果1分钟内已发送,则返回false(不允许发送)
|
||||
return count == 0, err
|
||||
}
|
||||
|
||||
// GetTodaySendCount 获取今日发送数量
|
||||
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
var count int64
|
||||
|
||||
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&count).Error
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetCodeStats 获取验证码统计
|
||||
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
|
||||
var stats repositories.SMSCodeStats
|
||||
|
||||
// 计算指定天数前的日期
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
// 总发送数
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, startDate).
|
||||
Count(&stats.TotalSent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总验证数
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
|
||||
Count(&stats.TotalValidated).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 成功率
|
||||
if stats.TotalSent > 0 {
|
||||
stats.SuccessRate = float64(stats.TotalValidated) / float64(stats.TotalSent) * 100
|
||||
}
|
||||
|
||||
// 今日发送数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.GetDB(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&stats.TodaySent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,720 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"hyapi-server/internal/domains/user/entities"
|
||||
"hyapi-server/internal/domains/user/repositories"
|
||||
"hyapi-server/internal/domains/user/repositories/queries"
|
||||
"hyapi-server/internal/shared/database"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
const (
|
||||
UsersTable = "users"
|
||||
UserCacheTTL = 30 * 60 // 30分钟
|
||||
)
|
||||
|
||||
// 定义错误常量
|
||||
var (
|
||||
// ErrUserNotFound 用户不存在错误
|
||||
ErrUserNotFound = errors.New("用户不存在")
|
||||
)
|
||||
|
||||
type GormUserRepository struct {
|
||||
*database.CachedBaseRepositoryImpl
|
||||
}
|
||||
|
||||
var _ repositories.UserRepository = (*GormUserRepository)(nil)
|
||||
|
||||
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
|
||||
return &GormUserRepository{
|
||||
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
|
||||
err := r.CreateEntity(ctx, &user)
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.SmartGetByID(ctx, id, &user)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, errors.New("用户不存在")
|
||||
}
|
||||
return entities.User{}, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) BatchGetByIDsWithEnterpriseInfo(ctx context.Context, ids []string) ([]*entities.User, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*entities.User{}, nil
|
||||
}
|
||||
|
||||
var users []*entities.User
|
||||
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id IN ?", ids).Find(&users).Error; err != nil {
|
||||
r.GetLogger().Error("批量查询用户失败", zap.Error(err), zap.Strings("ids", ids))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
// 如果指定了排除的用户ID,则排除该用户的记录
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("users.id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
if err != nil {
|
||||
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
|
||||
return r.UpdateEntity(ctx, &user)
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Create(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
|
||||
return r.GetDB(ctx).Save(&users).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.SmartList(ctx, &users, options)
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
|
||||
return r.DeleteEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
return r.ExistsEntity(ctx, id, &entities.User{})
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
|
||||
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// ================ 业务专用方法 ================
|
||||
|
||||
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
|
||||
var users []*entities.User
|
||||
err := r.GetDB(ctx).Where("user_type = ?", userType).Order("created_at DESC").Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
|
||||
var users []*entities.User
|
||||
var total int64
|
||||
|
||||
// 构建查询条件,预加载企业信息
|
||||
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
|
||||
}
|
||||
if query.UserType != "" {
|
||||
db = db.Where("users.user_type = ?", query.UserType)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
db = db.Where("users.active = ?", *query.IsActive)
|
||||
}
|
||||
if query.IsCertified != nil {
|
||||
db = db.Where("users.is_certified = ?", *query.IsCertified)
|
||||
}
|
||||
if query.CompanyName != "" {
|
||||
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
|
||||
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
db = db.Where("users.created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
db = db.Where("users.created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := db.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用排序(默认按创建时间倒序)
|
||||
db = db.Order("users.created_at DESC")
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
if err := db.Offset(offset).Limit(query.PageSize).Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_login_at": &now,
|
||||
"updated_at": now,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("password", newPassword).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
|
||||
var count int64
|
||||
err := r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ? AND password = ?", userID, password).
|
||||
Count(&count).Error
|
||||
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", true).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("active", false).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
|
||||
return r.GetDB(ctx).Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Updates(map[string]interface{}{
|
||||
"login_count": gorm.Expr("login_count + 1"),
|
||||
"last_login_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数
|
||||
if err := db.Model(&entities.User{}).Where("active = ?", true).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 指定时间范围内的登录数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemUserStats 获取系统用户统计信息
|
||||
func (r *GormUserRepository) GetSystemUserStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 总用户数
|
||||
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数(最近30天有登录)
|
||||
thirtyDaysAgo := time.Now().AddDate(0, 0, -30)
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", thirtyDaysAgo).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已认证用户数
|
||||
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemUserStatsByDateRange 获取系统指定时间范围内的用户统计信息
|
||||
func (r *GormUserRepository) GetSystemUserStatsByDateRange(ctx context.Context, startDate, endDate time.Time) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
db := r.GetDB(ctx)
|
||||
|
||||
// 指定时间范围内的注册数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 指定时间范围内的登录数
|
||||
if err := db.Model(&entities.User{}).
|
||||
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
|
||||
Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetSystemDailyUserStats 获取系统每日用户统计
|
||||
func (r *GormUserRepository) GetSystemDailyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(created_at) as date,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE DATE(created_at) >= $1
|
||||
AND DATE(created_at) <= $2
|
||||
GROUP BY DATE(created_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyUserStats 获取系统每月用户统计
|
||||
func (r *GormUserRepository) GetSystemMonthlyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(created_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE created_at >= $1
|
||||
AND created_at <= $2
|
||||
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemDailyCertificationStats 获取系统每日认证用户统计(基于is_certified字段)
|
||||
func (r *GormUserRepository) GetSystemDailyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
DATE(updated_at) as date,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_certified = true
|
||||
AND DATE(updated_at) >= $1
|
||||
AND DATE(updated_at) <= $2
|
||||
GROUP BY DATE(updated_at)
|
||||
ORDER BY date ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetSystemMonthlyCertificationStats 获取系统每月认证用户统计(基于is_certified字段)
|
||||
func (r *GormUserRepository) GetSystemMonthlyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
|
||||
var results []map[string]interface{}
|
||||
|
||||
sql := `
|
||||
SELECT
|
||||
TO_CHAR(updated_at, 'YYYY-MM') as month,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
WHERE is_certified = true
|
||||
AND updated_at >= $1
|
||||
AND updated_at <= $2
|
||||
GROUP BY TO_CHAR(updated_at, 'YYYY-MM')
|
||||
ORDER BY month ASC
|
||||
`
|
||||
|
||||
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserCallRankingByCalls 按调用次数获取用户排行
|
||||
func (r *GormUserRepository) GetUserCallRankingByCalls(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
AND DATE(ac.created_at) = CURRENT_DATE
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COUNT(ac.id) as calls
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN api_calls ac ON u.id = ac.user_id
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COUNT(ac.id) > 0
|
||||
ORDER BY calls DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetUserCallRankingByConsumption 按消费金额获取用户排行
|
||||
func (r *GormUserRepository) GetUserCallRankingByConsumption(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
AND DATE(wt.created_at) = CURRENT_DATE
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
AND DATE_TRUNC('month', wt.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(wt.amount), 0) as consumption
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(wt.amount), 0) > 0
|
||||
ORDER BY consumption DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetRechargeRanking 获取充值排行(排除赠送,只统计成功状态)
|
||||
func (r *GormUserRepository) GetRechargeRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
|
||||
var sql string
|
||||
var args []interface{}
|
||||
|
||||
switch period {
|
||||
case "today":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND DATE(rr.created_at) = CURRENT_DATE
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "month":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND DATE_TRUNC('month', rr.created_at) = DATE_TRUNC('month', CURRENT_DATE)
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
case "total":
|
||||
sql = `
|
||||
SELECT
|
||||
u.id as user_id,
|
||||
COALESCE(ei.company_name, u.username, u.phone) as username,
|
||||
COALESCE(SUM(rr.amount), 0) as amount
|
||||
FROM users u
|
||||
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
|
||||
LEFT JOIN recharge_records rr ON u.id = rr.user_id
|
||||
AND rr.status = 'success'
|
||||
AND rr.recharge_type != 'gift'
|
||||
WHERE u.deleted_at IS NULL
|
||||
GROUP BY u.id, ei.company_name, u.username, u.phone
|
||||
HAVING COALESCE(SUM(rr.amount), 0) > 0
|
||||
ORDER BY amount DESC
|
||||
LIMIT $1
|
||||
`
|
||||
args = []interface{}{limit}
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的时间周期: %s", period)
|
||||
}
|
||||
|
||||
var results []map[string]interface{}
|
||||
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
291
internal/infrastructure/events/certification_event_publisher.go
Normal file
291
internal/infrastructure/events/certification_event_publisher.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// ================ 常量定义 ================
|
||||
|
||||
const (
|
||||
// 事件类型
|
||||
EventTypeCertificationCreated = "certification.created"
|
||||
EventTypeEnterpriseInfoSubmitted = "certification.enterprise_info_submitted"
|
||||
EventTypeEnterpriseVerificationCompleted = "certification.enterprise_verification_completed"
|
||||
EventTypeContractGenerated = "certification.contract_generated"
|
||||
EventTypeContractSigned = "certification.contract_signed"
|
||||
EventTypeCertificationCompleted = "certification.completed"
|
||||
EventTypeCertificationFailed = "certification.failed"
|
||||
EventTypeStatusTransitioned = "certification.status_transitioned"
|
||||
|
||||
// 重试配置
|
||||
MaxRetries = 3
|
||||
RetryDelay = 5 * time.Second
|
||||
)
|
||||
|
||||
// ================ 事件结构 ================
|
||||
|
||||
// CertificationEventData 认证事件数据结构
|
||||
type CertificationEventData struct {
|
||||
EventType string `json:"event_type"`
|
||||
CertificationID string `json:"certification_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// ================ 事件发布器实现 ================
|
||||
|
||||
// CertificationEventPublisher 认证事件发布器实现
|
||||
//
|
||||
// 职责:
|
||||
// - 发布认证域相关的事件
|
||||
// - 支持异步发布和重试机制
|
||||
// - 提供事件持久化能力
|
||||
// - 集成监控和日志
|
||||
type CertificationEventPublisher struct {
|
||||
eventBus interfaces.EventBus
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCertificationEventPublisher 创建认证事件发布器
|
||||
func NewCertificationEventPublisher(
|
||||
eventBus interfaces.EventBus,
|
||||
logger *zap.Logger,
|
||||
) *CertificationEventPublisher {
|
||||
return &CertificationEventPublisher{
|
||||
eventBus: eventBus,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 事件发布方法 ================
|
||||
|
||||
// PublishCertificationCreated 发布认证创建事件
|
||||
func (p *CertificationEventPublisher) PublishCertificationCreated(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeCertificationCreated,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishEnterpriseInfoSubmitted 发布企业信息提交事件
|
||||
func (p *CertificationEventPublisher) PublishEnterpriseInfoSubmitted(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeEnterpriseInfoSubmitted,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishEnterpriseVerificationCompleted 发布企业认证完成事件
|
||||
func (p *CertificationEventPublisher) PublishEnterpriseVerificationCompleted(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeEnterpriseVerificationCompleted,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishContractGenerated 发布合同生成事件
|
||||
func (p *CertificationEventPublisher) PublishContractGenerated(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeContractGenerated,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishContractSigned 发布合同签署事件
|
||||
func (p *CertificationEventPublisher) PublishContractSigned(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeContractSigned,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishCertificationCompleted 发布认证完成事件
|
||||
func (p *CertificationEventPublisher) PublishCertificationCompleted(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeCertificationCompleted,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishCertificationFailed 发布认证失败事件
|
||||
func (p *CertificationEventPublisher) PublishCertificationFailed(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeCertificationFailed,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// PublishStatusTransitioned 发布状态转换事件
|
||||
func (p *CertificationEventPublisher) PublishStatusTransitioned(
|
||||
ctx context.Context,
|
||||
certificationID, userID string,
|
||||
data map[string]interface{},
|
||||
) error {
|
||||
eventData := &CertificationEventData{
|
||||
EventType: EventTypeStatusTransitioned,
|
||||
CertificationID: certificationID,
|
||||
UserID: userID,
|
||||
Data: data,
|
||||
Timestamp: time.Now(),
|
||||
Version: "1.0",
|
||||
}
|
||||
|
||||
return p.publishEventData(ctx, eventData)
|
||||
}
|
||||
|
||||
// ================ 内部实现 ================
|
||||
|
||||
// publishEventData 发布事件数据(带重试机制)
|
||||
func (p *CertificationEventPublisher) publishEventData(ctx context.Context, eventData *CertificationEventData) error {
|
||||
p.logger.Info("发布认证事件",
|
||||
zap.String("event_type", eventData.EventType),
|
||||
zap.String("certification_id", eventData.CertificationID),
|
||||
zap.Time("timestamp", eventData.Timestamp))
|
||||
|
||||
// 尝试发布事件,带重试机制
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// 指数退避重试
|
||||
delay := time.Duration(attempt) * RetryDelay
|
||||
p.logger.Warn("事件发布重试",
|
||||
zap.String("event_type", eventData.EventType),
|
||||
zap.Int("attempt", attempt),
|
||||
zap.Duration("delay", delay))
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(delay):
|
||||
// 继续重试
|
||||
}
|
||||
}
|
||||
|
||||
// 简化的事件发布:直接记录日志
|
||||
p.logger.Info("模拟事件发布",
|
||||
zap.String("event_type", eventData.EventType),
|
||||
zap.String("certification_id", eventData.CertificationID),
|
||||
zap.Any("data", eventData.Data))
|
||||
|
||||
// TODO: 这里可以集成真正的事件总线
|
||||
// if err := p.eventBus.Publish(ctx, eventData); err != nil {
|
||||
// lastErr = err
|
||||
// continue
|
||||
// }
|
||||
|
||||
// 发布成功
|
||||
p.logger.Info("事件发布成功",
|
||||
zap.String("event_type", eventData.EventType),
|
||||
zap.String("certification_id", eventData.CertificationID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 理论上不会到达这里,因为简化实现总是成功
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// ================ 事件处理器注册 ================
|
||||
|
||||
// RegisterEventHandlers 注册事件处理器
|
||||
func (p *CertificationEventPublisher) RegisterEventHandlers() error {
|
||||
// TODO: 注册具体的事件处理器
|
||||
// 例如:发送通知、更新统计数据、触发后续流程等
|
||||
|
||||
p.logger.Info("认证事件处理器已注册")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 工具方法 ================
|
||||
|
||||
// CreateEventData 创建事件数据
|
||||
func CreateEventData(eventType, certificationID, userID string, data map[string]interface{}) map[string]interface{} {
|
||||
if data == nil {
|
||||
data = make(map[string]interface{})
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"event_type": eventType,
|
||||
"certification_id": certificationID,
|
||||
"user_id": userID,
|
||||
"data": data,
|
||||
"timestamp": time.Now(),
|
||||
"version": "1.0",
|
||||
}
|
||||
}
|
||||
228
internal/infrastructure/events/invoice_event_handler.go
Normal file
228
internal/infrastructure/events/invoice_event_handler.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/domains/finance/events"
|
||||
"hyapi-server/internal/infrastructure/external/email"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// InvoiceEventHandler 发票事件处理器
|
||||
type InvoiceEventHandler struct {
|
||||
logger *zap.Logger
|
||||
emailService *email.QQEmailService
|
||||
name string
|
||||
eventTypes []string
|
||||
isAsync bool
|
||||
}
|
||||
|
||||
// NewInvoiceEventHandler 创建发票事件处理器
|
||||
func NewInvoiceEventHandler(logger *zap.Logger, emailService *email.QQEmailService) *InvoiceEventHandler {
|
||||
return &InvoiceEventHandler{
|
||||
logger: logger,
|
||||
emailService: emailService,
|
||||
name: "invoice-event-handler",
|
||||
eventTypes: []string{
|
||||
"InvoiceApplicationCreated",
|
||||
"InvoiceApplicationApproved",
|
||||
"InvoiceApplicationRejected",
|
||||
"InvoiceFileUploaded",
|
||||
},
|
||||
isAsync: true,
|
||||
}
|
||||
}
|
||||
|
||||
// GetName 获取处理器名称
|
||||
func (h *InvoiceEventHandler) GetName() string {
|
||||
return h.name
|
||||
}
|
||||
|
||||
// GetEventTypes 获取支持的事件类型
|
||||
func (h *InvoiceEventHandler) GetEventTypes() []string {
|
||||
return h.eventTypes
|
||||
}
|
||||
|
||||
// IsAsync 是否为异步处理器
|
||||
func (h *InvoiceEventHandler) IsAsync() bool {
|
||||
return h.isAsync
|
||||
}
|
||||
|
||||
// GetRetryConfig 获取重试配置
|
||||
func (h *InvoiceEventHandler) GetRetryConfig() interfaces.RetryConfig {
|
||||
return interfaces.RetryConfig{
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 5 * time.Second,
|
||||
BackoffFactor: 2.0,
|
||||
MaxDelay: 30 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
// Handle 处理事件
|
||||
func (h *InvoiceEventHandler) Handle(ctx context.Context, event interfaces.Event) error {
|
||||
h.logger.Info("🔄 开始处理发票事件",
|
||||
zap.String("event_type", event.GetType()),
|
||||
zap.String("event_id", event.GetID()),
|
||||
zap.String("aggregate_id", event.GetAggregateID()),
|
||||
zap.String("handler_name", h.GetName()),
|
||||
zap.Time("event_timestamp", event.GetTimestamp()),
|
||||
)
|
||||
|
||||
switch event.GetType() {
|
||||
case "InvoiceApplicationCreated":
|
||||
h.logger.Info("📝 处理发票申请创建事件")
|
||||
return h.handleInvoiceApplicationCreated(ctx, event)
|
||||
case "InvoiceApplicationApproved":
|
||||
h.logger.Info("✅ 处理发票申请通过事件")
|
||||
return h.handleInvoiceApplicationApproved(ctx, event)
|
||||
case "InvoiceApplicationRejected":
|
||||
h.logger.Info("❌ 处理发票申请拒绝事件")
|
||||
return h.handleInvoiceApplicationRejected(ctx, event)
|
||||
case "InvoiceFileUploaded":
|
||||
h.logger.Info("📎 处理发票文件上传事件")
|
||||
return h.handleInvoiceFileUploaded(ctx, event)
|
||||
default:
|
||||
h.logger.Warn("⚠️ 未知的发票事件类型", zap.String("event_type", event.GetType()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// handleInvoiceApplicationCreated 处理发票申请创建事件
|
||||
func (h *InvoiceEventHandler) handleInvoiceApplicationCreated(ctx context.Context, event interfaces.Event) error {
|
||||
h.logger.Info("发票申请已创建",
|
||||
zap.String("application_id", event.GetAggregateID()),
|
||||
)
|
||||
|
||||
// 这里可以发送通知给管理员,告知有新的发票申请
|
||||
// 暂时只记录日志
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoiceApplicationApproved 处理发票申请通过事件
|
||||
func (h *InvoiceEventHandler) handleInvoiceApplicationApproved(ctx context.Context, event interfaces.Event) error {
|
||||
h.logger.Info("发票申请已通过",
|
||||
zap.String("application_id", event.GetAggregateID()),
|
||||
)
|
||||
|
||||
// 这里可以发送通知给用户,告知发票申请已通过
|
||||
// 暂时只记录日志
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoiceApplicationRejected 处理发票申请拒绝事件
|
||||
func (h *InvoiceEventHandler) handleInvoiceApplicationRejected(ctx context.Context, event interfaces.Event) error {
|
||||
h.logger.Info("发票申请被拒绝",
|
||||
zap.String("application_id", event.GetAggregateID()),
|
||||
)
|
||||
|
||||
// 这里可以发送邮件通知用户,告知发票申请被拒绝
|
||||
// 暂时只记录日志
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInvoiceFileUploaded 处理发票文件上传事件
|
||||
func (h *InvoiceEventHandler) handleInvoiceFileUploaded(ctx context.Context, event interfaces.Event) error {
|
||||
h.logger.Info("📎 发票文件已上传事件开始处理",
|
||||
zap.String("invoice_id", event.GetAggregateID()),
|
||||
zap.String("event_id", event.GetID()),
|
||||
)
|
||||
|
||||
// 解析事件数据
|
||||
payload := event.GetPayload()
|
||||
if payload == nil {
|
||||
h.logger.Error("❌ 事件数据为空")
|
||||
return fmt.Errorf("事件数据为空")
|
||||
}
|
||||
|
||||
h.logger.Info("📋 事件数据解析开始",
|
||||
zap.Any("payload_type", fmt.Sprintf("%T", payload)),
|
||||
)
|
||||
|
||||
// 将payload转换为JSON,然后解析为InvoiceFileUploadedEvent
|
||||
payloadBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
h.logger.Error("❌ 序列化事件数据失败", zap.Error(err))
|
||||
return fmt.Errorf("序列化事件数据失败: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("📄 事件数据序列化成功",
|
||||
zap.String("payload_json", string(payloadBytes)),
|
||||
)
|
||||
|
||||
var fileUploadedEvent events.InvoiceFileUploadedEvent
|
||||
err = json.Unmarshal(payloadBytes, &fileUploadedEvent)
|
||||
if err != nil {
|
||||
h.logger.Error("❌ 解析发票文件上传事件失败", zap.Error(err))
|
||||
return fmt.Errorf("解析发票文件上传事件失败: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("✅ 事件数据解析成功",
|
||||
zap.String("invoice_id", fileUploadedEvent.InvoiceID),
|
||||
zap.String("user_id", fileUploadedEvent.UserID),
|
||||
zap.String("receiving_email", fileUploadedEvent.ReceivingEmail),
|
||||
zap.String("file_name", fileUploadedEvent.FileName),
|
||||
zap.String("file_url", fileUploadedEvent.FileURL),
|
||||
zap.String("company_name", fileUploadedEvent.CompanyName),
|
||||
zap.String("amount", fileUploadedEvent.Amount.String()),
|
||||
zap.String("invoice_type", string(fileUploadedEvent.InvoiceType)),
|
||||
)
|
||||
|
||||
// 发送发票邮件给用户
|
||||
return h.sendInvoiceEmail(ctx, &fileUploadedEvent)
|
||||
}
|
||||
|
||||
// sendInvoiceEmail 发送发票邮件
|
||||
func (h *InvoiceEventHandler) sendInvoiceEmail(ctx context.Context, event *events.InvoiceFileUploadedEvent) error {
|
||||
h.logger.Info("📧 开始发送发票邮件",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
zap.String("user_id", event.UserID),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
zap.String("file_name", event.FileName),
|
||||
zap.String("file_url", event.FileURL),
|
||||
)
|
||||
|
||||
// 构建邮件数据
|
||||
emailData := &email.InvoiceEmailData{
|
||||
CompanyName: event.CompanyName,
|
||||
Amount: event.Amount.String(),
|
||||
InvoiceType: event.InvoiceType.GetDisplayName(),
|
||||
FileURL: event.FileURL,
|
||||
FileName: event.FileName,
|
||||
ReceivingEmail: event.ReceivingEmail,
|
||||
ApprovedAt: event.UploadedAt.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
|
||||
h.logger.Info("📋 邮件数据构建完成",
|
||||
zap.String("company_name", emailData.CompanyName),
|
||||
zap.String("amount", emailData.Amount),
|
||||
zap.String("invoice_type", emailData.InvoiceType),
|
||||
zap.String("file_url", emailData.FileURL),
|
||||
zap.String("file_name", emailData.FileName),
|
||||
zap.String("receiving_email", emailData.ReceivingEmail),
|
||||
zap.String("approved_at", emailData.ApprovedAt),
|
||||
)
|
||||
|
||||
// 发送邮件
|
||||
h.logger.Info("🚀 开始调用邮件服务发送邮件")
|
||||
err := h.emailService.SendInvoiceEmail(ctx, emailData)
|
||||
if err != nil {
|
||||
h.logger.Error("❌ 发送发票邮件失败",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("发送发票邮件失败: %w", err)
|
||||
}
|
||||
|
||||
h.logger.Info("✅ 发票邮件发送成功",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
115
internal/infrastructure/events/invoice_event_publisher.go
Normal file
115
internal/infrastructure/events/invoice_event_publisher.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package events
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/domains/finance/events"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// InvoiceEventPublisher 发票事件发布器实现
|
||||
type InvoiceEventPublisher struct {
|
||||
logger *zap.Logger
|
||||
eventBus interfaces.EventBus
|
||||
}
|
||||
|
||||
// NewInvoiceEventPublisher 创建发票事件发布器
|
||||
func NewInvoiceEventPublisher(logger *zap.Logger, eventBus interfaces.EventBus) *InvoiceEventPublisher {
|
||||
return &InvoiceEventPublisher{
|
||||
logger: logger,
|
||||
eventBus: eventBus,
|
||||
}
|
||||
}
|
||||
|
||||
// PublishInvoiceApplicationCreated 发布发票申请创建事件
|
||||
func (p *InvoiceEventPublisher) PublishInvoiceApplicationCreated(ctx context.Context, event *events.InvoiceApplicationCreatedEvent) error {
|
||||
p.logger.Info("发布发票申请创建事件",
|
||||
zap.String("application_id", event.ApplicationID),
|
||||
zap.String("user_id", event.UserID),
|
||||
zap.String("invoice_type", string(event.InvoiceType)),
|
||||
zap.String("amount", event.Amount.String()),
|
||||
zap.String("company_name", event.CompanyName),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
)
|
||||
|
||||
// TODO: 实现实际的事件发布逻辑
|
||||
// 例如:发送到消息队列、调用外部服务等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishInvoiceApplicationApproved 发布发票申请通过事件
|
||||
func (p *InvoiceEventPublisher) PublishInvoiceApplicationApproved(ctx context.Context, event *events.InvoiceApplicationApprovedEvent) error {
|
||||
p.logger.Info("发布发票申请通过事件",
|
||||
zap.String("application_id", event.ApplicationID),
|
||||
zap.String("user_id", event.UserID),
|
||||
zap.String("amount", event.Amount.String()),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
zap.Time("approved_at", event.ApprovedAt),
|
||||
)
|
||||
|
||||
// TODO: 实现实际的事件发布逻辑
|
||||
// 例如:发送邮件通知用户、更新统计数据等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishInvoiceApplicationRejected 发布发票申请拒绝事件
|
||||
func (p *InvoiceEventPublisher) PublishInvoiceApplicationRejected(ctx context.Context, event *events.InvoiceApplicationRejectedEvent) error {
|
||||
p.logger.Info("发布发票申请拒绝事件",
|
||||
zap.String("application_id", event.ApplicationID),
|
||||
zap.String("user_id", event.UserID),
|
||||
zap.String("reason", event.Reason),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
zap.Time("rejected_at", event.RejectedAt),
|
||||
)
|
||||
|
||||
// TODO: 实现实际的事件发布逻辑
|
||||
// 例如:发送邮件通知用户、记录拒绝原因等
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PublishInvoiceFileUploaded 发布发票文件上传事件
|
||||
func (p *InvoiceEventPublisher) PublishInvoiceFileUploaded(ctx context.Context, event *events.InvoiceFileUploadedEvent) error {
|
||||
p.logger.Info("📤 开始发布发票文件上传事件",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
zap.String("user_id", event.UserID),
|
||||
zap.String("file_id", event.FileID),
|
||||
zap.String("file_name", event.FileName),
|
||||
zap.String("file_url", event.FileURL),
|
||||
zap.String("receiving_email", event.ReceivingEmail),
|
||||
zap.Time("uploaded_at", event.UploadedAt),
|
||||
)
|
||||
|
||||
// 发布到事件总线
|
||||
if p.eventBus != nil {
|
||||
p.logger.Info("🚀 准备发布事件到事件总线",
|
||||
zap.String("event_type", event.GetType()),
|
||||
zap.String("event_id", event.GetID()),
|
||||
)
|
||||
|
||||
if err := p.eventBus.Publish(ctx, event); err != nil {
|
||||
p.logger.Error("❌ 发布发票文件上传事件到事件总线失败",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
zap.String("event_type", event.GetType()),
|
||||
zap.String("event_id", event.GetID()),
|
||||
zap.Error(err),
|
||||
)
|
||||
return err
|
||||
}
|
||||
p.logger.Info("✅ 发票文件上传事件已发布到事件总线",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
zap.String("event_type", event.GetType()),
|
||||
zap.String("event_id", event.GetID()),
|
||||
)
|
||||
} else {
|
||||
p.logger.Warn("⚠️ 事件总线未初始化,无法发布事件",
|
||||
zap.String("invoice_id", event.InvoiceID),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
123
internal/infrastructure/external/README.md
vendored
Normal file
123
internal/infrastructure/external/README.md
vendored
Normal file
@@ -0,0 +1,123 @@
|
||||
# 外部服务错误处理修复说明
|
||||
|
||||
## 问题描述
|
||||
|
||||
在外部服务(WestDex、Yushan、Zhicha)中,使用 `fmt.Errorf("%w: %s", ErrXXX, err)` 包装错误后,外层的 `errors.Is(err, ErrXXX)` 无法正确识别错误类型。
|
||||
|
||||
## 问题原因
|
||||
|
||||
`fmt.Errorf` 创建的包装错误虽然实现了 `Unwrap()` 接口,但没有实现 `Is()` 接口,因此 `errors.Is` 无法正确判断错误类型。
|
||||
|
||||
## 修复方案
|
||||
|
||||
统一使用 `errors.Join` 来组合错误,这是 Go 1.20+ 的标准做法,天然支持 `errors.Is` 判断。
|
||||
|
||||
## 修复内容
|
||||
|
||||
### 1. WestDex 服务 (`westdex_service.go`)
|
||||
|
||||
#### 修复前:
|
||||
```go
|
||||
// 无法被 errors.Is 识别的错误包装
|
||||
err = fmt.Errorf("%w: %s", ErrSystem, marshalErr.Error())
|
||||
err = fmt.Errorf("%w: %s", ErrDatasource, westDexResp.Message)
|
||||
```
|
||||
|
||||
#### 修复后:
|
||||
```go
|
||||
// 可以被 errors.Is 正确识别的错误组合
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf(westDexResp.Message))
|
||||
```
|
||||
|
||||
### 2. Yushan 服务 (`yushan_service.go`)
|
||||
|
||||
#### 修复前:
|
||||
```go
|
||||
// 无法被 errors.Is 识别的错误包装
|
||||
err = fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
err = fmt.Errorf("%w: %s", ErrDatasource, "羽山请求retdata为空")
|
||||
```
|
||||
|
||||
#### 修复后:
|
||||
```go
|
||||
// 可以被 errors.Is 正确识别的错误组合
|
||||
err = errors.Join(ErrSystem, err)
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("羽山请求retdata为空"))
|
||||
```
|
||||
|
||||
### 3. Zhicha 服务 (`zhicha_service.go`)
|
||||
|
||||
#### 修复前:
|
||||
```go
|
||||
// 无法被 errors.Is 识别的错误包装
|
||||
err = fmt.Errorf("%w: %s", ErrSystem, marshalErr.Error())
|
||||
err = fmt.Errorf("%w: %s", ErrDatasource, "HTTP状态码 %d", response.StatusCode)
|
||||
```
|
||||
|
||||
#### 修复后:
|
||||
```go
|
||||
// 可以被 errors.Is 正确识别的错误组合
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
|
||||
```
|
||||
|
||||
## 修复效果
|
||||
|
||||
### 修复前的问题:
|
||||
```go
|
||||
// 在应用服务层
|
||||
if errors.Is(err, westdex.ErrDatasource) {
|
||||
// 这里无法正确识别,因为 fmt.Errorf 包装的错误
|
||||
// 没有实现 Is() 接口
|
||||
return ErrDatasource
|
||||
}
|
||||
```
|
||||
|
||||
### 修复后的效果:
|
||||
```go
|
||||
// 在应用服务层
|
||||
if errors.Is(err, westdex.ErrDatasource) {
|
||||
// 现在可以正确识别了!
|
||||
return ErrDatasource
|
||||
}
|
||||
|
||||
if errors.Is(err, westdex.ErrSystem) {
|
||||
// 系统错误也能正确识别
|
||||
return ErrSystem
|
||||
}
|
||||
```
|
||||
|
||||
## 优势
|
||||
|
||||
1. **完全兼容**:`errors.Is` 现在可以正确识别所有错误类型
|
||||
2. **标准做法**:使用 Go 1.20+ 的 `errors.Join` 标准库功能
|
||||
3. **性能优秀**:标准库实现,性能优于自定义解决方案
|
||||
4. **维护简单**:无需自定义错误类型,代码更简洁
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **Go版本要求**:需要 Go 1.20 或更高版本(项目使用 Go 1.23.4,完全满足)
|
||||
2. **错误消息格式**:`errors.Join` 使用换行符分隔多个错误
|
||||
3. **向后兼容**:现有的错误处理代码无需修改
|
||||
|
||||
## 测试验证
|
||||
|
||||
所有修复后的外部服务都能正确编译:
|
||||
```bash
|
||||
go build ./internal/infrastructure/external/westdex/...
|
||||
go build ./internal/infrastructure/external/yushan/...
|
||||
go build ./internal/infrastructure/external/zhicha/...
|
||||
```
|
||||
|
||||
## 总结
|
||||
|
||||
通过统一使用 `errors.Join` 修复外部服务的错误处理,现在:
|
||||
|
||||
- ✅ `errors.Is(err, ErrDatasource)` 可以正确识别数据源异常
|
||||
- ✅ `errors.Is(err, ErrSystem)` 可以正确识别系统异常
|
||||
- ✅ `errors.Is(err, ErrNotFound)` 可以正确识别查询为空
|
||||
- ✅ 错误处理逻辑更加清晰和可靠
|
||||
- ✅ 符合 Go 1.20+ 的最佳实践
|
||||
|
||||
这个修复确保了整个系统的错误处理链路都能正确工作,提高了系统的可靠性和可维护性。
|
||||
194
internal/infrastructure/external/alicloud/README.md
vendored
Normal file
194
internal/infrastructure/external/alicloud/README.md
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
# 阿里云二要素验证服务
|
||||
|
||||
这个服务提供了调用阿里云身份证二要素验证API的功能,用于验证姓名和身份证号码是否匹配。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 身份证二要素验证(姓名 + 身份证号)
|
||||
- 支持详细验证结果返回
|
||||
- 支持简单布尔值判断
|
||||
- 错误处理和中文错误信息
|
||||
|
||||
## 配置说明
|
||||
|
||||
### 必需配置
|
||||
|
||||
- `Host`: 阿里云API的域名地址
|
||||
- `AppCode`: 阿里云市场应用的AppCode
|
||||
|
||||
### 配置示例
|
||||
|
||||
```go
|
||||
host := "https://kzidcardv1.market.alicloudapi.com"
|
||||
appCode := "您的AppCode"
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 1. 创建服务实例
|
||||
|
||||
```go
|
||||
service := NewAlicloudService(host, appCode)
|
||||
```
|
||||
|
||||
### 2. 调用API
|
||||
|
||||
#### 身份证二要素验证示例
|
||||
|
||||
```go
|
||||
// 构建请求参数
|
||||
params := map[string]interface{}{
|
||||
"name": "张三",
|
||||
"idcard": "110101199001011234",
|
||||
}
|
||||
|
||||
// 调用API
|
||||
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
|
||||
if err != nil {
|
||||
log.Printf("验证失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析完整响应结构
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Birthday string `json:"birthday"`
|
||||
Result int `json:"result"`
|
||||
Address string `json:"address"`
|
||||
OrderNo string `json:"orderNo"`
|
||||
Sex string `json:"sex"`
|
||||
Desc string `json:"desc"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
log.Printf("响应解析失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查响应状态
|
||||
if response.Code != 200 {
|
||||
log.Printf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
|
||||
return
|
||||
}
|
||||
|
||||
idCardData := response.Data
|
||||
|
||||
// 判断验证结果
|
||||
if idCardData.Result == 1 {
|
||||
fmt.Println("身份证信息验证通过")
|
||||
} else {
|
||||
fmt.Println("身份证信息验证失败")
|
||||
}
|
||||
```
|
||||
|
||||
#### 通用API调用
|
||||
|
||||
```go
|
||||
// 调用其他阿里云API
|
||||
params := map[string]interface{}{
|
||||
"param1": "value1",
|
||||
"param2": "value2",
|
||||
}
|
||||
|
||||
responseBody, err := service.CallAPI("your/api/path", params)
|
||||
if err != nil {
|
||||
log.Printf("API调用失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据具体API的响应结构进行解析
|
||||
// 每个API的响应结构可能不同,需要根据API文档定义相应的结构体
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Code int `json:"code"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
log.Printf("响应解析失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理响应数据
|
||||
fmt.Printf("响应数据: %s\n", string(responseBody))
|
||||
```
|
||||
|
||||
## 响应格式
|
||||
|
||||
### 通用响应结构
|
||||
|
||||
```json
|
||||
{
|
||||
"msg": "成功",
|
||||
"success": true,
|
||||
"code": 200,
|
||||
"data": {
|
||||
// 具体的业务数据
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 身份证验证响应示例
|
||||
|
||||
#### 成功响应 (code: 200)
|
||||
|
||||
```json
|
||||
{
|
||||
"msg": "成功",
|
||||
"success": true,
|
||||
"code": 200,
|
||||
"data": {
|
||||
"birthday": "19840816",
|
||||
"result": 1,
|
||||
"address": "浙江省杭州市淳安县",
|
||||
"orderNo": "202406271440416095174",
|
||||
"sex": "男",
|
||||
"desc": "不一致"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### 参数错误响应 (code: 400)
|
||||
|
||||
```json
|
||||
{
|
||||
"msg": "请输入有效的身份证号码",
|
||||
"code": 400,
|
||||
"data": null
|
||||
}
|
||||
```
|
||||
|
||||
### 错误响应
|
||||
|
||||
```json
|
||||
{
|
||||
"msg": "AppCode无效",
|
||||
"success": false,
|
||||
"code": 400
|
||||
}
|
||||
```
|
||||
|
||||
## 错误处理
|
||||
|
||||
服务定义了以下错误类型:
|
||||
|
||||
- `ErrDatasource`: 数据源异常
|
||||
- `ErrSystem`: 系统异常
|
||||
- `ErrInvalid`: 身份证信息不匹配
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. 请确保您的AppCode有效且有足够的调用额度
|
||||
2. 身份证号码必须是18位有效格式
|
||||
3. 姓名必须是真实有效的姓名
|
||||
4. 建议在生产环境中添加适当的重试机制和超时设置
|
||||
5. 请遵守阿里云API的使用规范和频率限制
|
||||
|
||||
## 依赖
|
||||
|
||||
- Go 1.16+
|
||||
- 标准库:`net/http`, `encoding/json`, `net/url`
|
||||
48
internal/infrastructure/external/alicloud/alicloud_factory.go
vendored
Normal file
48
internal/infrastructure/external/alicloud/alicloud_factory.go
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
package alicloud
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewAlicloudServiceWithConfig 使用配置创建阿里云服务,并启用外部服务调用日志
|
||||
func NewAlicloudServiceWithConfig(cfg *config.Config) (*AlicloudService, error) {
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: true,
|
||||
LogDir: "./logs/external_services",
|
||||
ServiceName: "alicloud",
|
||||
UseDaily: false,
|
||||
EnableLevelSeparation: true,
|
||||
LevelConfigs: map[string]external_logger.ExternalServiceLevelFileConfig{
|
||||
"info": {
|
||||
MaxSize: 100,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 28,
|
||||
Compress: true,
|
||||
},
|
||||
"error": {
|
||||
MaxSize: 100,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 28,
|
||||
Compress: true,
|
||||
},
|
||||
"warn": {
|
||||
MaxSize: 100,
|
||||
MaxBackups: 3,
|
||||
MaxAge: 28,
|
||||
Compress: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewAlicloudService(
|
||||
cfg.Alicloud.Host,
|
||||
cfg.Alicloud.AppCode,
|
||||
logger,
|
||||
), nil
|
||||
}
|
||||
142
internal/infrastructure/external/alicloud/alicloud_service.go
vendored
Normal file
142
internal/infrastructure/external/alicloud/alicloud_service.go
vendored
Normal file
@@ -0,0 +1,142 @@
|
||||
package alicloud
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
)
|
||||
|
||||
// AlicloudConfig 阿里云配置
|
||||
type AlicloudConfig struct {
|
||||
Host string
|
||||
AppCode string
|
||||
}
|
||||
|
||||
// AlicloudService 阿里云服务
|
||||
type AlicloudService struct {
|
||||
config AlicloudConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewAlicloudService 创建阿里云服务实例
|
||||
func NewAlicloudService(host, appCode string, logger ...*external_logger.ExternalServiceLogger) *AlicloudService {
|
||||
var serviceLogger *external_logger.ExternalServiceLogger
|
||||
if len(logger) > 0 {
|
||||
serviceLogger = logger[0]
|
||||
}
|
||||
return &AlicloudService{
|
||||
config: AlicloudConfig{
|
||||
Host: host,
|
||||
AppCode: appCode,
|
||||
},
|
||||
logger: serviceLogger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (a *AlicloudService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, a.config.Host)))
|
||||
return fmt.Sprintf("alicloud_%x", hash[:8])
|
||||
}
|
||||
|
||||
// CallAPI 调用阿里云API的通用方法
|
||||
// path: API路径(如 "api-mall/api/id_card/check")
|
||||
// params: 请求参数
|
||||
func (a *AlicloudService) CallAPI(path string, params map[string]interface{}) (respBytes []byte, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := a.generateRequestID()
|
||||
transactionID := ""
|
||||
|
||||
// 构建请求URL
|
||||
reqURL := a.config.Host + "/" + path
|
||||
|
||||
// 记录请求日志
|
||||
if a.logger != nil {
|
||||
a.logger.LogRequest(requestID, transactionID, path, reqURL)
|
||||
}
|
||||
|
||||
// 构建请求参数
|
||||
formData := url.Values{}
|
||||
for key, value := range params {
|
||||
formData.Set(key, fmt.Sprintf("%v", value))
|
||||
}
|
||||
|
||||
// 创建HTTP请求
|
||||
req, err := http.NewRequest("POST", reqURL, strings.NewReader(formData.Encode()))
|
||||
if err != nil {
|
||||
if a.logger != nil {
|
||||
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrSystem, err), params)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8")
|
||||
req.Header.Set("Authorization", "APPCODE "+a.config.AppCode)
|
||||
|
||||
// 发送请求,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
if a.logger != nil {
|
||||
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %s", err.Error())), params)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: API请求超时: %s", ErrDatasource, err.Error())
|
||||
}
|
||||
if a.logger != nil {
|
||||
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrSystem, err), params)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
if a.logger != nil {
|
||||
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrSystem, err), params)
|
||||
}
|
||||
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
|
||||
}
|
||||
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if a.logger != nil {
|
||||
duration := time.Since(startTime)
|
||||
a.logger.LogResponse(requestID, transactionID, path, resp.StatusCode, duration)
|
||||
}
|
||||
|
||||
// 直接返回原始响应body,让调用方自己处理
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// GetConfig 获取配置信息
|
||||
func (a *AlicloudService) GetConfig() AlicloudConfig {
|
||||
return a.config
|
||||
}
|
||||
143
internal/infrastructure/external/alicloud/alicloud_service_test.go
vendored
Normal file
143
internal/infrastructure/external/alicloud/alicloud_service_test.go
vendored
Normal file
@@ -0,0 +1,143 @@
|
||||
package alicloud
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRealAlicloudAPI(t *testing.T) {
|
||||
// 使用真实的阿里云API配置
|
||||
host := "https://kzidcardv1.market.alicloudapi.com"
|
||||
appCode := "d55b58829efb41c8aa8e86769cba4844"
|
||||
|
||||
service := NewAlicloudService(host, appCode)
|
||||
|
||||
// 测试真实的身份证验证
|
||||
name := "张荣宏"
|
||||
idCard := "45212220000827423X"
|
||||
|
||||
fmt.Printf("开始测试阿里云二要素验证API...\n")
|
||||
fmt.Printf("姓名: %s\n", name)
|
||||
fmt.Printf("身份证: %s\n", idCard)
|
||||
|
||||
// 构建请求参数
|
||||
params := map[string]interface{}{
|
||||
"name": name,
|
||||
"idcard": idCard,
|
||||
}
|
||||
|
||||
// 调用真实API
|
||||
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
|
||||
if err != nil {
|
||||
t.Logf("API调用失败: %v", err)
|
||||
fmt.Printf("错误详情: %v\n", err)
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
|
||||
// 打印原始响应数据
|
||||
fmt.Printf("API响应成功!\n")
|
||||
fmt.Printf("原始响应数据: %s\n", string(responseBody))
|
||||
|
||||
// 解析完整响应结构
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Birthday string `json:"birthday"`
|
||||
Result int `json:"result"`
|
||||
Address string `json:"address"`
|
||||
OrderNo string `json:"orderNo"`
|
||||
Sex string `json:"sex"`
|
||||
Desc string `json:"desc"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
t.Logf("响应数据解析失败: %v", err)
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
|
||||
// 检查响应状态
|
||||
if response.Code != 200 {
|
||||
t.Logf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
|
||||
t.Fail()
|
||||
return
|
||||
}
|
||||
|
||||
idCardData := response.Data
|
||||
|
||||
// 打印详细响应结果
|
||||
fmt.Printf("验证结果: %d\n", idCardData.Result)
|
||||
fmt.Printf("描述: %s\n", idCardData.Desc)
|
||||
fmt.Printf("生日: %s\n", idCardData.Birthday)
|
||||
fmt.Printf("性别: %s\n", idCardData.Sex)
|
||||
fmt.Printf("地址: %s\n", idCardData.Address)
|
||||
fmt.Printf("订单号: %s\n", idCardData.OrderNo)
|
||||
|
||||
// 将完整响应转换为JSON并打印
|
||||
jsonResponse, _ := json.MarshalIndent(idCardData, "", " ")
|
||||
fmt.Printf("完整响应JSON:\n%s\n", string(jsonResponse))
|
||||
|
||||
// 判断验证结果
|
||||
if idCardData.Result == 1 {
|
||||
fmt.Printf("验证结果: 通过\n")
|
||||
} else {
|
||||
fmt.Printf("验证结果: 失败\n")
|
||||
}
|
||||
}
|
||||
|
||||
// TestAlicloudAPIError 测试错误响应
|
||||
func TestAlicloudAPIError(t *testing.T) {
|
||||
// 使用真实的阿里云API配置
|
||||
host := "https://kzidcardv1.market.alicloudapi.com"
|
||||
appCode := "d55b58829efb41c8aa8e86769cba4844"
|
||||
|
||||
service := NewAlicloudService(host, appCode)
|
||||
|
||||
// 测试无效的身份证号码
|
||||
name := "张三"
|
||||
invalidIdCard := "123456789"
|
||||
|
||||
fmt.Printf("测试错误响应 - 无效身份证号\n")
|
||||
fmt.Printf("姓名: %s\n", name)
|
||||
fmt.Printf("身份证: %s\n", invalidIdCard)
|
||||
|
||||
// 构建请求参数
|
||||
params := map[string]interface{}{
|
||||
"name": name,
|
||||
"idcard": invalidIdCard,
|
||||
}
|
||||
|
||||
// 调用真实API
|
||||
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
|
||||
if err != nil {
|
||||
fmt.Printf("网络请求错误: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Code int `json:"code"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
fmt.Printf("响应解析失败: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为错误响应
|
||||
if response.Code != 200 {
|
||||
fmt.Printf("预期的错误响应: code=%d, msg=%s\n", response.Code, response.Msg)
|
||||
fmt.Printf("错误处理正确: API返回错误状态\n")
|
||||
} else {
|
||||
t.Error("期望返回错误,但实际成功")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
76
internal/infrastructure/external/alicloud/example.go
vendored
Normal file
76
internal/infrastructure/external/alicloud/example.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package alicloud
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// ExampleUsage 使用示例
|
||||
func ExampleUsage() {
|
||||
// 创建阿里云服务实例
|
||||
// 请替换为您的实际配置
|
||||
host := "https://kzidcardv1.market.alicloudapi.com"
|
||||
appCode := "您的AppCode"
|
||||
|
||||
service := NewAlicloudService(host, appCode)
|
||||
|
||||
// 示例:验证身份证信息
|
||||
name := "张三"
|
||||
idCard := "110101199001011234"
|
||||
|
||||
// 构建请求参数
|
||||
params := map[string]interface{}{
|
||||
"name": name,
|
||||
"idcard": idCard,
|
||||
}
|
||||
|
||||
// 调用API
|
||||
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
|
||||
if err != nil {
|
||||
log.Printf("验证失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析完整响应结构
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Birthday string `json:"birthday"`
|
||||
Result int `json:"result"`
|
||||
Address string `json:"address"`
|
||||
OrderNo string `json:"orderNo"`
|
||||
Sex string `json:"sex"`
|
||||
Desc string `json:"desc"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
log.Printf("响应解析失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查响应状态
|
||||
if response.Code != 200 {
|
||||
log.Printf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
|
||||
return
|
||||
}
|
||||
|
||||
idCardData := response.Data
|
||||
|
||||
fmt.Printf("验证结果: %d\n", idCardData.Result)
|
||||
fmt.Printf("描述: %s\n", idCardData.Desc)
|
||||
fmt.Printf("生日: %s\n", idCardData.Birthday)
|
||||
fmt.Printf("性别: %s\n", idCardData.Sex)
|
||||
fmt.Printf("地址: %s\n", idCardData.Address)
|
||||
fmt.Printf("订单号: %s\n", idCardData.OrderNo)
|
||||
|
||||
// 判断验证结果
|
||||
if idCardData.Result == 1 {
|
||||
fmt.Println("身份证信息验证通过")
|
||||
} else {
|
||||
fmt.Println("身份证信息验证失败")
|
||||
}
|
||||
}
|
||||
162
internal/infrastructure/external/alicloud/example_advanced.go
vendored
Normal file
162
internal/infrastructure/external/alicloud/example_advanced.go
vendored
Normal file
@@ -0,0 +1,162 @@
|
||||
package alicloud
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
)
|
||||
|
||||
// ExampleAdvancedUsage 高级使用示例
|
||||
func ExampleAdvancedUsage() {
|
||||
// 创建阿里云服务实例
|
||||
host := "https://kzidcardv1.market.alicloudapi.com"
|
||||
appCode := "您的AppCode"
|
||||
|
||||
service := NewAlicloudService(host, appCode)
|
||||
|
||||
// 示例1: 身份证二要素验证
|
||||
fmt.Println("=== 示例1: 身份证二要素验证 ===")
|
||||
exampleIdCardCheck(service)
|
||||
|
||||
// 示例2: 其他API调用(假设)
|
||||
fmt.Println("\n=== 示例2: 其他API调用 ===")
|
||||
exampleOtherAPI(service)
|
||||
}
|
||||
|
||||
// exampleIdCardCheck 身份证验证示例
|
||||
func exampleIdCardCheck(service *AlicloudService) {
|
||||
// 构建请求参数
|
||||
params := map[string]interface{}{
|
||||
"name": "张三",
|
||||
"idcard": "110101199001011234",
|
||||
}
|
||||
|
||||
// 调用API
|
||||
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
|
||||
if err != nil {
|
||||
log.Printf("身份证验证失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析完整响应结构
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Data struct {
|
||||
Birthday string `json:"birthday"`
|
||||
Result int `json:"result"`
|
||||
Address string `json:"address"`
|
||||
OrderNo string `json:"orderNo"`
|
||||
Sex string `json:"sex"`
|
||||
Desc string `json:"desc"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
log.Printf("响应解析失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查响应状态
|
||||
if response.Code != 200 {
|
||||
log.Printf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
|
||||
return
|
||||
}
|
||||
|
||||
idCardData := response.Data
|
||||
|
||||
// 处理验证结果
|
||||
fmt.Printf("验证结果: %d (%s)\n", idCardData.Result, idCardData.Desc)
|
||||
fmt.Printf("生日: %s\n", idCardData.Birthday)
|
||||
fmt.Printf("性别: %s\n", idCardData.Sex)
|
||||
fmt.Printf("地址: %s\n", idCardData.Address)
|
||||
fmt.Printf("订单号: %s\n", idCardData.OrderNo)
|
||||
|
||||
if idCardData.Result == 1 {
|
||||
fmt.Println("✅ 身份证信息验证通过")
|
||||
} else {
|
||||
fmt.Println("❌ 身份证信息验证失败")
|
||||
}
|
||||
}
|
||||
|
||||
// exampleOtherAPI 其他API调用示例
|
||||
func exampleOtherAPI(service *AlicloudService) {
|
||||
// 假设调用其他API
|
||||
params := map[string]interface{}{
|
||||
"param1": "value1",
|
||||
"param2": "value2",
|
||||
}
|
||||
|
||||
// 调用API
|
||||
responseBody, err := service.CallAPI("other/api/path", params)
|
||||
if err != nil {
|
||||
log.Printf("API调用失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 根据具体API的响应结构进行解析
|
||||
// 这里只是示例,实际使用时需要根据API文档定义相应的结构体
|
||||
fmt.Printf("API响应数据: %s\n", string(responseBody))
|
||||
|
||||
// 示例:解析通用响应结构
|
||||
var genericData map[string]interface{}
|
||||
if err := json.Unmarshal(responseBody, &genericData); err != nil {
|
||||
log.Printf("响应解析失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("解析后的数据: %+v\n", genericData)
|
||||
}
|
||||
|
||||
// ExampleErrorHandling 错误处理示例
|
||||
func ExampleErrorHandling() {
|
||||
host := "https://kzidcardv1.market.alicloudapi.com"
|
||||
appCode := "您的AppCode"
|
||||
|
||||
service := NewAlicloudService(host, appCode)
|
||||
|
||||
// 测试各种错误情况
|
||||
testCases := []struct {
|
||||
name string
|
||||
idCard string
|
||||
desc string
|
||||
}{
|
||||
{"张三", "123456789", "无效身份证号"},
|
||||
{"", "110101199001011234", "空姓名"},
|
||||
{"张三", "", "空身份证号"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
fmt.Printf("\n测试: %s\n", tc.desc)
|
||||
|
||||
params := map[string]interface{}{
|
||||
"name": tc.name,
|
||||
"idcard": tc.idCard,
|
||||
}
|
||||
|
||||
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ 网络请求错误: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var response struct {
|
||||
Msg string `json:"msg"`
|
||||
Code int `json:"code"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(responseBody, &response); err != nil {
|
||||
fmt.Printf("❌ 响应解析失败: %v\n", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if response.Code != 200 {
|
||||
fmt.Printf("❌ 预期错误: code=%d, msg=%s\n", response.Code, response.Msg)
|
||||
} else {
|
||||
fmt.Printf("⚠️ 意外成功\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
134
internal/infrastructure/external/captcha/captcha_service.go
vendored
Normal file
134
internal/infrastructure/external/captcha/captcha_service.go
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/alibabacloud-go/tea/tea"
|
||||
captcha20230305 "github.com/alibabacloud-go/captcha-20230305/client"
|
||||
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrCaptchaVerifyFailed = errors.New("图形验证码校验失败")
|
||||
ErrCaptchaConfig = errors.New("验证码配置错误")
|
||||
ErrCaptchaEncryptMissing = errors.New("加密模式需要配置 EncryptKey(控制台 ekey)")
|
||||
)
|
||||
|
||||
// CaptchaConfig 阿里云验证码配置
|
||||
type CaptchaConfig struct {
|
||||
AccessKeyID string
|
||||
AccessKeySecret string
|
||||
EndpointURL string
|
||||
SceneID string
|
||||
// EncryptKey 加密模式使用的密钥(控制台 ekey,Base64 编码的 32 字节),用于生成 EncryptedSceneId
|
||||
EncryptKey string
|
||||
}
|
||||
|
||||
// CaptchaService 阿里云验证码服务
|
||||
type CaptchaService struct {
|
||||
config CaptchaConfig
|
||||
}
|
||||
|
||||
// NewCaptchaService 创建验证码服务实例
|
||||
func NewCaptchaService(config CaptchaConfig) *CaptchaService {
|
||||
return &CaptchaService{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify 验证滑块验证码
|
||||
func (s *CaptchaService) Verify(captchaVerifyParam string) error {
|
||||
if captchaVerifyParam == "" {
|
||||
return ErrCaptchaVerifyFailed
|
||||
}
|
||||
|
||||
if s.config.AccessKeyID == "" || s.config.AccessKeySecret == "" {
|
||||
return ErrCaptchaConfig
|
||||
}
|
||||
|
||||
clientCfg := &openapi.Config{
|
||||
AccessKeyId: tea.String(s.config.AccessKeyID),
|
||||
AccessKeySecret: tea.String(s.config.AccessKeySecret),
|
||||
}
|
||||
clientCfg.Endpoint = tea.String(s.config.EndpointURL)
|
||||
|
||||
client, err := captcha20230305.NewClient(clientCfg)
|
||||
if err != nil {
|
||||
return errors.Join(ErrCaptchaConfig, err)
|
||||
}
|
||||
|
||||
req := &captcha20230305.VerifyIntelligentCaptchaRequest{
|
||||
SceneId: tea.String(s.config.SceneID),
|
||||
CaptchaVerifyParam: tea.String(captchaVerifyParam),
|
||||
}
|
||||
|
||||
resp, err := client.VerifyIntelligentCaptcha(req)
|
||||
if err != nil {
|
||||
return errors.Join(ErrCaptchaVerifyFailed, err)
|
||||
}
|
||||
|
||||
if resp.Body == nil || !tea.BoolValue(resp.Body.Result.VerifyResult) {
|
||||
return ErrCaptchaVerifyFailed
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetEncryptedSceneId 生成加密场景 ID(EncryptedSceneId),供前端加密模式初始化验证码使用。
|
||||
// 算法:AES-256-CBC,明文 sceneId×tamp&expireTime,密钥为控制台 ekey(Base64 解码后 32 字节)。
|
||||
// expireTimeSec 有效期为 1~86400 秒。
|
||||
func (s *CaptchaService) GetEncryptedSceneId(expireTimeSec int) (string, error) {
|
||||
if expireTimeSec <= 0 || expireTimeSec > 86400 {
|
||||
return "", fmt.Errorf("expireTimeSec 必须在 1~86400 之间")
|
||||
}
|
||||
if s.config.EncryptKey == "" {
|
||||
return "", ErrCaptchaEncryptMissing
|
||||
}
|
||||
if s.config.SceneID == "" {
|
||||
return "", ErrCaptchaConfig
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(s.config.EncryptKey)
|
||||
if err != nil || len(keyBytes) != 32 {
|
||||
return "", errors.Join(ErrCaptchaConfig, fmt.Errorf("EncryptKey 必须为 Base64 编码的 32 字节"))
|
||||
}
|
||||
|
||||
plaintext := fmt.Sprintf("%s&%d&%d", s.config.SceneID, time.Now().Unix(), expireTimeSec)
|
||||
plainBytes := []byte(plaintext)
|
||||
plainBytes = pkcs7Pad(plainBytes, aes.BlockSize)
|
||||
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return "", errors.Join(ErrCaptchaConfig, err)
|
||||
}
|
||||
|
||||
iv := make([]byte, aes.BlockSize)
|
||||
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
ciphertext := make([]byte, len(plainBytes))
|
||||
mode.CryptBlocks(ciphertext, plainBytes)
|
||||
|
||||
result := make([]byte, len(iv)+len(ciphertext))
|
||||
copy(result, iv)
|
||||
copy(result[len(iv):], ciphertext)
|
||||
return base64.StdEncoding.EncodeToString(result), nil
|
||||
}
|
||||
|
||||
func pkcs7Pad(data []byte, blockSize int) []byte {
|
||||
n := blockSize - (len(data) % blockSize)
|
||||
pad := make([]byte, n)
|
||||
for i := range pad {
|
||||
pad[i] = byte(n)
|
||||
}
|
||||
return append(data, pad...)
|
||||
}
|
||||
712
internal/infrastructure/external/email/qq_email_service.go
vendored
Normal file
712
internal/infrastructure/external/email/qq_email_service.go
vendored
Normal file
@@ -0,0 +1,712 @@
|
||||
package email
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
)
|
||||
|
||||
// QQEmailService QQ邮箱服务
|
||||
type QQEmailService struct {
|
||||
config config.EmailConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// EmailData 邮件数据
|
||||
type EmailData struct {
|
||||
To string `json:"to"`
|
||||
Subject string `json:"subject"`
|
||||
Content string `json:"content"`
|
||||
Data map[string]interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// InvoiceEmailData 发票邮件数据
|
||||
type InvoiceEmailData struct {
|
||||
CompanyName string `json:"company_name"`
|
||||
Amount string `json:"amount"`
|
||||
InvoiceType string `json:"invoice_type"`
|
||||
FileURL string `json:"file_url"`
|
||||
FileName string `json:"file_name"`
|
||||
ReceivingEmail string `json:"receiving_email"`
|
||||
ApprovedAt string `json:"approved_at"`
|
||||
}
|
||||
|
||||
// NewQQEmailService 创建QQ邮箱服务
|
||||
func NewQQEmailService(config config.EmailConfig, logger *zap.Logger) *QQEmailService {
|
||||
return &QQEmailService{
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件
|
||||
func (s *QQEmailService) SendEmail(ctx context.Context, data *EmailData) error {
|
||||
s.logger.Info("开始发送邮件",
|
||||
zap.String("to", data.To),
|
||||
zap.String("subject", data.Subject),
|
||||
)
|
||||
|
||||
// 构建邮件内容
|
||||
message := s.buildEmailMessage(data)
|
||||
|
||||
// 发送邮件
|
||||
err := s.sendSMTP(data.To, data.Subject, message)
|
||||
if err != nil {
|
||||
s.logger.Error("发送邮件失败",
|
||||
zap.String("to", data.To),
|
||||
zap.String("subject", data.Subject),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("发送邮件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("邮件发送成功",
|
||||
zap.String("to", data.To),
|
||||
zap.String("subject", data.Subject),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendInvoiceEmail 发送发票邮件
|
||||
func (s *QQEmailService) SendInvoiceEmail(ctx context.Context, data *InvoiceEmailData) error {
|
||||
s.logger.Info("开始发送发票邮件",
|
||||
zap.String("to", data.ReceivingEmail),
|
||||
zap.String("company_name", data.CompanyName),
|
||||
zap.String("amount", data.Amount),
|
||||
)
|
||||
|
||||
// 构建邮件内容
|
||||
subject := "您的发票已开具成功"
|
||||
content := s.buildInvoiceEmailContent(data)
|
||||
|
||||
emailData := &EmailData{
|
||||
To: data.ReceivingEmail,
|
||||
Subject: subject,
|
||||
Content: content,
|
||||
Data: map[string]interface{}{
|
||||
"company_name": data.CompanyName,
|
||||
"amount": data.Amount,
|
||||
"invoice_type": data.InvoiceType,
|
||||
"file_url": data.FileURL,
|
||||
"file_name": data.FileName,
|
||||
"approved_at": data.ApprovedAt,
|
||||
},
|
||||
}
|
||||
|
||||
return s.SendEmail(ctx, emailData)
|
||||
}
|
||||
|
||||
// buildEmailMessage 构建邮件消息
|
||||
func (s *QQEmailService) buildEmailMessage(data *EmailData) string {
|
||||
headers := make(map[string]string)
|
||||
headers["From"] = s.config.FromEmail
|
||||
headers["To"] = data.To
|
||||
headers["Subject"] = data.Subject
|
||||
headers["MIME-Version"] = "1.0"
|
||||
headers["Content-Type"] = "text/html; charset=UTF-8"
|
||||
|
||||
var message strings.Builder
|
||||
for key, value := range headers {
|
||||
message.WriteString(fmt.Sprintf("%s: %s\r\n", key, value))
|
||||
}
|
||||
message.WriteString("\r\n")
|
||||
message.WriteString(data.Content)
|
||||
|
||||
return message.String()
|
||||
}
|
||||
|
||||
// buildInvoiceEmailContent 构建发票邮件内容
|
||||
func (s *QQEmailService) buildInvoiceEmailContent(data *InvoiceEmailData) string {
|
||||
htmlTemplate := `
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>发票开具成功通知</title>
|
||||
<style>
|
||||
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
|
||||
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Inter', 'Microsoft YaHei', Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #2d3748;
|
||||
background: linear-gradient(135deg, #f7fafc 0%, #edf2f7 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 650px;
|
||||
margin: 0 auto;
|
||||
background: #ffffff;
|
||||
border-radius: 24px;
|
||||
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.08);
|
||||
overflow: hidden;
|
||||
border: 1px solid #e2e8f0;
|
||||
}
|
||||
|
||||
.header {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 50px 40px 40px;
|
||||
text-align: center;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.header::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: -50%;
|
||||
left: -50%;
|
||||
width: 200%;
|
||||
height: 200%;
|
||||
background: radial-gradient(circle, rgba(255,255,255,0.08) 0%, transparent 70%);
|
||||
animation: float 6s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes float {
|
||||
0%, 100% { transform: translateY(0px) rotate(0deg); }
|
||||
50% { transform: translateY(-20px) rotate(180deg); }
|
||||
}
|
||||
|
||||
.success-icon {
|
||||
font-size: 48px;
|
||||
margin-bottom: 16px;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.header h1 {
|
||||
font-size: 24px;
|
||||
font-weight: 500;
|
||||
margin: 0;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.content {
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.greeting {
|
||||
padding: 40px 40px 20px;
|
||||
text-align: center;
|
||||
background: linear-gradient(135deg, #f8fafc 0%, #ffffff 100%);
|
||||
}
|
||||
|
||||
.greeting p {
|
||||
font-size: 16px;
|
||||
color: #4a5568;
|
||||
margin-bottom: 8px;
|
||||
font-weight: 400;
|
||||
}
|
||||
|
||||
.access-section {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
padding: 40px;
|
||||
text-align: center;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
margin: 0 20px 30px;
|
||||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.access-section::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: -50%;
|
||||
left: -50%;
|
||||
width: 200%;
|
||||
height: 200%;
|
||||
background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%);
|
||||
animation: shimmer 8s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes shimmer {
|
||||
0%, 100% { transform: translateX(-100%) translateY(-100%) rotate(0deg); }
|
||||
50% { transform: translateX(100%) translateY(100%) rotate(180deg); }
|
||||
}
|
||||
|
||||
.access-section h3 {
|
||||
color: white;
|
||||
font-size: 22px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 12px;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.access-section p {
|
||||
color: rgba(255, 255, 255, 0.9);
|
||||
margin-bottom: 25px;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
font-size: 15px;
|
||||
}
|
||||
|
||||
.access-btn {
|
||||
display: inline-block;
|
||||
background: rgba(255, 255, 255, 0.15);
|
||||
color: white;
|
||||
padding: 16px 32px;
|
||||
text-decoration: none;
|
||||
border-radius: 50px;
|
||||
font-weight: 600;
|
||||
font-size: 15px;
|
||||
border: 2px solid rgba(255, 255, 255, 0.2);
|
||||
transition: all 0.3s ease;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
backdrop-filter: blur(10px);
|
||||
letter-spacing: 0.3px;
|
||||
}
|
||||
|
||||
.access-btn:hover {
|
||||
background: rgba(255, 255, 255, 0.25);
|
||||
transform: translateY(-3px);
|
||||
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.15);
|
||||
border-color: rgba(255, 255, 255, 0.3);
|
||||
}
|
||||
|
||||
.info-section {
|
||||
padding: 0 40px 40px;
|
||||
}
|
||||
|
||||
.info-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
|
||||
gap: 20px;
|
||||
margin: 30px 0;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
background: linear-gradient(135deg, #f8fafc 0%, #ffffff 100%);
|
||||
padding: 24px;
|
||||
border-radius: 16px;
|
||||
border: 1px solid #e2e8f0;
|
||||
position: relative;
|
||||
overflow: hidden;
|
||||
transition: all 0.3s ease;
|
||||
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.info-item:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 12px 25px -5px rgba(102, 126, 234, 0.1);
|
||||
border-color: #cbd5e0;
|
||||
}
|
||||
|
||||
.info-item::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 4px;
|
||||
height: 100%;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border-radius: 0 2px 2px 0;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
font-weight: 600;
|
||||
color: #718096;
|
||||
display: block;
|
||||
margin-bottom: 8px;
|
||||
font-size: 13px;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.8px;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: #2d3748;
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.notes-section {
|
||||
background: linear-gradient(135deg, #f0fff4 0%, #ffffff 100%);
|
||||
padding: 30px;
|
||||
border-radius: 16px;
|
||||
margin: 30px 0;
|
||||
border: 1px solid #c6f6d5;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.notes-section::before {
|
||||
content: '';
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 4px;
|
||||
height: 100%;
|
||||
background: linear-gradient(135deg, #48bb78 0%, #38a169 100%);
|
||||
border-radius: 0 2px 2px 0;
|
||||
}
|
||||
|
||||
.notes-section h4 {
|
||||
color: #2f855a;
|
||||
font-size: 16px;
|
||||
font-weight: 600;
|
||||
margin-bottom: 16px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
.notes-section h4::before {
|
||||
content: '📋';
|
||||
margin-right: 8px;
|
||||
font-size: 18px;
|
||||
}
|
||||
|
||||
.notes-section ul {
|
||||
list-style: none;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.notes-section li {
|
||||
color: #4a5568;
|
||||
margin-bottom: 10px;
|
||||
padding-left: 24px;
|
||||
position: relative;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.notes-section li::before {
|
||||
content: '✓';
|
||||
color: #48bb78;
|
||||
font-weight: bold;
|
||||
position: absolute;
|
||||
left: 0;
|
||||
font-size: 16px;
|
||||
}
|
||||
|
||||
.footer {
|
||||
background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%);
|
||||
color: rgba(255, 255, 255, 0.8);
|
||||
padding: 35px 40px;
|
||||
text-align: center;
|
||||
font-size: 14px;
|
||||
}
|
||||
|
||||
.footer p {
|
||||
margin-bottom: 8px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.footer-divider {
|
||||
width: 60px;
|
||||
height: 2px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
margin: 20px auto;
|
||||
border-radius: 1px;
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.container {
|
||||
margin: 10px;
|
||||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.header {
|
||||
padding: 40px 30px 30px;
|
||||
}
|
||||
|
||||
.greeting {
|
||||
padding: 30px 30px 20px;
|
||||
}
|
||||
|
||||
.access-section {
|
||||
margin: 0 15px 25px;
|
||||
padding: 30px 25px;
|
||||
}
|
||||
|
||||
.info-section {
|
||||
padding: 0 30px 30px;
|
||||
}
|
||||
|
||||
.info-grid {
|
||||
grid-template-columns: 1fr;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
.footer {
|
||||
padding: 30px 30px;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<div class="success-icon">✓</div>
|
||||
<h1>发票已开具完成</h1>
|
||||
</div>
|
||||
|
||||
<div class="content">
|
||||
<div class="greeting">
|
||||
<p>尊敬的用户,您好!</p>
|
||||
<p>您的发票申请已审核通过,发票已成功开具。</p>
|
||||
</div>
|
||||
|
||||
<div class="access-section">
|
||||
<h3>📄 发票访问链接</h3>
|
||||
<p>您的发票已准备就绪,请点击下方按钮访问查看页面</p>
|
||||
<a href="{{.FileURL}}" class="access-btn" target="_blank">
|
||||
🔗 访问发票页面
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<div class="info-section">
|
||||
<div class="info-grid">
|
||||
<div class="info-item">
|
||||
<span class="info-label">公司名称</span>
|
||||
<span class="info-value">{{.CompanyName}}</span>
|
||||
</div>
|
||||
|
||||
<div class="info-item">
|
||||
<span class="info-label">发票金额</span>
|
||||
<span class="info-value">¥{{.Amount}}</span>
|
||||
</div>
|
||||
|
||||
<div class="info-item">
|
||||
<span class="info-label">发票类型</span>
|
||||
<span class="info-value">{{.InvoiceType}}</span>
|
||||
</div>
|
||||
|
||||
<div class="info-item">
|
||||
<span class="info-label">开具时间</span>
|
||||
<span class="info-value">{{.ApprovedAt}}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="notes-section">
|
||||
<h4>注意事项</h4>
|
||||
<ul>
|
||||
<li>访问页面后可在页面内下载发票文件</li>
|
||||
<li>请妥善保管发票文件,建议打印存档</li>
|
||||
<li>如有疑问,请回到我们平台进行下载</li>
|
||||
</ul>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="footer">
|
||||
<p>此邮件由系统自动发送,请勿回复</p>
|
||||
<div class="footer-divider"></div>
|
||||
<p>海宇数据 API 服务平台</p>
|
||||
<p>发送时间:{{.CurrentTime}}</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>`
|
||||
|
||||
// 解析模板
|
||||
tmpl, err := template.New("invoice_email").Parse(htmlTemplate)
|
||||
if err != nil {
|
||||
s.logger.Error("解析邮件模板失败", zap.Error(err))
|
||||
return s.buildSimpleInvoiceEmail(data)
|
||||
}
|
||||
|
||||
// 准备模板数据
|
||||
templateData := struct {
|
||||
CompanyName string
|
||||
Amount string
|
||||
InvoiceType string
|
||||
FileURL string
|
||||
FileName string
|
||||
ApprovedAt string
|
||||
CurrentTime string
|
||||
Domain string
|
||||
}{
|
||||
CompanyName: data.CompanyName,
|
||||
Amount: data.Amount,
|
||||
InvoiceType: data.InvoiceType,
|
||||
FileURL: data.FileURL,
|
||||
FileName: data.FileName,
|
||||
ApprovedAt: data.ApprovedAt,
|
||||
CurrentTime: time.Now().Format("2006-01-02 15:04:05"),
|
||||
Domain: s.config.Domain,
|
||||
}
|
||||
|
||||
// 执行模板
|
||||
var content strings.Builder
|
||||
err = tmpl.Execute(&content, templateData)
|
||||
if err != nil {
|
||||
s.logger.Error("执行邮件模板失败", zap.Error(err))
|
||||
return s.buildSimpleInvoiceEmail(data)
|
||||
}
|
||||
|
||||
return content.String()
|
||||
}
|
||||
|
||||
// buildSimpleInvoiceEmail 构建简单的发票邮件内容(备用方案)
|
||||
func (s *QQEmailService) buildSimpleInvoiceEmail(data *InvoiceEmailData) string {
|
||||
return fmt.Sprintf(`
|
||||
发票开具成功通知
|
||||
|
||||
尊敬的用户,您好!
|
||||
|
||||
您的发票申请已审核通过,发票已成功开具。
|
||||
|
||||
发票信息:
|
||||
- 公司名称:%s
|
||||
- 发票金额:¥%s
|
||||
- 发票类型:%s
|
||||
- 开具时间:%s
|
||||
|
||||
发票文件下载链接:%s
|
||||
文件名:%s
|
||||
|
||||
如有疑问,请访问控制台查看详细信息:https://%s
|
||||
|
||||
海宇数据 API 服务平台
|
||||
%s
|
||||
`, data.CompanyName, data.Amount, data.InvoiceType, data.ApprovedAt, data.FileURL, data.FileName, s.config.Domain, time.Now().Format("2006-01-02 15:04:05"))
|
||||
}
|
||||
|
||||
// sendSMTP 通过SMTP发送邮件
|
||||
func (s *QQEmailService) sendSMTP(to, subject, message string) error {
|
||||
// 构建认证信息
|
||||
auth := smtp.PlainAuth("", s.config.Username, s.config.Password, s.config.Host)
|
||||
|
||||
// 构建收件人列表
|
||||
toList := []string{to}
|
||||
|
||||
// 发送邮件
|
||||
if s.config.UseSSL {
|
||||
// QQ邮箱587端口使用STARTTLS,465端口使用直接SSL
|
||||
if s.config.Port == 587 {
|
||||
// 使用STARTTLS (587端口)
|
||||
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port))
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接SMTP服务器失败: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, s.config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建SMTP客户端失败: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 启用STARTTLS
|
||||
if err = client.StartTLS(&tls.Config{
|
||||
ServerName: s.config.Host,
|
||||
InsecureSkipVerify: false,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("启用STARTTLS失败: %w", err)
|
||||
}
|
||||
|
||||
// 认证
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置发件人
|
||||
if err = client.Mail(s.config.FromEmail); err != nil {
|
||||
return fmt.Errorf("设置发件人失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置收件人
|
||||
for _, recipient := range toList {
|
||||
if err = client.Rcpt(recipient); err != nil {
|
||||
return fmt.Errorf("设置收件人失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送邮件内容
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("准备发送邮件内容失败: %w", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
_, err = writer.Write([]byte(message))
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送邮件内容失败: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 使用直接SSL连接 (465端口)
|
||||
tlsConfig := &tls.Config{
|
||||
ServerName: s.config.Host,
|
||||
InsecureSkipVerify: false,
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接SMTP服务器失败: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client, err := smtp.NewClient(conn, s.config.Host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建SMTP客户端失败: %w", err)
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
// 认证
|
||||
if err = client.Auth(auth); err != nil {
|
||||
return fmt.Errorf("SMTP认证失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置发件人
|
||||
if err = client.Mail(s.config.FromEmail); err != nil {
|
||||
return fmt.Errorf("设置发件人失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置收件人
|
||||
for _, recipient := range toList {
|
||||
if err = client.Rcpt(recipient); err != nil {
|
||||
return fmt.Errorf("设置收件人失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 发送邮件内容
|
||||
writer, err := client.Data()
|
||||
if err != nil {
|
||||
return fmt.Errorf("准备发送邮件内容失败: %w", err)
|
||||
}
|
||||
defer writer.Close()
|
||||
|
||||
_, err = writer.Write([]byte(message))
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送邮件内容失败: %w", err)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 使用普通连接
|
||||
err := smtp.SendMail(
|
||||
fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
|
||||
auth,
|
||||
s.config.FromEmail,
|
||||
toList,
|
||||
[]byte(message),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送邮件失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
301
internal/infrastructure/external/esign/certification_esign_service.go
vendored
Normal file
301
internal/infrastructure/external/esign/certification_esign_service.go
vendored
Normal file
@@ -0,0 +1,301 @@
|
||||
package esign
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/domains/certification/entities/value_objects"
|
||||
"hyapi-server/internal/domains/certification/enums"
|
||||
"hyapi-server/internal/domains/certification/repositories"
|
||||
"hyapi-server/internal/shared/esign"
|
||||
)
|
||||
|
||||
// ================ 常量定义 ================
|
||||
|
||||
const (
|
||||
// 企业认证超时时间
|
||||
EnterpriseAuthTimeout = 30 * time.Minute
|
||||
|
||||
// 合同签署超时时间
|
||||
ContractSignTimeout = 7 * 24 * time.Hour // 7天
|
||||
|
||||
// 回调重试次数
|
||||
MaxCallbackRetries = 3
|
||||
)
|
||||
|
||||
// ================ 服务实现 ================
|
||||
|
||||
// CertificationEsignService 认证e签宝服务实现
|
||||
//
|
||||
// 业务职责:
|
||||
// - 处理企业认证流程
|
||||
// - 处理合同生成和签署
|
||||
// - 处理e签宝回调
|
||||
// - 管理认证状态更新
|
||||
type CertificationEsignService struct {
|
||||
esignClient *esign.Client
|
||||
commandRepo repositories.CertificationCommandRepository
|
||||
queryRepo repositories.CertificationQueryRepository
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCertificationEsignService 创建认证e签宝服务
|
||||
func NewCertificationEsignService(
|
||||
esignClient *esign.Client,
|
||||
commandRepo repositories.CertificationCommandRepository,
|
||||
queryRepo repositories.CertificationQueryRepository,
|
||||
logger *zap.Logger,
|
||||
) *CertificationEsignService {
|
||||
return &CertificationEsignService{
|
||||
esignClient: esignClient,
|
||||
commandRepo: commandRepo,
|
||||
queryRepo: queryRepo,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 企业认证流程 ================
|
||||
|
||||
// StartEnterpriseAuth 开始企业认证
|
||||
//
|
||||
// 业务流程:
|
||||
// 1. 调用e签宝企业认证API
|
||||
// 2. 更新认证记录的auth_flow_id
|
||||
// 3. 更新状态为企业认证中
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - certificationID: 认证ID
|
||||
// - enterpriseInfo: 企业信息
|
||||
//
|
||||
// 返回:
|
||||
// - authURL: 认证URL
|
||||
// - error: 错误信息
|
||||
func (s *CertificationEsignService) StartEnterpriseAuth(
|
||||
ctx context.Context,
|
||||
certificationID string,
|
||||
enterpriseInfo *value_objects.EnterpriseInfo,
|
||||
) (string, error) {
|
||||
s.logger.Info("开始企业认证",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.String("company_name", enterpriseInfo.CompanyName))
|
||||
|
||||
// TODO: 实现e签宝企业认证API调用
|
||||
// 暂时使用模拟响应
|
||||
authFlowID := fmt.Sprintf("auth_%s_%d", certificationID, time.Now().Unix())
|
||||
authURL := fmt.Sprintf("https://esign.example.com/auth/%s", authFlowID)
|
||||
|
||||
s.logger.Info("模拟调用e签宝企业认证API",
|
||||
zap.String("auth_flow_id", authFlowID),
|
||||
zap.String("auth_url", authURL))
|
||||
|
||||
// 更新认证记录
|
||||
if err := s.commandRepo.UpdateAuthFlowID(ctx, certificationID, authFlowID); err != nil {
|
||||
s.logger.Error("更新认证流程ID失败", zap.Error(err))
|
||||
return "", fmt.Errorf("更新认证流程ID失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("企业认证启动成功",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.String("auth_flow_id", authFlowID))
|
||||
|
||||
return authURL, nil
|
||||
}
|
||||
|
||||
// HandleEnterpriseAuthCallback 处理企业认证回调
|
||||
//
|
||||
// 业务流程:
|
||||
// 1. 根据回调信息查找认证记录
|
||||
// 2. 根据回调状态更新认证状态
|
||||
// 3. 如果成功,继续合同生成流程
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - authFlowID: 认证流程ID
|
||||
// - success: 是否成功
|
||||
// - message: 回调消息
|
||||
//
|
||||
// 返回:
|
||||
// - error: 错误信息
|
||||
func (s *CertificationEsignService) HandleEnterpriseAuthCallback(
|
||||
ctx context.Context,
|
||||
authFlowID string,
|
||||
success bool,
|
||||
message string,
|
||||
) error {
|
||||
s.logger.Info("处理企业认证回调",
|
||||
zap.String("auth_flow_id", authFlowID),
|
||||
zap.Bool("success", success))
|
||||
|
||||
// 查找认证记录
|
||||
cert, err := s.queryRepo.FindByAuthFlowID(ctx, authFlowID)
|
||||
if err != nil {
|
||||
s.logger.Error("根据认证流程ID查找认证记录失败", zap.Error(err))
|
||||
return fmt.Errorf("查找认证记录失败: %w", err)
|
||||
}
|
||||
|
||||
if success {
|
||||
// 企业认证成功,更新状态
|
||||
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusEnterpriseVerified); err != nil {
|
||||
s.logger.Error("更新认证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新认证状态失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("企业认证成功", zap.String("certification_id", cert.ID))
|
||||
} else {
|
||||
// 企业认证失败,更新状态
|
||||
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusInfoRejected); err != nil {
|
||||
s.logger.Error("更新认证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新认证状态失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("企业认证失败", zap.String("certification_id", cert.ID), zap.String("reason", message))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 合同管理流程 ================
|
||||
|
||||
// GenerateContract 生成认证合同
|
||||
//
|
||||
// 业务流程:
|
||||
// 1. 调用e签宝合同生成API
|
||||
// 2. 更新认证记录的合同信息
|
||||
// 3. 更新状态为合同已生成
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - certificationID: 认证ID
|
||||
//
|
||||
// 返回:
|
||||
// - contractSignURL: 合同签署URL
|
||||
// - error: 错误信息
|
||||
func (s *CertificationEsignService) GenerateContract(
|
||||
ctx context.Context,
|
||||
certificationID string,
|
||||
) (string, error) {
|
||||
s.logger.Info("生成认证合同", zap.String("certification_id", certificationID))
|
||||
|
||||
// TODO: 实现e签宝合同生成API调用
|
||||
// 暂时使用模拟响应
|
||||
contractFileID := fmt.Sprintf("contract_%s_%d", certificationID, time.Now().Unix())
|
||||
esignFlowID := fmt.Sprintf("flow_%s_%d", certificationID, time.Now().Unix())
|
||||
contractURL := fmt.Sprintf("https://esign.example.com/contract/%s", contractFileID)
|
||||
contractSignURL := fmt.Sprintf("https://esign.example.com/sign/%s", esignFlowID)
|
||||
|
||||
s.logger.Info("模拟调用e签宝合同生成API",
|
||||
zap.String("contract_file_id", contractFileID),
|
||||
zap.String("esign_flow_id", esignFlowID))
|
||||
|
||||
// 更新认证记录
|
||||
if err := s.commandRepo.UpdateContractInfo(
|
||||
ctx,
|
||||
certificationID,
|
||||
contractFileID,
|
||||
esignFlowID,
|
||||
contractURL,
|
||||
contractSignURL,
|
||||
); err != nil {
|
||||
s.logger.Error("更新合同信息失败", zap.Error(err))
|
||||
return "", fmt.Errorf("更新合同信息失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新状态
|
||||
if err := s.commandRepo.UpdateStatus(ctx, certificationID, enums.StatusContractApplied); err != nil {
|
||||
s.logger.Error("更新认证状态失败", zap.Error(err))
|
||||
return "", fmt.Errorf("更新认证状态失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("认证合同生成成功",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.String("contract_file_id", contractFileID))
|
||||
|
||||
return contractSignURL, nil
|
||||
}
|
||||
|
||||
// HandleContractSignCallback 处理合同签署回调
|
||||
//
|
||||
// 业务流程:
|
||||
// 1. 根据回调信息查找认证记录
|
||||
// 2. 根据回调状态更新认证状态
|
||||
// 3. 如果成功,认证流程完成
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - esignFlowID: e签宝流程ID
|
||||
// - success: 是否成功
|
||||
// - signedFileURL: 已签署文件URL
|
||||
//
|
||||
// 返回:
|
||||
// - error: 错误信息
|
||||
func (s *CertificationEsignService) HandleContractSignCallback(
|
||||
ctx context.Context,
|
||||
esignFlowID string,
|
||||
success bool,
|
||||
signedFileURL string,
|
||||
) error {
|
||||
s.logger.Info("处理合同签署回调",
|
||||
zap.String("esign_flow_id", esignFlowID),
|
||||
zap.Bool("success", success))
|
||||
|
||||
// 查找认证记录
|
||||
cert, err := s.queryRepo.FindByEsignFlowID(ctx, esignFlowID)
|
||||
if err != nil {
|
||||
s.logger.Error("根据e签宝流程ID查找认证记录失败", zap.Error(err))
|
||||
return fmt.Errorf("查找认证记录失败: %w", err)
|
||||
}
|
||||
|
||||
if success {
|
||||
// 合同签署成功,更新合同URL
|
||||
if err := s.commandRepo.UpdateContractInfo(ctx, cert.ID, cert.ContractFileID, cert.EsignFlowID, signedFileURL, cert.ContractSignURL); err != nil {
|
||||
s.logger.Error("更新合同URL失败", zap.Error(err))
|
||||
return fmt.Errorf("更新合同URL失败: %w", err)
|
||||
}
|
||||
|
||||
// 更新状态到合同已签署
|
||||
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusContractSigned); err != nil {
|
||||
s.logger.Error("更新认证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新认证状态失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("合同签署成功", zap.String("certification_id", cert.ID))
|
||||
} else {
|
||||
// 合同签署失败
|
||||
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusContractRejected); err != nil {
|
||||
s.logger.Error("更新认证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新认证状态失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("合同签署失败", zap.String("certification_id", cert.ID))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ================ 辅助方法 ================
|
||||
|
||||
// GetContractSignURL 获取合同签署URL
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - certificationID: 认证ID
|
||||
//
|
||||
// 返回:
|
||||
// - signURL: 签署URL
|
||||
// - error: 错误信息
|
||||
func (s *CertificationEsignService) GetContractSignURL(ctx context.Context, certificationID string) (string, error) {
|
||||
cert, err := s.queryRepo.GetByID(ctx, certificationID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取认证信息失败: %w", err)
|
||||
}
|
||||
|
||||
if cert.ContractSignURL == "" {
|
||||
return "", fmt.Errorf("合同签署URL尚未生成")
|
||||
}
|
||||
|
||||
return cert.ContractSignURL, nil
|
||||
}
|
||||
48
internal/infrastructure/external/jiguang/crypto.go
vendored
Normal file
48
internal/infrastructure/external/jiguang/crypto.go
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
package jiguang
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SignMethod 签名方法类型
|
||||
type SignMethod string
|
||||
|
||||
const (
|
||||
SignMethodMD5 SignMethod = "md5"
|
||||
SignMethodHMACMD5 SignMethod = "hmac"
|
||||
)
|
||||
|
||||
// GenerateSign 生成签名
|
||||
// 根据 signMethod 参数选择使用 MD5 或 HMAC-MD5 算法
|
||||
// MD5: md5(timestamp + "&appSecret=" + appSecret),然后转大写十六进制
|
||||
// HMAC-MD5: hmac_md5(timestamp, appSecret),然后转大写十六进制
|
||||
func GenerateSign(timestamp string, appSecret string, signMethod SignMethod) (string, error) {
|
||||
var hashBytes []byte
|
||||
|
||||
switch signMethod {
|
||||
case SignMethodMD5:
|
||||
// MD5算法:在待签名字符串后面加上 &appSecret=xxx 再进行计算
|
||||
signStr := timestamp + "&appSecret=" + appSecret
|
||||
hash := md5.Sum([]byte(signStr))
|
||||
hashBytes = hash[:]
|
||||
case SignMethodHMACMD5:
|
||||
// HMAC-MD5算法:使用 appSecret 初始化摘要算法再进行计算
|
||||
mac := hmac.New(md5.New, []byte(appSecret))
|
||||
mac.Write([]byte(timestamp))
|
||||
hashBytes = mac.Sum(nil)
|
||||
default:
|
||||
return "", fmt.Errorf("不支持的签名方法: %s", signMethod)
|
||||
}
|
||||
|
||||
// 将二进制转化为大写的十六进制(正确签名应该为32大写字符串)
|
||||
return strings.ToUpper(hex.EncodeToString(hashBytes)), nil
|
||||
}
|
||||
|
||||
// GenerateSignWithDefault 使用默认的 HMAC-MD5 方法生成签名
|
||||
func GenerateSignWithDefault(timestamp string, appSecret string) (string, error) {
|
||||
return GenerateSign(timestamp, appSecret, SignMethodHMACMD5)
|
||||
}
|
||||
149
internal/infrastructure/external/jiguang/jiguang_errors.go
vendored
Normal file
149
internal/infrastructure/external/jiguang/jiguang_errors.go
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
package jiguang
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// JiguangError 极光服务错误
|
||||
type JiguangError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (e *JiguangError) Error() string {
|
||||
return fmt.Sprintf("极光错误 [%d]: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// IsSuccess 检查是否成功
|
||||
func (e *JiguangError) IsSuccess() bool {
|
||||
return e.Code == 0
|
||||
}
|
||||
|
||||
// IsQueryFailed 检查是否查询失败
|
||||
func (e *JiguangError) IsQueryFailed() bool {
|
||||
return e.Code == 922
|
||||
}
|
||||
|
||||
// IsNoRecord 检查是否查无记录
|
||||
func (e *JiguangError) IsNoRecord() bool {
|
||||
return e.Code == 921
|
||||
}
|
||||
|
||||
// IsParamError 检查是否是参数相关错误
|
||||
func (e *JiguangError) IsParamError() bool {
|
||||
return e.Code == 400 || e.Code == 906 || e.Code == 914 || e.Code == 918
|
||||
}
|
||||
|
||||
// IsAuthError 检查是否是认证相关错误
|
||||
func (e *JiguangError) IsAuthError() bool {
|
||||
return e.Code == 902 || e.Code == 903 || e.Code == 904 || e.Code == 905
|
||||
}
|
||||
|
||||
// IsSystemError 检查是否是系统错误
|
||||
func (e *JiguangError) IsSystemError() bool {
|
||||
return e.Code == 405 || e.Code == 911 || e.Code == 912 || e.Code == 915 || e.Code == 916 || e.Code == 917 || e.Code == 919 || e.Code == 923
|
||||
}
|
||||
|
||||
// 预定义错误常量
|
||||
var (
|
||||
// 成功状态
|
||||
ErrSuccess = &JiguangError{Code: 0, Message: "请求成功"}
|
||||
|
||||
// 参数错误
|
||||
ErrParamInvalid = &JiguangError{Code: 400, Message: "请求参数不正确,QCXGGB2Q查询为空"}
|
||||
ErrMethodInvalid = &JiguangError{Code: 405, Message: "请求方法不正确"}
|
||||
ErrParamFormInvalid = &JiguangError{Code: 906, Message: "请求参数形式不正确"}
|
||||
ErrBodyIncomplete = &JiguangError{Code: 914, Message: "Body 请求参数不完整"}
|
||||
ErrBodyNotSupported = &JiguangError{Code: 918, Message: "Body 请求参数不支持"}
|
||||
|
||||
// 认证错误
|
||||
ErrAppIDInvalid = &JiguangError{Code: 902, Message: "错误的 appId/账户已删除"}
|
||||
ErrTimestampInvalid = &JiguangError{Code: 903, Message: "错误的时间戳/时间误差大于 10 分钟"}
|
||||
ErrSignMethodInvalid = &JiguangError{Code: 904, Message: "无法识别的签名方法"}
|
||||
ErrSignInvalid = &JiguangError{Code: 905, Message: "签名不合法"}
|
||||
|
||||
// 系统错误
|
||||
ErrAccountStatusError = &JiguangError{Code: 911, Message: "账户状态异常"}
|
||||
ErrInterfaceDisabled = &JiguangError{Code: 912, Message: "接口状态不可用"}
|
||||
ErrAPICallError = &JiguangError{Code: 915, Message: "API 接口调用有误"}
|
||||
ErrInternalError = &JiguangError{Code: 916, Message: "内部接口调用错误,请联系相关人员"}
|
||||
ErrTimeout = &JiguangError{Code: 917, Message: "请求超时"}
|
||||
ErrBusinessDisabled = &JiguangError{Code: 919, Message: "业务状态不可用"}
|
||||
ErrInterfaceException = &JiguangError{Code: 923, Message: "接口异常"}
|
||||
|
||||
// 业务错误
|
||||
ErrNoRecord = &JiguangError{Code: 921, Message: "查无记录"}
|
||||
ErrQueryFailed = &JiguangError{Code: 922, Message: "查询失败"}
|
||||
)
|
||||
|
||||
// NewJiguangError 创建新的极光错误
|
||||
func NewJiguangError(code int, message string) *JiguangError {
|
||||
return &JiguangError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewJiguangErrorFromCode 根据状态码创建错误
|
||||
func NewJiguangErrorFromCode(code int) *JiguangError {
|
||||
switch code {
|
||||
case 0:
|
||||
return ErrSuccess
|
||||
case 400:
|
||||
return ErrParamInvalid
|
||||
case 405:
|
||||
return ErrMethodInvalid
|
||||
case 902:
|
||||
return ErrAppIDInvalid
|
||||
case 903:
|
||||
return ErrTimestampInvalid
|
||||
case 904:
|
||||
return ErrSignMethodInvalid
|
||||
case 905:
|
||||
return ErrSignInvalid
|
||||
case 906:
|
||||
return ErrParamFormInvalid
|
||||
case 911:
|
||||
return ErrAccountStatusError
|
||||
case 912:
|
||||
return ErrInterfaceDisabled
|
||||
case 914:
|
||||
return ErrBodyIncomplete
|
||||
case 915:
|
||||
return ErrAPICallError
|
||||
case 916:
|
||||
return ErrInternalError
|
||||
case 917:
|
||||
return ErrTimeout
|
||||
case 918:
|
||||
return ErrBodyNotSupported
|
||||
case 919:
|
||||
return ErrBusinessDisabled
|
||||
case 921:
|
||||
return ErrNoRecord
|
||||
case 922:
|
||||
return ErrQueryFailed
|
||||
case 923:
|
||||
return ErrInterfaceException
|
||||
default:
|
||||
return &JiguangError{
|
||||
Code: code,
|
||||
Message: fmt.Sprintf("未知错误码: %d", code),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsJiguangError 检查是否是极光错误
|
||||
func IsJiguangError(err error) bool {
|
||||
_, ok := err.(*JiguangError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetJiguangError 获取极光错误
|
||||
func GetJiguangError(err error) *JiguangError {
|
||||
if jiguangErr, ok := err.(*JiguangError); ok {
|
||||
return jiguangErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
85
internal/infrastructure/external/jiguang/jiguang_factory.go
vendored
Normal file
85
internal/infrastructure/external/jiguang/jiguang_factory.go
vendored
Normal file
@@ -0,0 +1,85 @@
|
||||
package jiguang
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewJiguangServiceWithConfig 使用配置创建极光服务
|
||||
func NewJiguangServiceWithConfig(cfg *config.Config) (*JiguangService, error) {
|
||||
// 将配置类型转换为通用外部服务日志配置
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Jiguang.Logging.Enabled,
|
||||
LogDir: cfg.Jiguang.Logging.LogDir,
|
||||
ServiceName: "jiguang",
|
||||
UseDaily: cfg.Jiguang.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Jiguang.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
|
||||
// 转换级别配置
|
||||
for key, value := range cfg.Jiguang.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: value.MaxSize,
|
||||
MaxBackups: value.MaxBackups,
|
||||
MaxAge: value.MaxAge,
|
||||
Compress: value.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析签名方法
|
||||
var signMethod SignMethod
|
||||
if cfg.Jiguang.SignMethod == "md5" {
|
||||
signMethod = SignMethodMD5
|
||||
} else {
|
||||
signMethod = SignMethodHMACMD5 // 默认使用 HMAC-MD5
|
||||
}
|
||||
|
||||
// 解析超时时间
|
||||
timeout := 60 * time.Second
|
||||
if cfg.Jiguang.Timeout > 0 {
|
||||
timeout = cfg.Jiguang.Timeout
|
||||
}
|
||||
|
||||
// 创建极光服务
|
||||
service := NewJiguangService(
|
||||
cfg.Jiguang.URL,
|
||||
cfg.Jiguang.AppID,
|
||||
cfg.Jiguang.AppSecret,
|
||||
signMethod,
|
||||
timeout,
|
||||
logger,
|
||||
)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewJiguangServiceWithLogging 使用自定义日志配置创建极光服务
|
||||
func NewJiguangServiceWithLogging(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig) (*JiguangService, error) {
|
||||
// 设置服务名称
|
||||
loggingConfig.ServiceName = "jiguang"
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建极光服务
|
||||
service := NewJiguangService(url, appID, appSecret, signMethod, timeout, logger)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewJiguangServiceSimple 创建简单的极光服务(无日志)
|
||||
func NewJiguangServiceSimple(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration) *JiguangService {
|
||||
return NewJiguangService(url, appID, appSecret, signMethod, timeout, nil)
|
||||
}
|
||||
316
internal/infrastructure/external/jiguang/jiguang_service.go
vendored
Normal file
316
internal/infrastructure/external/jiguang/jiguang_service.go
vendored
Normal file
@@ -0,0 +1,316 @@
|
||||
package jiguang
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
ErrNotFound = errors.New("查询为空")
|
||||
)
|
||||
|
||||
// JiguangResponse 极光API响应结构(兼容两套字段命名)
|
||||
//
|
||||
// 格式一:ordernum、message、result(定位/查询类接口常见)
|
||||
// 格式二:order_id、msg、data(文档中的 code/msg/order_id/data)
|
||||
type JiguangResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Message string `json:"message"`
|
||||
OrderID string `json:"order_id"`
|
||||
OrderNum string `json:"ordernum"`
|
||||
Data interface{} `json:"data"`
|
||||
Result interface{} `json:"result"`
|
||||
}
|
||||
|
||||
// normalize 将异名字段合并到 OrderID、Msg,便于后续统一分支使用
|
||||
func (r *JiguangResponse) normalize() {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
if r.OrderID == "" && r.OrderNum != "" {
|
||||
r.OrderID = r.OrderNum
|
||||
}
|
||||
if r.Msg == "" && r.Message != "" {
|
||||
r.Msg = r.Message
|
||||
}
|
||||
}
|
||||
|
||||
// JiguangConfig 极光服务配置
|
||||
type JiguangConfig struct {
|
||||
URL string
|
||||
AppID string
|
||||
AppSecret string
|
||||
SignMethod SignMethod // 签名方法:md5 或 hmac
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// JiguangService 极光服务
|
||||
type JiguangService struct {
|
||||
config JiguangConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewJiguangService 创建一个新的极光服务实例
|
||||
func NewJiguangService(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *JiguangService {
|
||||
// 如果没有指定签名方法,默认使用 HMAC-MD5
|
||||
if signMethod == "" {
|
||||
signMethod = SignMethodHMACMD5
|
||||
}
|
||||
|
||||
// 如果没有指定超时时间,默认使用 60 秒
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
return &JiguangService{
|
||||
config: JiguangConfig{
|
||||
URL: url,
|
||||
AppID: appID,
|
||||
AppSecret: appSecret,
|
||||
SignMethod: signMethod,
|
||||
Timeout: timeout,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (j *JiguangService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, j.config.AppID)))
|
||||
return fmt.Sprintf("jiguang_%x", hash[:8])
|
||||
}
|
||||
|
||||
// CallAPI 调用极光API
|
||||
// apiCode: API服务编码(如 marriage-single-v2),用于请求头
|
||||
// apiPath: API路径(如 marriage/single-v2),用于URL路径
|
||||
// params: 请求参数(会作为JSON body发送)
|
||||
func (j *JiguangService) CallAPI(ctx context.Context, apiCode string, apiPath string, params map[string]interface{}) (resp []byte, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := j.generateRequestID()
|
||||
|
||||
// 生成时间戳(毫秒)
|
||||
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||
|
||||
// 从ctx中获取transactionId
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
// 生成签名
|
||||
sign, signErr := GenerateSign(timestamp, j.config.AppSecret, j.config.SignMethod)
|
||||
if signErr != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("生成签名失败: %w", signErr))
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建完整的请求URL,使用apiPath作为路径
|
||||
requestURL := strings.TrimSuffix(j.config.URL, "/") + "/" + strings.TrimPrefix(apiPath, "/")
|
||||
|
||||
// 记录请求日志
|
||||
if j.logger != nil {
|
||||
j.logger.LogRequest(requestID, transactionID, apiCode, requestURL)
|
||||
}
|
||||
|
||||
// 将请求参数转换为JSON
|
||||
jsonData, marshalErr := json.Marshal(params)
|
||||
if marshalErr != nil {
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(jsonData))
|
||||
if newRequestErr != nil {
|
||||
err = errors.Join(ErrSystem, newRequestErr)
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("appId", j.config.AppID)
|
||||
req.Header.Set("apiCode", apiCode)
|
||||
req.Header.Set("timestamp", timestamp)
|
||||
req.Header.Set("signMethod", string(j.config.SignMethod))
|
||||
req.Header.Set("sign", sign)
|
||||
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: j.config.Timeout,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
httpResp, clientDoErr := client.Do(req)
|
||||
if clientDoErr != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, clientDoErr)
|
||||
}
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
closeErr := Body.Close()
|
||||
if closeErr != nil {
|
||||
// 记录关闭错误
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), params)
|
||||
}
|
||||
}
|
||||
}(httpResp.Body)
|
||||
|
||||
// 计算请求耗时
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 读取响应体
|
||||
bodyBytes, readErr := io.ReadAll(httpResp.Body)
|
||||
if readErr != nil {
|
||||
err = errors.Join(ErrSystem, readErr)
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查HTTP状态码
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("极光请求失败,状态码: %d", httpResp.StatusCode))
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析响应结构
|
||||
var jiguangResp JiguangResponse
|
||||
if err := json.Unmarshal(bodyBytes, &jiguangResp); err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
|
||||
if j.logger != nil {
|
||||
j.logger.LogError(requestID, transactionID, apiCode, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
jiguangResp.normalize()
|
||||
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if j.logger != nil {
|
||||
if jiguangResp.OrderID != "" {
|
||||
j.logger.LogResponseWithID(requestID, transactionID, apiCode, httpResp.StatusCode, duration, jiguangResp.OrderID)
|
||||
} else {
|
||||
j.logger.LogResponse(requestID, transactionID, apiCode, httpResp.StatusCode, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查业务状态码
|
||||
if jiguangResp.Code != 0 && jiguangResp.Code != 200 {
|
||||
// 创建极光错误
|
||||
jiguangErr := NewJiguangErrorFromCode(jiguangResp.Code)
|
||||
if jiguangErr.Message == fmt.Sprintf("未知错误码: %d", jiguangResp.Code) {
|
||||
if jiguangResp.Msg != "" {
|
||||
jiguangErr.Message = jiguangResp.Msg
|
||||
} else if jiguangResp.Message != "" {
|
||||
jiguangErr.Message = jiguangResp.Message
|
||||
}
|
||||
}
|
||||
// 根据错误类型返回不同的错误
|
||||
if jiguangErr.IsNoRecord() {
|
||||
// 从context中获取apiCode,判断是否需要抛出异常
|
||||
var processorCode string
|
||||
if ctxProcessorCode, ok := ctx.Value("api_code").(string); ok {
|
||||
processorCode = ctxProcessorCode
|
||||
}
|
||||
// 定义不需要抛出异常的处理器列表(默认情况下查无记录时抛出异常)
|
||||
processorsNotToThrowError := map[string]bool{
|
||||
// 在这个列表中的处理器,查无记录时返回空数组,不抛出异常
|
||||
// 示例:如果需要添加某个处理器,取消下面的注释
|
||||
// "QCXG9P1C": true,
|
||||
}
|
||||
// 如果是不需要抛出异常的处理器,返回空数组;否则(默认)抛出异常
|
||||
if processorsNotToThrowError[processorCode] {
|
||||
// 查无记录时,返回空数组,API调用记录为成功
|
||||
return []byte("[]"), nil
|
||||
}
|
||||
// 记录错误日志
|
||||
if j.logger != nil {
|
||||
j.logger.LogErrorWithResponseID(requestID, transactionID, apiCode, jiguangErr, params, jiguangResp.OrderID)
|
||||
}
|
||||
return nil, errors.Join(ErrNotFound, jiguangErr)
|
||||
}
|
||||
// 记录错误日志(查无记录的情况不记录错误日志)
|
||||
if j.logger != nil {
|
||||
j.logger.LogErrorWithResponseID(requestID, transactionID, apiCode, jiguangErr, params, jiguangResp.OrderID)
|
||||
}
|
||||
if jiguangErr.IsQueryFailed() {
|
||||
return nil, errors.Join(ErrDatasource, jiguangErr)
|
||||
} else if jiguangErr.IsSystemError() {
|
||||
return nil, errors.Join(ErrSystem, jiguangErr)
|
||||
} else {
|
||||
return nil, errors.Join(ErrDatasource, jiguangErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 成功时业务体在 data 或 result
|
||||
payload := jiguangResp.Data
|
||||
if payload == nil {
|
||||
payload = jiguangResp.Result
|
||||
}
|
||||
if payload == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
|
||||
dataBytes, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("业务数据序列化失败: %w", err))
|
||||
if j.logger != nil {
|
||||
j.logger.LogErrorWithResponseID(requestID, transactionID, apiCode, err, params, jiguangResp.OrderID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dataBytes, nil
|
||||
}
|
||||
|
||||
// GetConfig 获取配置信息
|
||||
func (j *JiguangService) GetConfig() JiguangConfig {
|
||||
return j.config
|
||||
}
|
||||
25
internal/infrastructure/external/muzi/muzi_errors.go
vendored
Normal file
25
internal/infrastructure/external/muzi/muzi_errors.go
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
package muzi
|
||||
|
||||
import "fmt"
|
||||
|
||||
// MuziError 木子数据业务错误
|
||||
type MuziError struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
// Error implements error interface.
|
||||
func (e *MuziError) Error() string {
|
||||
return fmt.Sprintf("木子数据返回错误,代码: %d,信息: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// NewMuziError 根据错误码创建业务错误
|
||||
func NewMuziError(code int, message string) *MuziError {
|
||||
if message == "" {
|
||||
message = "木子数据返回未知错误"
|
||||
}
|
||||
return &MuziError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
61
internal/infrastructure/external/muzi/muzi_factory.go
vendored
Normal file
61
internal/infrastructure/external/muzi/muzi_factory.go
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
package muzi
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewMuziServiceWithConfig 使用配置创建木子数据服务
|
||||
func NewMuziServiceWithConfig(cfg *config.Config) (*MuziService, error) {
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Muzi.Logging.Enabled,
|
||||
LogDir: cfg.Muzi.Logging.LogDir,
|
||||
ServiceName: "muzi",
|
||||
UseDaily: cfg.Muzi.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Muzi.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
|
||||
for level, levelCfg := range cfg.Muzi.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[level] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: levelCfg.MaxSize,
|
||||
MaxBackups: levelCfg.MaxBackups,
|
||||
MaxAge: levelCfg.MaxAge,
|
||||
Compress: levelCfg.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
service := NewMuziService(
|
||||
cfg.Muzi.URL,
|
||||
cfg.Muzi.AppID,
|
||||
cfg.Muzi.AppSecret,
|
||||
cfg.Muzi.Timeout,
|
||||
logger,
|
||||
)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewMuziServiceWithLogging 使用自定义日志配置创建木子数据服务
|
||||
func NewMuziServiceWithLogging(url, appID, appSecret string, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig) (*MuziService, error) {
|
||||
loggingConfig.ServiceName = "muzi"
|
||||
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewMuziService(url, appID, appSecret, timeout, logger), nil
|
||||
}
|
||||
|
||||
// NewMuziServiceSimple 创建无日志的木子数据服务
|
||||
func NewMuziServiceSimple(url, appID, appSecret string, timeout time.Duration) *MuziService {
|
||||
return NewMuziService(url, appID, appSecret, timeout, nil)
|
||||
}
|
||||
406
internal/infrastructure/external/muzi/muzi_service.go
vendored
Normal file
406
internal/infrastructure/external/muzi/muzi_service.go
vendored
Normal file
@@ -0,0 +1,406 @@
|
||||
package muzi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
const defaultRequestTimeout = 60 * time.Second
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
)
|
||||
|
||||
// Muzi状态码常量
|
||||
const (
|
||||
CodeSuccess = 0 // 成功查询
|
||||
CodeSystemError = 500 // 系统异常
|
||||
CodeParamMissing = 601 // 参数不全
|
||||
CodeInterfaceExpired = 602 // 接口已过期
|
||||
CodeVerifyFailed = 603 // 接口校验失败
|
||||
CodeIPNotInWhitelist = 604 // IP不在白名单中
|
||||
CodeProductNotFound = 701 // 产品编号不存在
|
||||
CodeUserNotFound = 702 // 用户名不存在
|
||||
CodeUnauthorizedAPI = 703 // 接口未授权
|
||||
CodeInsufficientFund = 704 // 商户余额不足
|
||||
)
|
||||
|
||||
// MuziResponse 木子数据接口通用响应
|
||||
type MuziResponse struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Timestamp int64 `json:"timestamp"`
|
||||
ExecuteTime int64 `json:"executeTime"`
|
||||
}
|
||||
|
||||
// MuziConfig 木子数据接口配置
|
||||
type MuziConfig struct {
|
||||
URL string
|
||||
AppID string
|
||||
AppSecret string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// MuziService 木子数据接口服务封装
|
||||
type MuziService struct {
|
||||
config MuziConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewMuziService 创建木子数据服务实例
|
||||
func NewMuziService(url, appID, appSecret string, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *MuziService {
|
||||
if timeout <= 0 {
|
||||
timeout = defaultRequestTimeout
|
||||
}
|
||||
|
||||
return &MuziService{
|
||||
config: MuziConfig{
|
||||
URL: url,
|
||||
AppID: appID,
|
||||
AppSecret: appSecret,
|
||||
Timeout: timeout,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (m *MuziService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
raw := fmt.Sprintf("%d_%s", timestamp, m.config.AppID)
|
||||
sum := md5.Sum([]byte(raw))
|
||||
return fmt.Sprintf("muzi_%x", sum[:8])
|
||||
}
|
||||
|
||||
// CallAPI 调用木子数据接口
|
||||
func (m *MuziService) CallAPI(ctx context.Context, prodCode string, path string, params map[string]interface{},paramSign map[string]interface{}) (json.RawMessage, error) {
|
||||
requestID := m.generateRequestID()
|
||||
now := time.Now()
|
||||
timestamp := strconv.FormatInt(now.UnixMilli(), 10)
|
||||
|
||||
flatParams := flattenParams(params)
|
||||
|
||||
signParts := collectSignatureValues(paramSign)
|
||||
signature := m.GenerateSignature(prodCode, timestamp, signParts...)
|
||||
|
||||
// 从上下文获取链路ID
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"appId": m.config.AppID,
|
||||
"prodCode": prodCode,
|
||||
"timestamp": timestamp,
|
||||
"signature": signature,
|
||||
}
|
||||
for key, value := range flatParams {
|
||||
requestBody[key] = value
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
m.logger.LogRequest(requestID, transactionID, prodCode, m.config.URL)
|
||||
}
|
||||
|
||||
bodyBytes, marshalErr := json.Marshal(requestBody)
|
||||
if marshalErr != nil {
|
||||
err := errors.Join(ErrSystem, marshalErr)
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建完整的URL,拼接路径参数
|
||||
fullURL := m.config.URL
|
||||
if path != "" {
|
||||
// 确保路径以/开头
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
// 确保URL不以/结尾,避免双斜杠
|
||||
if strings.HasSuffix(fullURL, "/") {
|
||||
fullURL = fullURL[:len(fullURL)-1]
|
||||
}
|
||||
fullURL += path
|
||||
}
|
||||
|
||||
req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewBuffer(bodyBytes))
|
||||
if reqErr != nil {
|
||||
err := errors.Join(ErrSystem, reqErr)
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
client := &http.Client{
|
||||
Timeout: m.config.Timeout,
|
||||
}
|
||||
|
||||
resp, httpErr := client.Do(req)
|
||||
if httpErr != nil {
|
||||
err := wrapHTTPError(httpErr)
|
||||
if errors.Is(err, ErrDatasource) {
|
||||
err = errors.Join(err, fmt.Errorf("API请求超时: %v", httpErr))
|
||||
}
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer func(body io.ReadCloser) {
|
||||
closeErr := body.Close()
|
||||
if closeErr != nil && m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), requestBody)
|
||||
}
|
||||
}(resp.Body)
|
||||
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
err := errors.Join(ErrSystem, readErr)
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
m.logger.LogResponse(requestID, transactionID, prodCode, resp.StatusCode, time.Since(now))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err := errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", resp.StatusCode))
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var muziResp MuziResponse
|
||||
if err := json.Unmarshal(respBody, &muziResp); err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %v", err))
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if muziResp.Code != CodeSuccess {
|
||||
muziErr := NewMuziError(muziResp.Code, muziResp.Msg)
|
||||
var resultErr error
|
||||
|
||||
switch muziResp.Code {
|
||||
case CodeSystemError:
|
||||
resultErr = errors.Join(ErrDatasource, muziErr)
|
||||
default:
|
||||
resultErr = errors.Join(ErrSystem, muziErr)
|
||||
}
|
||||
|
||||
if m.logger != nil {
|
||||
m.logger.LogError(requestID, transactionID, prodCode, muziErr, requestBody)
|
||||
}
|
||||
return nil, resultErr
|
||||
}
|
||||
|
||||
return muziResp.Data, nil
|
||||
}
|
||||
|
||||
func wrapHTTPError(err error) error {
|
||||
var timeout bool
|
||||
if err == context.DeadlineExceeded {
|
||||
timeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
timeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
timeout = true
|
||||
}
|
||||
|
||||
if timeout {
|
||||
return errors.Join(ErrDatasource, err)
|
||||
}
|
||||
return errors.Join(ErrSystem, err)
|
||||
}
|
||||
|
||||
func pkcs5Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
func flattenParams(params map[string]interface{}) map[string]interface{} {
|
||||
result := make(map[string]interface{})
|
||||
if params == nil {
|
||||
return result
|
||||
}
|
||||
for key, value := range params {
|
||||
flattenValue(key, value, result)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func flattenValue(prefix string, value interface{}, out map[string]interface{}) {
|
||||
switch val := value.(type) {
|
||||
case map[string]interface{}:
|
||||
for k, v := range val {
|
||||
flattenValue(combinePrefix(prefix, k), v, out)
|
||||
}
|
||||
case map[interface{}]interface{}:
|
||||
for k, v := range val {
|
||||
keyStr := fmt.Sprint(k)
|
||||
flattenValue(combinePrefix(prefix, keyStr), v, out)
|
||||
}
|
||||
case []interface{}:
|
||||
for i, item := range val {
|
||||
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
|
||||
flattenValue(nextPrefix, item, out)
|
||||
}
|
||||
case []string:
|
||||
for i, item := range val {
|
||||
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
|
||||
flattenValue(nextPrefix, item, out)
|
||||
}
|
||||
case []int:
|
||||
for i, item := range val {
|
||||
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
|
||||
flattenValue(nextPrefix, item, out)
|
||||
}
|
||||
case []float64:
|
||||
for i, item := range val {
|
||||
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
|
||||
flattenValue(nextPrefix, item, out)
|
||||
}
|
||||
case []bool:
|
||||
for i, item := range val {
|
||||
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
|
||||
flattenValue(nextPrefix, item, out)
|
||||
}
|
||||
default:
|
||||
out[prefix] = val
|
||||
}
|
||||
}
|
||||
|
||||
func combinePrefix(prefix, key string) string {
|
||||
if prefix == "" {
|
||||
return key
|
||||
}
|
||||
return prefix + "." + key
|
||||
}
|
||||
|
||||
// Encrypt 使用 AES/ECB/PKCS5Padding 对单个字符串进行加密并返回 Base64 结果
|
||||
func (m *MuziService) Encrypt(value string) (string, error) {
|
||||
if len(m.config.AppSecret) != 32 {
|
||||
return "", fmt.Errorf("AppSecret长度必须为32位")
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher([]byte(m.config.AppSecret))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("初始化加密器失败: %w", err)
|
||||
}
|
||||
|
||||
padded := pkcs5Padding([]byte(value), block.BlockSize())
|
||||
encrypted := make([]byte, len(padded))
|
||||
|
||||
for bs, be := 0, block.BlockSize(); bs < len(padded); bs, be = bs+block.BlockSize(), be+block.BlockSize() {
|
||||
block.Encrypt(encrypted[bs:be], padded[bs:be])
|
||||
}
|
||||
|
||||
return base64.StdEncoding.EncodeToString(encrypted), nil
|
||||
}
|
||||
|
||||
// GenerateSignature 根据协议生成签名,extraValues 会按顺序追加在待签名字符串之后
|
||||
func (m *MuziService) GenerateSignature(prodCode, timestamp string, extraValues ...string) string {
|
||||
signStr := m.config.AppID + prodCode + timestamp
|
||||
for _, extra := range extraValues {
|
||||
signStr += extra
|
||||
}
|
||||
hash := md5.Sum([]byte(signStr))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// GenerateTimestamp 生成当前毫秒级时间戳字符串
|
||||
func (m *MuziService) GenerateTimestamp() string {
|
||||
return strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||
}
|
||||
|
||||
// FlattenParams 将嵌套参数展平为一维键值对
|
||||
func (m *MuziService) FlattenParams(params map[string]interface{}) map[string]interface{} {
|
||||
return flattenParams(params)
|
||||
}
|
||||
|
||||
func collectSignatureValues(data interface{}) []string {
|
||||
var result []string
|
||||
collectSignatureValuesRecursive(reflect.ValueOf(data), &result)
|
||||
return result
|
||||
}
|
||||
|
||||
func collectSignatureValuesRecursive(value reflect.Value, result *[]string) {
|
||||
if !value.IsValid() {
|
||||
*result = append(*result, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch value.Kind() {
|
||||
case reflect.Pointer, reflect.Interface:
|
||||
if value.IsNil() {
|
||||
*result = append(*result, "")
|
||||
return
|
||||
}
|
||||
collectSignatureValuesRecursive(value.Elem(), result)
|
||||
case reflect.Map:
|
||||
keys := value.MapKeys()
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
return fmt.Sprint(keys[i].Interface()) < fmt.Sprint(keys[j].Interface())
|
||||
})
|
||||
for _, key := range keys {
|
||||
collectSignatureValuesRecursive(value.MapIndex(key), result)
|
||||
}
|
||||
case reflect.Slice, reflect.Array:
|
||||
for i := 0; i < value.Len(); i++ {
|
||||
collectSignatureValuesRecursive(value.Index(i), result)
|
||||
}
|
||||
case reflect.Struct:
|
||||
typeInfo := value.Type()
|
||||
fieldNames := make([]string, 0, value.NumField())
|
||||
fieldIndices := make(map[string]int, value.NumField())
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
field := typeInfo.Field(i)
|
||||
if field.PkgPath != "" {
|
||||
continue
|
||||
}
|
||||
fieldNames = append(fieldNames, field.Name)
|
||||
fieldIndices[field.Name] = i
|
||||
}
|
||||
sort.Strings(fieldNames)
|
||||
for _, name := range fieldNames {
|
||||
collectSignatureValuesRecursive(value.Field(fieldIndices[name]), result)
|
||||
}
|
||||
default:
|
||||
*result = append(*result, fmt.Sprint(value.Interface()))
|
||||
}
|
||||
}
|
||||
573
internal/infrastructure/external/notification/wechat_work_service.go
vendored
Normal file
573
internal/infrastructure/external/notification/wechat_work_service.go
vendored
Normal file
@@ -0,0 +1,573 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WeChatWorkService 企业微信通知服务
|
||||
type WeChatWorkService struct {
|
||||
webhookURL string
|
||||
secret string
|
||||
timeout time.Duration
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// WechatWorkConfig 企业微信配置
|
||||
type WechatWorkConfig struct {
|
||||
WebhookURL string `yaml:"webhook_url"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
}
|
||||
|
||||
// WechatWorkMessage 企业微信消息
|
||||
type WechatWorkMessage struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *WechatWorkText `json:"text,omitempty"`
|
||||
Markdown *WechatWorkMarkdown `json:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// WechatWorkText 文本消息
|
||||
type WechatWorkText struct {
|
||||
Content string `json:"content"`
|
||||
MentionedList []string `json:"mentioned_list,omitempty"`
|
||||
MentionedMobileList []string `json:"mentioned_mobile_list,omitempty"`
|
||||
}
|
||||
|
||||
// WechatWorkMarkdown Markdown消息
|
||||
type WechatWorkMarkdown struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// NewWeChatWorkService 创建企业微信通知服务
|
||||
func NewWeChatWorkService(webhookURL, secret string, logger *zap.Logger) *WeChatWorkService {
|
||||
return &WeChatWorkService{
|
||||
webhookURL: webhookURL,
|
||||
secret: secret,
|
||||
timeout: 60 * time.Second,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendTextMessage 发送文本消息
|
||||
func (s *WeChatWorkService) SendTextMessage(ctx context.Context, content string, mentionedList []string, mentionedMobileList []string) error {
|
||||
s.logger.Info("发送企业微信文本消息",
|
||||
zap.String("content", content),
|
||||
zap.Strings("mentioned_list", mentionedList),
|
||||
)
|
||||
|
||||
message := map[string]interface{}{
|
||||
"msgtype": "text",
|
||||
"text": map[string]interface{}{
|
||||
"content": content,
|
||||
"mentioned_list": mentionedList,
|
||||
"mentioned_mobile_list": mentionedMobileList,
|
||||
},
|
||||
}
|
||||
|
||||
return s.sendMessage(ctx, message)
|
||||
}
|
||||
|
||||
// SendMarkdownMessage 发送Markdown消息
|
||||
func (s *WeChatWorkService) SendMarkdownMessage(ctx context.Context, content string) error {
|
||||
s.logger.Info("发送企业微信Markdown消息", zap.String("content", content))
|
||||
|
||||
message := map[string]interface{}{
|
||||
"msgtype": "markdown",
|
||||
"markdown": map[string]interface{}{
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
|
||||
return s.sendMessage(ctx, message)
|
||||
}
|
||||
|
||||
// SendCardMessage 发送卡片消息
|
||||
func (s *WeChatWorkService) SendCardMessage(ctx context.Context, title, description, url string, btnText string) error {
|
||||
s.logger.Info("发送企业微信卡片消息",
|
||||
zap.String("title", title),
|
||||
zap.String("description", description),
|
||||
)
|
||||
|
||||
message := map[string]interface{}{
|
||||
"msgtype": "template_card",
|
||||
"template_card": map[string]interface{}{
|
||||
"card_type": "text_notice",
|
||||
"source": map[string]interface{}{
|
||||
"icon_url": "https://example.com/icon.png",
|
||||
"desc": "企业认证系统",
|
||||
},
|
||||
"main_title": map[string]interface{}{
|
||||
"title": title,
|
||||
},
|
||||
"horizontal_content_list": []map[string]interface{}{
|
||||
{
|
||||
"keyname": "描述",
|
||||
"value": description,
|
||||
},
|
||||
},
|
||||
"jump_list": []map[string]interface{}{
|
||||
{
|
||||
"type": "1",
|
||||
"title": btnText,
|
||||
"url": url,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return s.sendMessage(ctx, message)
|
||||
}
|
||||
|
||||
// SendCertificationNotification 发送认证相关通知
|
||||
func (s *WeChatWorkService) SendCertificationNotification(ctx context.Context, notificationType string, data map[string]interface{}) error {
|
||||
s.logger.Info("发送认证通知", zap.String("type", notificationType))
|
||||
|
||||
switch notificationType {
|
||||
case "new_application":
|
||||
return s.sendNewApplicationNotification(ctx, data)
|
||||
case "pending_manual_review":
|
||||
return s.sendPendingManualReviewNotification(ctx, data)
|
||||
case "ocr_success":
|
||||
return s.sendOCRSuccessNotification(ctx, data)
|
||||
case "ocr_failed":
|
||||
return s.sendOCRFailedNotification(ctx, data)
|
||||
case "face_verify_success":
|
||||
return s.sendFaceVerifySuccessNotification(ctx, data)
|
||||
case "face_verify_failed":
|
||||
return s.sendFaceVerifyFailedNotification(ctx, data)
|
||||
case "admin_approved":
|
||||
return s.sendAdminApprovedNotification(ctx, data)
|
||||
case "admin_rejected":
|
||||
return s.sendAdminRejectedNotification(ctx, data)
|
||||
case "contract_signed":
|
||||
return s.sendContractSignedNotification(ctx, data)
|
||||
case "certification_completed":
|
||||
return s.sendCertificationCompletedNotification(ctx, data)
|
||||
default:
|
||||
return fmt.Errorf("不支持的通知类型: %s", notificationType)
|
||||
}
|
||||
}
|
||||
|
||||
// sendNewApplicationNotification 发送新申请通知
|
||||
func (s *WeChatWorkService) sendNewApplicationNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicantName := data["applicant_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】🆕 新的企业认证申请
|
||||
|
||||
**企业名称**: %s
|
||||
**申请人**: %s
|
||||
**申请ID**: %s
|
||||
**申请时间**: %s
|
||||
|
||||
请管理员及时审核处理。`,
|
||||
companyName,
|
||||
applicantName,
|
||||
applicationID,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendPendingManualReviewNotification 用户已提交企业信息,待管理员人工审核(三真审核前序步骤)
|
||||
func (s *WeChatWorkService) sendPendingManualReviewNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := fmt.Sprint(data["company_name"])
|
||||
legalPersonName := fmt.Sprint(data["legal_person_name"])
|
||||
authorizedRepName := fmt.Sprint(data["authorized_rep_name"])
|
||||
contactPhone := fmt.Sprint(data["contact_phone"])
|
||||
apiUsage := fmt.Sprint(data["api_usage"])
|
||||
submitAt := fmt.Sprint(data["submit_at"])
|
||||
|
||||
if authorizedRepName == "" || authorizedRepName == "<nil>" {
|
||||
authorizedRepName = "—"
|
||||
}
|
||||
if apiUsage == "" || apiUsage == "<nil>" {
|
||||
apiUsage = "—"
|
||||
}
|
||||
if contactPhone == "" || contactPhone == "<nil>" {
|
||||
contactPhone = "—"
|
||||
}
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】📋 企业信息待人工审核
|
||||
|
||||
**企业名称**: %s
|
||||
**法人**: %s
|
||||
**授权申请人**: %s
|
||||
**联系电话**: %s
|
||||
**应用场景说明**: %s
|
||||
**提交时间**: %s
|
||||
|
||||
> 请管理员登录后台 **企业审核** 通过审核后,用户方可进行 e签宝企业认证。`,
|
||||
companyName,
|
||||
legalPersonName,
|
||||
authorizedRepName,
|
||||
contactPhone,
|
||||
apiUsage,
|
||||
submitAt)
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendOCRSuccessNotification 发送OCR识别成功通知
|
||||
func (s *WeChatWorkService) sendOCRSuccessNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
confidence := data["confidence"].(float64)
|
||||
applicationID := data["application_id"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】✅ OCR识别成功
|
||||
|
||||
**企业名称**: %s
|
||||
**识别置信度**: %.2f%%
|
||||
**申请ID**: %s
|
||||
**识别时间**: %s
|
||||
|
||||
营业执照信息已自动提取,请用户确认信息。`,
|
||||
companyName,
|
||||
confidence*100,
|
||||
applicationID,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendOCRFailedNotification 发送OCR识别失败通知
|
||||
func (s *WeChatWorkService) sendOCRFailedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
applicationID := data["application_id"].(string)
|
||||
errorMsg := data["error_message"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】❌ OCR识别失败
|
||||
|
||||
**申请ID**: %s
|
||||
**错误信息**: %s
|
||||
**失败时间**: %s
|
||||
|
||||
请检查营业执照图片质量或联系技术支持。`,
|
||||
applicationID,
|
||||
errorMsg,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendFaceVerifySuccessNotification 发送人脸识别成功通知
|
||||
func (s *WeChatWorkService) sendFaceVerifySuccessNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
applicantName := data["applicant_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
confidence := data["confidence"].(float64)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】✅ 人脸识别成功
|
||||
|
||||
**申请人**: %s
|
||||
**申请ID**: %s
|
||||
**识别置信度**: %.2f%%
|
||||
**识别时间**: %s
|
||||
|
||||
身份验证通过,可以进行下一步操作。`,
|
||||
applicantName,
|
||||
applicationID,
|
||||
confidence*100,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendFaceVerifyFailedNotification 发送人脸识别失败通知
|
||||
func (s *WeChatWorkService) sendFaceVerifyFailedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
applicantName := data["applicant_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
errorMsg := data["error_message"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】❌ 人脸识别失败
|
||||
|
||||
**申请人**: %s
|
||||
**申请ID**: %s
|
||||
**错误信息**: %s
|
||||
**失败时间**: %s
|
||||
|
||||
请重新进行人脸识别或联系技术支持。`,
|
||||
applicantName,
|
||||
applicationID,
|
||||
errorMsg,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendAdminApprovedNotification 发送管理员审核通过通知
|
||||
func (s *WeChatWorkService) sendAdminApprovedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
adminName := data["admin_name"].(string)
|
||||
comment := data["comment"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】✅ 管理员审核通过
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**审核人**: %s
|
||||
**审核意见**: %s
|
||||
**审核时间**: %s
|
||||
|
||||
认证申请已通过审核,请用户签署电子合同。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
adminName,
|
||||
comment,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendAdminRejectedNotification 发送管理员审核拒绝通知
|
||||
func (s *WeChatWorkService) sendAdminRejectedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
adminName := data["admin_name"].(string)
|
||||
reason := data["reason"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】❌ 管理员审核拒绝
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**审核人**: %s
|
||||
**拒绝原因**: %s
|
||||
**审核时间**: %s
|
||||
|
||||
认证申请被拒绝,请根据反馈意见重新提交。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
adminName,
|
||||
reason,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendContractSignedNotification 发送合同签署通知
|
||||
func (s *WeChatWorkService) sendContractSignedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
signerName := data["signer_name"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】📝 电子合同已签署
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**签署人**: %s
|
||||
**签署时间**: %s
|
||||
|
||||
电子合同签署完成,系统将自动生成钱包和Access Key。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
signerName,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendCertificationCompletedNotification 发送认证完成通知
|
||||
func (s *WeChatWorkService) sendCertificationCompletedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
walletAddress := data["wallet_address"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】🎉 企业认证完成
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**钱包地址**: %s
|
||||
**完成时间**: %s
|
||||
|
||||
恭喜!企业认证流程已完成,钱包和Access Key已生成。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
walletAddress,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendMessage 发送消息到企业微信
|
||||
func (s *WeChatWorkService) sendMessage(ctx context.Context, message map[string]interface{}) error {
|
||||
// 生成签名URL
|
||||
signedURL := s.generateSignedURL()
|
||||
|
||||
// 序列化消息
|
||||
messageBytes, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: s.timeout,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", signedURL, bytes.NewBuffer(messageBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "hyapi-server/1.0")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
errorMsg := "发送请求失败"
|
||||
if isTimeout {
|
||||
errorMsg = "发送请求超时"
|
||||
}
|
||||
return fmt.Errorf("%s: %w", errorMsg, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误码
|
||||
if errCode, ok := response["errcode"].(float64); ok && errCode != 0 {
|
||||
errmsg := response["errmsg"].(string)
|
||||
return fmt.Errorf("企业微信API错误: %d - %s", int(errCode), errmsg)
|
||||
}
|
||||
|
||||
s.logger.Info("企业微信消息发送成功", zap.Any("response", response))
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateSignedURL 生成带签名的URL
|
||||
func (s *WeChatWorkService) generateSignedURL() string {
|
||||
if s.secret == "" {
|
||||
return s.webhookURL
|
||||
}
|
||||
|
||||
// 生成时间戳
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// 生成随机字符串(这里简化处理,实际应该使用随机字符串)
|
||||
nonce := fmt.Sprintf("%d", timestamp)
|
||||
|
||||
// 构建签名字符串
|
||||
signStr := fmt.Sprintf("%d\n%s", timestamp, s.secret)
|
||||
|
||||
// 计算签名
|
||||
h := hmac.New(sha256.New, []byte(s.secret))
|
||||
h.Write([]byte(signStr))
|
||||
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
// 构建签名URL
|
||||
return fmt.Sprintf("%s×tamp=%d&nonce=%s&sign=%s",
|
||||
s.webhookURL, timestamp, nonce, signature)
|
||||
}
|
||||
|
||||
// SendSystemAlert 发送系统告警
|
||||
func (s *WeChatWorkService) SendSystemAlert(ctx context.Context, level, title, message string) error {
|
||||
s.logger.Info("发送系统告警",
|
||||
zap.String("level", level),
|
||||
zap.String("title", title),
|
||||
)
|
||||
|
||||
// 根据告警级别选择图标
|
||||
var icon string
|
||||
switch level {
|
||||
case "info":
|
||||
icon = "ℹ️"
|
||||
case "warning":
|
||||
icon = "⚠️"
|
||||
case "error":
|
||||
icon = "🚨"
|
||||
case "critical":
|
||||
icon = "💥"
|
||||
default:
|
||||
icon = "📢"
|
||||
}
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】%s 系统告警
|
||||
|
||||
**级别**: %s
|
||||
**标题**: %s
|
||||
**消息**: %s
|
||||
**时间**: %s
|
||||
|
||||
请相关人员及时处理。`,
|
||||
icon,
|
||||
level,
|
||||
title,
|
||||
message,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// SendDailyReport 发送每日报告
|
||||
func (s *WeChatWorkService) SendDailyReport(ctx context.Context, reportData map[string]interface{}) error {
|
||||
s.logger.Info("发送每日报告")
|
||||
|
||||
content := fmt.Sprintf(`## 【海宇数据】📊 企业认证系统每日报告
|
||||
|
||||
**报告日期**: %s
|
||||
|
||||
### 统计数据
|
||||
- **新增申请**: %d
|
||||
- **OCR识别成功**: %d
|
||||
- **OCR识别失败**: %d
|
||||
- **人脸识别成功**: %d
|
||||
- **人脸识别失败**: %d
|
||||
- **审核通过**: %d
|
||||
- **审核拒绝**: %d
|
||||
- **认证完成**: %d
|
||||
|
||||
### 系统状态
|
||||
- **系统运行时间**: %s
|
||||
- **API调用次数**: %d
|
||||
- **错误次数**: %d
|
||||
|
||||
祝您工作愉快!`,
|
||||
time.Now().Format("2006-01-02"),
|
||||
reportData["new_applications"],
|
||||
reportData["ocr_success"],
|
||||
reportData["ocr_failed"],
|
||||
reportData["face_verify_success"],
|
||||
reportData["face_verify_failed"],
|
||||
reportData["admin_approved"],
|
||||
reportData["admin_rejected"],
|
||||
reportData["certification_completed"],
|
||||
reportData["uptime"],
|
||||
reportData["api_calls"],
|
||||
reportData["errors"])
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
147
internal/infrastructure/external/notification/wechat_work_service_test.go
vendored
Normal file
147
internal/infrastructure/external/notification/wechat_work_service_test.go
vendored
Normal file
@@ -0,0 +1,147 @@
|
||||
package notification_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/infrastructure/external/notification"
|
||||
)
|
||||
|
||||
// newTestWeChatWorkService 创建用于测试的企业微信服务实例
|
||||
// 默认使用环境变量 WECOM_WEBHOOK,若未设置则使用项目配置中的 webhook。
|
||||
func newTestWeChatWorkService(t *testing.T) *notification.WeChatWorkService {
|
||||
t.Helper()
|
||||
|
||||
webhook := os.Getenv("WECOM_WEBHOOK")
|
||||
if webhook == "" {
|
||||
// 使用你提供的 webhook 地址
|
||||
webhook = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=649bf737-28ca-4f30-ad5f-cfb65b2af113"
|
||||
}
|
||||
|
||||
logger, _ := zap.NewDevelopment()
|
||||
return notification.NewWeChatWorkService(webhook, "", logger)
|
||||
}
|
||||
|
||||
// TestWeChatWork_SendAllBusinessNotifications
|
||||
// 手动运行该用例,将依次向企业微信群推送 5 种业务场景的通知:
|
||||
// 1. 用户充值成功
|
||||
// 2. 用户申请开发票
|
||||
// 3. 用户企业认证成功
|
||||
// 4. 用户余额低于阈值
|
||||
// 5. 用户余额欠费
|
||||
//
|
||||
// 注意:
|
||||
// - 通知中只使用企业名称和手机号码,不展示用户ID
|
||||
// - 默认使用示例企业名称和手机号,实际使用时请根据需要修改
|
||||
func TestWeChatWork_SendAllBusinessNotifications(t *testing.T) {
|
||||
svc := newTestWeChatWorkService(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 示例企业信息(实际可按需修改)
|
||||
enterpriseName := "测试企业有限公司"
|
||||
phone := "13800000000"
|
||||
|
||||
now := time.Now().Format("2006-01-02 15:04:05")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
content string
|
||||
}{
|
||||
{
|
||||
name: "recharge_success",
|
||||
content: fmt.Sprintf(
|
||||
"### 【海宇数据】用户充值成功通知\n"+
|
||||
"> 企业名称:%s\n"+
|
||||
"> 联系手机:%s\n"+
|
||||
"> 充值金额:%s 元\n"+
|
||||
"> 入账总额:%s 元(含赠送)\n"+
|
||||
"> 时间:%s\n",
|
||||
enterpriseName,
|
||||
phone,
|
||||
"1000.00",
|
||||
"1050.00",
|
||||
now,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "invoice_applied",
|
||||
content: fmt.Sprintf(
|
||||
"### 【海宇数据】用户申请开发票\n"+
|
||||
"> 企业名称:%s\n"+
|
||||
"> 联系手机:%s\n"+
|
||||
"> 申请开票金额:%s 元\n"+
|
||||
"> 发票类型:%s\n"+
|
||||
"> 申请时间:%s\n"+
|
||||
"\n请财务尽快审核并开具发票。",
|
||||
enterpriseName,
|
||||
phone,
|
||||
"500.00",
|
||||
"增值税专用发票",
|
||||
now,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "certification_completed",
|
||||
content: fmt.Sprintf(
|
||||
"### 【海宇数据】企业认证成功\n"+
|
||||
"> 企业名称:%s\n"+
|
||||
"> 联系手机:%s\n"+
|
||||
"> 完成时间:%s\n"+
|
||||
"\n该企业已完成认证,请相关同事同步更新内部系统并关注后续接入情况。",
|
||||
enterpriseName,
|
||||
phone,
|
||||
now,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "low_balance_alert",
|
||||
content: fmt.Sprintf(
|
||||
"### 【海宇数据】用户余额预警\n"+
|
||||
"<font color=\"warning\">用户余额已低于预警阈值,请及时跟进。</font>\n"+
|
||||
"> 企业名称:%s\n"+
|
||||
"> 联系手机:%s\n"+
|
||||
"> 当前余额:%s 元\n"+
|
||||
"> 预警阈值:%s 元\n"+
|
||||
"> 时间:%s\n",
|
||||
enterpriseName,
|
||||
phone,
|
||||
"180.00",
|
||||
"200.00",
|
||||
now,
|
||||
),
|
||||
},
|
||||
{
|
||||
name: "arrears_alert",
|
||||
content: fmt.Sprintf(
|
||||
"### 【海宇数据】用户余额欠费告警\n"+
|
||||
"<font color=\"warning\">该企业已发生欠费,请及时联系并处理。</font>\n"+
|
||||
"> 企业名称:%s\n"+
|
||||
"> 联系手机:%s\n"+
|
||||
"> 当前余额:%s 元\n"+
|
||||
"> 欠费金额:%s 元\n"+
|
||||
"> 时间:%s\n",
|
||||
enterpriseName,
|
||||
phone,
|
||||
"-50.00",
|
||||
"50.00",
|
||||
now,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if err := svc.SendMarkdownMessage(ctx, tc.content); err != nil {
|
||||
t.Fatalf("发送场景[%s]通知失败: %v", tc.name, err)
|
||||
}
|
||||
// 简单间隔,避免瞬时发送过多消息
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
})
|
||||
}
|
||||
}
|
||||
531
internal/infrastructure/external/ocr/baidu_ocr_service.go
vendored
Normal file
531
internal/infrastructure/external/ocr/baidu_ocr_service.go
vendored
Normal file
@@ -0,0 +1,531 @@
|
||||
package ocr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/application/certification/dto/responses"
|
||||
)
|
||||
|
||||
// BaiduOCRService 百度OCR服务
|
||||
type BaiduOCRService struct {
|
||||
apiKey string
|
||||
secretKey string
|
||||
endpoint string
|
||||
timeout time.Duration
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewBaiduOCRService 创建百度OCR服务
|
||||
func NewBaiduOCRService(apiKey, secretKey string, logger *zap.Logger) *BaiduOCRService {
|
||||
return &BaiduOCRService{
|
||||
apiKey: apiKey,
|
||||
secretKey: secretKey,
|
||||
endpoint: "https://aip.baidubce.com",
|
||||
timeout: 60 * time.Second,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RecognizeBusinessLicense 识别营业执照
|
||||
func (s *BaiduOCRService) RecognizeBusinessLicense(ctx context.Context, imageBytes []byte) (*responses.BusinessLicenseResult, error) {
|
||||
s.logger.Info("开始识别营业执照", zap.Int("image_size", len(imageBytes)))
|
||||
|
||||
// 获取访问令牌
|
||||
accessToken, err := s.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 将图片转换为base64并进行URL编码
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
|
||||
|
||||
// 构建请求URL(只包含access_token)
|
||||
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/business_license?access_token=%s", s.endpoint, accessToken)
|
||||
|
||||
// 构建POST请求体
|
||||
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
|
||||
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("营业执照识别请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
|
||||
errorMsg := result["error_msg"].(string)
|
||||
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
|
||||
}
|
||||
|
||||
// 解析识别结果
|
||||
licenseResult := s.parseBusinessLicenseResult(result)
|
||||
|
||||
s.logger.Info("营业执照识别成功",
|
||||
zap.String("company_name", licenseResult.CompanyName),
|
||||
zap.String("legal_representative", licenseResult.LegalPersonName),
|
||||
zap.String("registered_capital", licenseResult.RegisteredCapital),
|
||||
)
|
||||
|
||||
return licenseResult, nil
|
||||
}
|
||||
|
||||
// RecognizeIDCard 识别身份证
|
||||
func (s *BaiduOCRService) RecognizeIDCard(ctx context.Context, imageBytes []byte, side string) (*responses.IDCardResult, error) {
|
||||
s.logger.Info("开始识别身份证", zap.String("side", side), zap.Int("image_size", len(imageBytes)))
|
||||
|
||||
// 获取访问令牌
|
||||
accessToken, err := s.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 将图片转换为base64并进行URL编码
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
|
||||
|
||||
// 构建请求URL(只包含access_token)
|
||||
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/idcard?access_token=%s", s.endpoint, accessToken)
|
||||
|
||||
// 构建POST请求体
|
||||
payload := strings.NewReader(fmt.Sprintf("image=%s&side=%s", imageBase64UrlEncoded, side))
|
||||
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("身份证识别请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
|
||||
errorMsg := result["error_msg"].(string)
|
||||
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
|
||||
}
|
||||
|
||||
// 解析识别结果
|
||||
idCardResult := s.parseIDCardResult(result, side)
|
||||
|
||||
s.logger.Info("身份证识别成功",
|
||||
zap.String("name", idCardResult.Name),
|
||||
zap.String("id_number", idCardResult.IDCardNumber),
|
||||
zap.String("side", side),
|
||||
)
|
||||
|
||||
return idCardResult, nil
|
||||
}
|
||||
|
||||
// RecognizeGeneralText 通用文字识别
|
||||
func (s *BaiduOCRService) RecognizeGeneralText(ctx context.Context, imageBytes []byte) (*responses.GeneralTextResult, error) {
|
||||
s.logger.Info("开始通用文字识别", zap.Int("image_size", len(imageBytes)))
|
||||
|
||||
// 获取访问令牌
|
||||
accessToken, err := s.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 将图片转换为base64并进行URL编码
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
|
||||
|
||||
// 构建请求URL(只包含access_token)
|
||||
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/general_basic?access_token=%s", s.endpoint, accessToken)
|
||||
|
||||
// 构建POST请求体
|
||||
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
|
||||
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("通用文字识别请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
|
||||
errorMsg := result["error_msg"].(string)
|
||||
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
|
||||
}
|
||||
|
||||
// 解析识别结果
|
||||
textResult := s.parseGeneralTextResult(result)
|
||||
|
||||
s.logger.Info("通用文字识别成功",
|
||||
zap.Int("word_count", len(textResult.Words)),
|
||||
zap.Float64("confidence", textResult.Confidence),
|
||||
)
|
||||
|
||||
return textResult, nil
|
||||
}
|
||||
|
||||
// RecognizeFromURL 从URL识别图片
|
||||
func (s *BaiduOCRService) RecognizeFromURL(ctx context.Context, imageURL string, ocrType string) (interface{}, error) {
|
||||
s.logger.Info("从URL识别图片", zap.String("url", imageURL), zap.String("type", ocrType))
|
||||
|
||||
// 下载图片
|
||||
imageBytes, err := s.downloadImage(ctx, imageURL)
|
||||
if err != nil {
|
||||
s.logger.Error("下载图片失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("下载图片失败: %w", err)
|
||||
}
|
||||
|
||||
// 根据类型调用相应的识别方法
|
||||
switch ocrType {
|
||||
case "business_license":
|
||||
return s.RecognizeBusinessLicense(ctx, imageBytes)
|
||||
case "idcard_front":
|
||||
return s.RecognizeIDCard(ctx, imageBytes, "front")
|
||||
case "idcard_back":
|
||||
return s.RecognizeIDCard(ctx, imageBytes, "back")
|
||||
case "general_text":
|
||||
return s.RecognizeGeneralText(ctx, imageBytes)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的OCR类型: %s", ocrType)
|
||||
}
|
||||
}
|
||||
|
||||
// getAccessToken 获取百度API访问令牌
|
||||
func (s *BaiduOCRService) getAccessToken(ctx context.Context) (string, error) {
|
||||
// 构建获取访问令牌的URL
|
||||
tokenURL := fmt.Sprintf("%s/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
s.endpoint, s.apiKey, s.secretKey)
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.sendRequest(ctx, "POST", tokenURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取访问令牌请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return "", fmt.Errorf("解析访问令牌响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error"].(string); ok && errCode != "" {
|
||||
errorDesc := result["error_description"].(string)
|
||||
return "", fmt.Errorf("获取访问令牌失败: %s - %s", errCode, errorDesc)
|
||||
}
|
||||
|
||||
// 提取访问令牌
|
||||
accessToken, ok := result["access_token"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("响应中未找到访问令牌")
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// sendRequest 发送HTTP请求
|
||||
func (s *BaiduOCRService) sendRequest(ctx context.Context, method, url string, body io.Reader) ([]byte, error) {
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: s.timeout,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("User-Agent", "hyapi-server/1.0")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error();
|
||||
errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
return nil, fmt.Errorf("API请求超时: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应内容
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应内容失败: %w", err)
|
||||
}
|
||||
|
||||
return responseBody, nil
|
||||
}
|
||||
|
||||
// parseBusinessLicenseResult 解析营业执照识别结果
|
||||
func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface{}) *responses.BusinessLicenseResult {
|
||||
wordsResult := result["words_result"].(map[string]interface{})
|
||||
|
||||
// 提取企业信息
|
||||
companyName := ""
|
||||
if companyNameObj, ok := wordsResult["单位名称"].(map[string]interface{}); ok {
|
||||
companyName = companyNameObj["words"].(string)
|
||||
}
|
||||
|
||||
unifiedSocialCode := ""
|
||||
if socialCreditCodeObj, ok := wordsResult["社会信用代码"].(map[string]interface{}); ok {
|
||||
unifiedSocialCode = socialCreditCodeObj["words"].(string)
|
||||
}
|
||||
|
||||
legalPersonName := ""
|
||||
if legalPersonObj, ok := wordsResult["法人"].(map[string]interface{}); ok {
|
||||
legalPersonName = legalPersonObj["words"].(string)
|
||||
}
|
||||
|
||||
// 提取注册资本等其他信息
|
||||
registeredCapital := ""
|
||||
if registeredCapitalObj, ok := wordsResult["注册资本"].(map[string]interface{}); ok {
|
||||
registeredCapital = registeredCapitalObj["words"].(string)
|
||||
}
|
||||
|
||||
// 提取企业地址
|
||||
address := ""
|
||||
if addressObj, ok := wordsResult["地址"].(map[string]interface{}); ok {
|
||||
address = addressObj["words"].(string)
|
||||
}
|
||||
|
||||
// 计算置信度(这里简化处理,实际应该从OCR结果中获取)
|
||||
confidence := 0.9 // 默认置信度
|
||||
|
||||
return &responses.BusinessLicenseResult{
|
||||
CompanyName: companyName,
|
||||
UnifiedSocialCode: unifiedSocialCode,
|
||||
LegalPersonName: legalPersonName,
|
||||
LegalPersonID: "", // 营业执照上没有法人身份证号
|
||||
RegisteredCapital: registeredCapital,
|
||||
Address: address,
|
||||
Confidence: confidence,
|
||||
ProcessedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// parseIDCardResult 解析身份证识别结果
|
||||
func (s *BaiduOCRService) parseIDCardResult(result map[string]interface{}, side string) *responses.IDCardResult {
|
||||
wordsResult := result["words_result"].(map[string]interface{})
|
||||
|
||||
idCardResult := &responses.IDCardResult{
|
||||
Side: side,
|
||||
Confidence: s.extractConfidence(result),
|
||||
}
|
||||
|
||||
if side == "front" {
|
||||
if name, ok := wordsResult["姓名"]; ok {
|
||||
if word, ok := name.(map[string]interface{}); ok {
|
||||
idCardResult.Name = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if gender, ok := wordsResult["性别"]; ok {
|
||||
if word, ok := gender.(map[string]interface{}); ok {
|
||||
idCardResult.Gender = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if nation, ok := wordsResult["民族"]; ok {
|
||||
if word, ok := nation.(map[string]interface{}); ok {
|
||||
idCardResult.Nation = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if birthday, ok := wordsResult["出生"]; ok {
|
||||
if word, ok := birthday.(map[string]interface{}); ok {
|
||||
idCardResult.Birthday = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if address, ok := wordsResult["住址"]; ok {
|
||||
if word, ok := address.(map[string]interface{}); ok {
|
||||
idCardResult.Address = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if idNumber, ok := wordsResult["公民身份号码"]; ok {
|
||||
if word, ok := idNumber.(map[string]interface{}); ok {
|
||||
idCardResult.IDCardNumber = word["words"].(string)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if issuingAgency, ok := wordsResult["签发机关"]; ok {
|
||||
if word, ok := issuingAgency.(map[string]interface{}); ok {
|
||||
idCardResult.IssuingAgency = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if validPeriod, ok := wordsResult["有效期限"]; ok {
|
||||
if word, ok := validPeriod.(map[string]interface{}); ok {
|
||||
idCardResult.ValidPeriod = word["words"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return idCardResult
|
||||
}
|
||||
|
||||
// parseGeneralTextResult 解析通用文字识别结果
|
||||
func (s *BaiduOCRService) parseGeneralTextResult(result map[string]interface{}) *responses.GeneralTextResult {
|
||||
wordsResult := result["words_result"].([]interface{})
|
||||
|
||||
textResult := &responses.GeneralTextResult{
|
||||
Confidence: s.extractConfidence(result),
|
||||
Words: make([]responses.TextLine, 0, len(wordsResult)),
|
||||
}
|
||||
|
||||
for _, word := range wordsResult {
|
||||
if wordMap, ok := word.(map[string]interface{}); ok {
|
||||
line := responses.TextLine{
|
||||
Text: wordMap["words"].(string),
|
||||
Confidence: 1.0, // 百度返回的通用文字识别没有单独置信度
|
||||
}
|
||||
textResult.Words = append(textResult.Words, line)
|
||||
}
|
||||
}
|
||||
|
||||
return textResult
|
||||
}
|
||||
|
||||
// extractConfidence 提取置信度
|
||||
func (s *BaiduOCRService) extractConfidence(result map[string]interface{}) float64 {
|
||||
if confidence, ok := result["confidence"].(float64); ok {
|
||||
return confidence
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// extractWords 提取识别的文字
|
||||
func (s *BaiduOCRService) extractWords(result map[string]interface{}) []string {
|
||||
words := make([]string, 0)
|
||||
|
||||
if wordsResult, ok := result["words_result"]; ok {
|
||||
switch v := wordsResult.(type) {
|
||||
case map[string]interface{}:
|
||||
// 营业执照等结构化文档
|
||||
for _, word := range v {
|
||||
if wordMap, ok := word.(map[string]interface{}); ok {
|
||||
if wordsStr, ok := wordMap["words"].(string); ok {
|
||||
words = append(words, wordsStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
// 通用文字识别
|
||||
for _, word := range v {
|
||||
if wordMap, ok := word.(map[string]interface{}); ok {
|
||||
if wordsStr, ok := wordMap["words"].(string); ok {
|
||||
words = append(words, wordsStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return words
|
||||
}
|
||||
|
||||
// downloadImage 下载图片
|
||||
func (s *BaiduOCRService) downloadImage(ctx context.Context, imageURL string) ([]byte, error) {
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载图片失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应内容
|
||||
imageBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取图片内容失败: %w", err)
|
||||
}
|
||||
|
||||
return imageBytes, nil
|
||||
}
|
||||
|
||||
// ValidateBusinessLicense 验证营业执照识别结果
|
||||
func (s *BaiduOCRService) ValidateBusinessLicense(result *responses.BusinessLicenseResult) error {
|
||||
if result.Confidence < 0.8 {
|
||||
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
|
||||
}
|
||||
if result.CompanyName == "" {
|
||||
return fmt.Errorf("未能识别公司名称")
|
||||
}
|
||||
if result.LegalPersonName == "" {
|
||||
return fmt.Errorf("未能识别法定代表人")
|
||||
}
|
||||
if result.UnifiedSocialCode == "" {
|
||||
return fmt.Errorf("未能识别统一社会信用代码")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateIDCard 验证身份证识别结果
|
||||
func (s *BaiduOCRService) ValidateIDCard(result *responses.IDCardResult) error {
|
||||
if result.Confidence < 0.8 {
|
||||
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
|
||||
}
|
||||
if result.Side == "front" {
|
||||
if result.Name == "" {
|
||||
return fmt.Errorf("未能识别姓名")
|
||||
}
|
||||
if result.IDCardNumber == "" {
|
||||
return fmt.Errorf("未能识别身份证号码")
|
||||
}
|
||||
} else {
|
||||
if result.IssuingAgency == "" {
|
||||
return fmt.Errorf("未能识别签发机关")
|
||||
}
|
||||
if result.ValidPeriod == "" {
|
||||
return fmt.Errorf("未能识别有效期限")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
164
internal/infrastructure/external/pdfgen/pdfgen_service.go
vendored
Normal file
164
internal/infrastructure/external/pdfgen/pdfgen_service.go
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
package pdfgen
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// PDFGenService PDF生成服务客户端
|
||||
type PDFGenService struct {
|
||||
baseURL string
|
||||
apiPath string
|
||||
logger *zap.Logger
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewPDFGenService 创建PDF生成服务客户端
|
||||
func NewPDFGenService(cfg *config.Config, logger *zap.Logger) *PDFGenService {
|
||||
// 根据环境选择服务地址
|
||||
var baseURL string
|
||||
if cfg.App.IsProduction() {
|
||||
baseURL = cfg.PDFGen.ProductionURL
|
||||
} else {
|
||||
baseURL = cfg.PDFGen.DevelopmentURL
|
||||
}
|
||||
|
||||
// 如果配置为空,使用默认值
|
||||
if baseURL == "" {
|
||||
if cfg.App.IsProduction() {
|
||||
baseURL = "http://localhost:15990"
|
||||
} else {
|
||||
baseURL = "http://101.43.41.217:15990"
|
||||
}
|
||||
}
|
||||
|
||||
// 获取API路径,如果为空使用默认值
|
||||
apiPath := cfg.PDFGen.APIPath
|
||||
if apiPath == "" {
|
||||
apiPath = "/api/v1/generate/guangzhou"
|
||||
}
|
||||
|
||||
// 获取超时时间,如果为0使用默认值
|
||||
timeout := cfg.PDFGen.Timeout
|
||||
if timeout == 0 {
|
||||
timeout = 120 * time.Second
|
||||
}
|
||||
|
||||
logger.Info("PDF生成服务已初始化",
|
||||
zap.String("base_url", baseURL),
|
||||
zap.String("api_path", apiPath),
|
||||
zap.Duration("timeout", timeout),
|
||||
)
|
||||
|
||||
return &PDFGenService{
|
||||
baseURL: baseURL,
|
||||
apiPath: apiPath,
|
||||
logger: logger,
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
Proxy: nil, // 不使用任何代理
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GeneratePDFRequest PDF生成请求
|
||||
type GeneratePDFRequest struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
ReportNumber string `json:"report_number,omitempty"`
|
||||
GenerateTime string `json:"generate_time,omitempty"`
|
||||
}
|
||||
|
||||
// GeneratePDFResponse PDF生成响应
|
||||
type GeneratePDFResponse struct {
|
||||
PDFBytes []byte
|
||||
FileName string
|
||||
}
|
||||
|
||||
// GenerateGuangzhouPDF 生成广州大数据租赁风险PDF报告
|
||||
func (s *PDFGenService) GenerateGuangzhouPDF(ctx context.Context, req *GeneratePDFRequest) (*GeneratePDFResponse, error) {
|
||||
// 构建请求体
|
||||
reqBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建请求URL
|
||||
url := fmt.Sprintf("%s%s", s.baseURL, s.apiPath)
|
||||
|
||||
// 创建HTTP请求
|
||||
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
start := time.Now()
|
||||
|
||||
// 发送请求
|
||||
s.logger.Info("开始调用PDF生成服务",
|
||||
zap.String("url", url),
|
||||
zap.Int("data_count", len(req.Data)),
|
||||
zap.ByteString("reqBody", reqBody),
|
||||
)
|
||||
|
||||
resp, err := s.client.Do(httpReq)
|
||||
if err != nil {
|
||||
s.logger.Error("调用PDF生成服务失败",
|
||||
zap.String("url", url),
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("调用PDF生成服务失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查HTTP状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// 尝试读取错误信息
|
||||
errorBody, _ := io.ReadAll(resp.Body)
|
||||
s.logger.Error("PDF生成服务返回错误",
|
||||
zap.String("url", url),
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
zap.String("error_body", string(errorBody)),
|
||||
)
|
||||
return nil, fmt.Errorf("PDF生成失败,状态码: %d, 错误: %s", resp.StatusCode, string(errorBody))
|
||||
}
|
||||
|
||||
// 读取PDF文件
|
||||
pdfBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取PDF文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 生成文件名
|
||||
fileName := "大数据租赁风险报告.pdf"
|
||||
if req.ReportNumber != "" {
|
||||
fileName = fmt.Sprintf("%s.pdf", req.ReportNumber)
|
||||
}
|
||||
|
||||
s.logger.Info("PDF生成成功",
|
||||
zap.String("url", url),
|
||||
zap.String("file_name", fileName),
|
||||
zap.Int("file_size", len(pdfBytes)),
|
||||
zap.Duration("duration", time.Since(start)),
|
||||
)
|
||||
|
||||
return &GeneratePDFResponse{
|
||||
PDFBytes: pdfBytes,
|
||||
FileName: fileName,
|
||||
}, nil
|
||||
}
|
||||
47
internal/infrastructure/external/shujubao/crypto.go
vendored
Normal file
47
internal/infrastructure/external/shujubao/crypto.go
vendored
Normal file
@@ -0,0 +1,47 @@
|
||||
package shujubao
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SignMethod 签名方法
|
||||
type SignMethod string
|
||||
|
||||
const (
|
||||
SignMethodMD5 SignMethod = "md5"
|
||||
SignMethodHMACMD5 SignMethod = "hmac"
|
||||
)
|
||||
|
||||
// GenerateSignMD5 使用 MD5 生成签名:md5(app_secret + timestamp),32 位小写
|
||||
func GenerateSignMD5(appSecret, timestamp string) string {
|
||||
h := md5.Sum([]byte(appSecret + timestamp))
|
||||
sign := strings.ToLower(hex.EncodeToString(h[:]))
|
||||
return sign
|
||||
}
|
||||
|
||||
// GenerateSignHMAC 使用 HMAC-MD5 生成签名(仅 timestamp,兼容旧逻辑)
|
||||
func GenerateSignHMAC(appSecret, timestamp string) string {
|
||||
mac := hmac.New(md5.New, []byte(appSecret))
|
||||
mac.Write([]byte(timestamp))
|
||||
sign := strings.ToLower(hex.EncodeToString(mac.Sum(nil)))
|
||||
return sign
|
||||
}
|
||||
|
||||
// GenerateSignFromParamsMD5 根据入参生成签名:入参按 ASCII 排序组合后与 app_secret 做 MD5。
|
||||
// sortedParamStr 格式为 key1=value1&key2=value2&...(key 按字母序)。
|
||||
func GenerateSignFromParamsMD5(appSecret, sortedParamStr string) string {
|
||||
h := md5.Sum([]byte(appSecret + sortedParamStr))
|
||||
sign := strings.ToLower(hex.EncodeToString(h[:]))
|
||||
return sign
|
||||
}
|
||||
|
||||
// GenerateSignFromParamsHMAC 根据入参生成签名:入参按 ASCII 排序组合后与 app_secret 做 HMAC-MD5。
|
||||
func GenerateSignFromParamsHMAC(appSecret, sortedParamStr string) string {
|
||||
mac := hmac.New(md5.New, []byte(appSecret))
|
||||
mac.Write([]byte(sortedParamStr))
|
||||
sign := strings.ToLower(hex.EncodeToString(mac.Sum(nil)))
|
||||
return sign
|
||||
}
|
||||
135
internal/infrastructure/external/shujubao/shujubao_errors.go
vendored
Normal file
135
internal/infrastructure/external/shujubao/shujubao_errors.go
vendored
Normal file
@@ -0,0 +1,135 @@
|
||||
package shujubao
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// GetQueryEmptyErrByCode 将数据宝错误码归类为“查询为空/不扣费”错误。
|
||||
// 说明:上游通常依赖 errors.Is(err, ErrQueryEmpty) 来决定是否扣费。
|
||||
func GetQueryEmptyErrByCode(code string) error {
|
||||
switch code {
|
||||
case "10001", "10006":
|
||||
return ErrQueryEmpty
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ShujubaoError 数据宝服务错误
|
||||
type ShujubaoError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Error 实现 error 接口
|
||||
func (e *ShujubaoError) Error() string {
|
||||
return fmt.Sprintf("数据宝错误 [%s]: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// IsSuccess 检查是否成功
|
||||
func (e *ShujubaoError) IsSuccess() bool {
|
||||
return e.Code == "200" || e.Code == "0" || e.Code == "10000"
|
||||
}
|
||||
|
||||
// NewShujubaoError 创建新的数据宝错误
|
||||
func NewShujubaoError(code, message string) *ShujubaoError {
|
||||
return &ShujubaoError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// 数据宝全系统错误码与描述映射(Code -> Desc)
|
||||
var systemErrorCodeDesc = map[string]string{
|
||||
"10000": "成功",
|
||||
"10001": "查空",
|
||||
"10002": "查询失败",
|
||||
"10003": "系统处理异常",
|
||||
"10004": "系统处理超时",
|
||||
"10005": "服务异常",
|
||||
"10006": "查无",
|
||||
"10017": "查询失败",
|
||||
"10018": "参数错误",
|
||||
"10019": "系统异常",
|
||||
"10020": "同一参数请求次数超限",
|
||||
"99999": "其他错误",
|
||||
"999": "接口处理异常",
|
||||
"000": "key参数不能为空",
|
||||
"001": "找不到这个key",
|
||||
"002": "调用次数已用完",
|
||||
"003": "用户该接口状态不可用",
|
||||
"004": "接口信息不存在",
|
||||
"005": "你没有认证信息",
|
||||
"008": "当前接口只允许“企业认证”通过的账户进行调用,请在数据宝官网个人中心进行企业认证后再进行调用,谢谢!",
|
||||
"009": "触发风控",
|
||||
"011": "接口缺少参数",
|
||||
"012": "没有ip访问权限",
|
||||
"013": "接口模板不存在",
|
||||
"015": "该接口已下架",
|
||||
"020": "调用第三方产生异常",
|
||||
"022": "调用第三方返回的数据格式错误",
|
||||
"025": "你没有购买此接口",
|
||||
"026": "用户信息不存在",
|
||||
"027": "请求第三方地址超时,请稍后再试",
|
||||
"028": "请求第三方地址被拒绝,请稍后再试",
|
||||
"034": "签名不合法",
|
||||
"035": "请求参数加密有误",
|
||||
"036": "验签失败",
|
||||
"037": "timestamp不能为空",
|
||||
"038": "请求繁忙,请稍后联系管理员再试",
|
||||
"039": "请在个人中心接口设置加密状态",
|
||||
"040": "timestamp不合法",
|
||||
"041": "timestamp已过期",
|
||||
"042": "身份证手机号姓名银行卡等不符合规则",
|
||||
"043": "该号段不支持验证",
|
||||
"047": "请在个人中心获取密钥",
|
||||
"048": "找不到这个secretKey",
|
||||
"049": "用户还未申购该产品",
|
||||
"050": "请联系客服开启验签",
|
||||
"051": "超过当日调用次数",
|
||||
"052": "机房限制调用,请联系客服切换其他机房",
|
||||
"053": "系统错误",
|
||||
"054": "token无效",
|
||||
"055": "配置信息未完善,请联系数据宝工作人员",
|
||||
"056": "apiName参数不能为空",
|
||||
"057": "并发量超过限制,请联系客服",
|
||||
"058": "撞库风控预警,请联系客服",
|
||||
}
|
||||
|
||||
// GetSystemErrorDesc 根据错误码获取系统错误描述(支持带 SYSTEM_ 前缀或纯数字)
|
||||
func GetSystemErrorDesc(code string) string {
|
||||
// 去掉 SYSTEM_ 前缀
|
||||
key := code
|
||||
if len(code) > 7 && code[:7] == "SYSTEM_" {
|
||||
key = code[7:]
|
||||
}
|
||||
if desc, ok := systemErrorCodeDesc[key]; ok {
|
||||
return desc
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// NewShujubaoErrorFromCode 根据状态码创建错误
|
||||
func NewShujubaoErrorFromCode(code, message string) *ShujubaoError {
|
||||
if message != "" {
|
||||
return NewShujubaoError(code, message)
|
||||
}
|
||||
if desc := GetSystemErrorDesc(code); desc != "" {
|
||||
return NewShujubaoError(code, desc)
|
||||
}
|
||||
return NewShujubaoError(code, "未知错误")
|
||||
}
|
||||
|
||||
// IsShujubaoError 检查是否是数据宝错误
|
||||
func IsShujubaoError(err error) bool {
|
||||
_, ok := err.(*ShujubaoError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetShujubaoError 获取数据宝错误
|
||||
func GetShujubaoError(err error) *ShujubaoError {
|
||||
if shujubaoErr, ok := err.(*ShujubaoError); ok {
|
||||
return shujubaoErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
66
internal/infrastructure/external/shujubao/shujubao_factory.go
vendored
Normal file
66
internal/infrastructure/external/shujubao/shujubao_factory.go
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
package shujubao
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewShujubaoServiceWithConfig 使用配置创建数据宝服务
|
||||
func NewShujubaoServiceWithConfig(cfg *config.Config) (*ShujubaoService, error) {
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Shujubao.Logging.Enabled,
|
||||
LogDir: cfg.Shujubao.Logging.LogDir,
|
||||
ServiceName: "shujubao",
|
||||
UseDaily: cfg.Shujubao.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Shujubao.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
for k, v := range cfg.Shujubao.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[k] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: v.MaxSize,
|
||||
MaxBackups: v.MaxBackups,
|
||||
MaxAge: v.MaxAge,
|
||||
Compress: v.Compress,
|
||||
}
|
||||
}
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signMethod SignMethod
|
||||
if cfg.Shujubao.SignMethod == "md5" {
|
||||
signMethod = SignMethodMD5
|
||||
} else {
|
||||
signMethod = SignMethodHMACMD5
|
||||
}
|
||||
timeout := 60 * time.Second
|
||||
if cfg.Shujubao.Timeout > 0 {
|
||||
timeout = cfg.Shujubao.Timeout
|
||||
}
|
||||
|
||||
return NewShujubaoService(
|
||||
cfg.Shujubao.URL,
|
||||
cfg.Shujubao.AppSecret,
|
||||
signMethod,
|
||||
timeout,
|
||||
logger,
|
||||
), nil
|
||||
}
|
||||
|
||||
// NewShujubaoServiceWithLogging 使用自定义日志配置创建数据宝服务
|
||||
func NewShujubaoServiceWithLogging(url, appSecret string, signMethod SignMethod, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig) (*ShujubaoService, error) {
|
||||
loggingConfig.ServiceName = "shujubao"
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewShujubaoService(url, appSecret, signMethod, timeout, logger), nil
|
||||
}
|
||||
|
||||
// NewShujubaoServiceSimple 创建无日志的数据宝服务
|
||||
func NewShujubaoServiceSimple(url, appSecret string, signMethod SignMethod, timeout time.Duration) *ShujubaoService {
|
||||
return NewShujubaoService(url, appSecret, signMethod, timeout, nil)
|
||||
}
|
||||
313
internal/infrastructure/external/shujubao/shujubao_service.go
vendored
Normal file
313
internal/infrastructure/external/shujubao/shujubao_service.go
vendored
Normal file
@@ -0,0 +1,313 @@
|
||||
package shujubao
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志
|
||||
maxLogParamValueLen = 300
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
ErrQueryEmpty = errors.New("查询为空")
|
||||
)
|
||||
|
||||
// truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容
|
||||
func truncateForLog(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return s
|
||||
}
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]"
|
||||
}
|
||||
|
||||
// paramsForLog 返回适合写入错误日志的入参副本(长字符串会被截断)
|
||||
func paramsForLog(params map[string]interface{}) map[string]interface{} {
|
||||
if params == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]interface{}, len(params))
|
||||
for k, v := range params {
|
||||
if v == nil {
|
||||
out[k] = nil
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
out[k] = truncateForLog(val, maxLogParamValueLen)
|
||||
default:
|
||||
s := fmt.Sprint(v)
|
||||
out[k] = truncateForLog(s, maxLogParamValueLen)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ShujubaoResp 数据宝 API 通用响应(按实际文档调整)
|
||||
type ShujubaoResp struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
// ShujubaoConfig 数据宝服务配置
|
||||
type ShujubaoConfig struct {
|
||||
URL string
|
||||
AppSecret string
|
||||
SignMethod SignMethod
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// ShujubaoService 数据宝服务
|
||||
type ShujubaoService struct {
|
||||
config ShujubaoConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewShujubaoService 创建数据宝服务实例
|
||||
func NewShujubaoService(url, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *ShujubaoService {
|
||||
if signMethod == "" {
|
||||
signMethod = SignMethodHMACMD5
|
||||
}
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
return &ShujubaoService{
|
||||
config: ShujubaoConfig{
|
||||
URL: url,
|
||||
AppSecret: appSecret,
|
||||
SignMethod: signMethod,
|
||||
Timeout: timeout,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求 ID
|
||||
func (s *ShujubaoService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, s.config.AppSecret)))
|
||||
return fmt.Sprintf("shujubao_%x", hash[:8])
|
||||
}
|
||||
|
||||
// buildSortedParamStr 将入参按 key 的 ASCII 排序组合为 key1=value1&key2=value2&...
|
||||
func buildSortedParamStr(params map[string]interface{}) string {
|
||||
if len(params) == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var b strings.Builder
|
||||
for i, k := range keys {
|
||||
if i > 0 {
|
||||
b.WriteByte('&')
|
||||
}
|
||||
v := params[k]
|
||||
var vs string
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
vs = val
|
||||
case nil:
|
||||
vs = ""
|
||||
default:
|
||||
vs = fmt.Sprint(val)
|
||||
}
|
||||
b.WriteString(k)
|
||||
b.WriteByte('=')
|
||||
b.WriteString(vs)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// buildFormUrlEncodedBody 按 key 的 ASCII 排序构建 application/x-www-form-urlencoded 请求体(键与值均已 URL 编码)
|
||||
func buildFormUrlEncodedBody(params map[string]interface{}) string {
|
||||
if len(params) == 0 {
|
||||
return ""
|
||||
}
|
||||
keys := make([]string, 0, len(params))
|
||||
for k := range params {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
var b strings.Builder
|
||||
for i, k := range keys {
|
||||
if i > 0 {
|
||||
b.WriteByte('&')
|
||||
}
|
||||
v := params[k]
|
||||
var vs string
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
vs = val
|
||||
case nil:
|
||||
vs = ""
|
||||
default:
|
||||
vs = fmt.Sprint(val)
|
||||
}
|
||||
b.WriteString(url.QueryEscape(k))
|
||||
b.WriteByte('=')
|
||||
b.WriteString(url.QueryEscape(vs))
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// generateSign 根据入参与时间戳生成签名。入参按 ASCII 排序组合后与 app_secret 做 MD5/HMAC。
|
||||
// 对于开启了加密的接口需传 sign 与 timestamp;明文传输的接口则无需传这两个参数。
|
||||
func (s *ShujubaoService) generateSign(timestamp string, params map[string]interface{}) string {
|
||||
// 合并 timestamp 到入参后参与排序
|
||||
merged := make(map[string]interface{}, len(params)+1)
|
||||
for k, v := range params {
|
||||
merged[k] = v
|
||||
}
|
||||
merged["timestamp"] = timestamp
|
||||
sortedStr := buildSortedParamStr(merged)
|
||||
switch s.config.SignMethod {
|
||||
case SignMethodMD5:
|
||||
return GenerateSignFromParamsMD5(s.config.AppSecret, sortedStr)
|
||||
default:
|
||||
return GenerateSignFromParamsHMAC(s.config.AppSecret, sortedStr)
|
||||
}
|
||||
}
|
||||
|
||||
// buildRequestURL 拼接接口地址得到最终请求 URL,如 https://api.chinadatapay.com/communication/personal/197
|
||||
func (s *ShujubaoService) buildRequestURL(apiPath string) string {
|
||||
base := strings.TrimSuffix(s.config.URL, "/")
|
||||
if apiPath == "" {
|
||||
return base
|
||||
}
|
||||
return base + "/" + strings.TrimPrefix(apiPath, "/")
|
||||
}
|
||||
|
||||
// CallAPI 调用数据宝 API(POST)。最终请求地址 = url + 拼接接口地址值;body 为业务参数;sign、timestamp 按原样传 header。
|
||||
func (s *ShujubaoService) CallAPI(ctx context.Context, apiPath string, params map[string]interface{}) (data interface{}, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := s.generateRequestID()
|
||||
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
|
||||
// 最终请求 URL = https://api.chinadatapay.com/communication + 拼接接口地址值,如 /personal/197
|
||||
requestURL := s.buildRequestURL(apiPath)
|
||||
|
||||
var transactionID string
|
||||
if id, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = id
|
||||
}
|
||||
|
||||
if s.logger != nil {
|
||||
s.logger.LogRequest(requestID, transactionID, apiPath, requestURL)
|
||||
}
|
||||
|
||||
// 使用 application/x-www-form-urlencoded,贵司接口暂不支持 JSON 入参
|
||||
formBody := buildFormUrlEncodedBody(params)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, strings.NewReader(formBody))
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("timestamp", timestamp)
|
||||
req.Header.Set("sign", s.generateSign(timestamp, params))
|
||||
|
||||
client := &http.Client{Timeout: s.config.Timeout}
|
||||
response, err := client.Do(req)
|
||||
if err != nil {
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
}
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
respBody, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.logger != nil {
|
||||
duration := time.Since(startTime)
|
||||
s.logger.LogResponse(requestID, transactionID, apiPath, response.StatusCode, duration)
|
||||
}
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shujubaoResp ShujubaoResp
|
||||
if err := json.Unmarshal(respBody, &shujubaoResp); err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
code := shujubaoResp.Code
|
||||
|
||||
// 成功码只有这三类:其它 code 都走统一错误映射返回
|
||||
if code != "10000" && code != "200" && code != "0" {
|
||||
shujubaoErr := NewShujubaoErrorFromCode(code, shujubaoResp.Message)
|
||||
if queryEmptyErr := GetQueryEmptyErrByCode(code); queryEmptyErr != nil {
|
||||
err = errors.Join(queryEmptyErr, shujubaoErr)
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, shujubaoErr, paramsForLog(params))
|
||||
}
|
||||
return nil, errors.Join(ErrDatasource, shujubaoErr)
|
||||
}
|
||||
|
||||
return shujubaoResp.Data, nil
|
||||
}
|
||||
199
internal/infrastructure/external/shumai/crypto.go
vendored
Normal file
199
internal/infrastructure/external/shumai/crypto.go
vendored
Normal file
@@ -0,0 +1,199 @@
|
||||
package shumai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SignMethod 签名方法
|
||||
type SignMethod string
|
||||
|
||||
const (
|
||||
SignMethodMD5 SignMethod = "md5"
|
||||
SignMethodHMACMD5 SignMethod = "hmac"
|
||||
)
|
||||
|
||||
// GenerateSignForm 生成表单接口签名(appid & timestamp & app_security)
|
||||
// 拼接规则:appid + "&" + timestamp + "&" + app_security,对拼接串做 MD5,32 位小写十六进制;
|
||||
// 不足 32 位左侧补 0。
|
||||
func GenerateSignForm(appid, timestamp, appSecret string) string {
|
||||
str := appid + "&" + timestamp + "&" + appSecret
|
||||
hash := md5.Sum([]byte(str))
|
||||
sign := strings.ToLower(hex.EncodeToString(hash[:]))
|
||||
if n := 32 - len(sign); n > 0 {
|
||||
sign = strings.Repeat("0", n) + sign
|
||||
}
|
||||
return sign
|
||||
}
|
||||
|
||||
// app_secret: "BnJWo61hUgNEa5fqBCueiT1IZ1e0DxPU"
|
||||
|
||||
// Encrypt 使用 AES/ECB/PKCS5Padding 加密数据
|
||||
// 加密算法:AES,工作模式:ECB(无初始向量),填充方式:PKCS5Padding
|
||||
// 加密 key 是服务商分配的 app_security,AES 加密之后再进行 base64 编码
|
||||
func Encrypt(data, appSecurity string) (string, error) {
|
||||
key := prepareAESKey([]byte(appSecurity))
|
||||
ciphertext, err := aesEncryptECB([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt 解密 base64 编码的 AES/ECB/PKCS5Padding 加密数据
|
||||
func Decrypt(encodedData, appSecurity string) ([]byte, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := prepareAESKey([]byte(appSecurity))
|
||||
plaintext, err := aesDecryptECB(ciphertext, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// prepareAESKey 准备 AES 密钥,确保长度为 16/24/32 字节
|
||||
// 如果 key 长度不足,用 0 填充;如果过长,截取前 32 字节
|
||||
func prepareAESKey(key []byte) []byte {
|
||||
keyLen := len(key)
|
||||
if keyLen == 16 || keyLen == 24 || keyLen == 32 {
|
||||
return key
|
||||
}
|
||||
if keyLen < 16 {
|
||||
// 不足 16 字节,用 0 填充到 16 字节(AES-128)
|
||||
padded := make([]byte, 16)
|
||||
copy(padded, key)
|
||||
return padded
|
||||
}
|
||||
if keyLen < 24 {
|
||||
// 不足 24 字节,用 0 填充到 24 字节(AES-192)
|
||||
padded := make([]byte, 24)
|
||||
copy(padded, key)
|
||||
return padded
|
||||
}
|
||||
if keyLen < 32 {
|
||||
// 不足 32 字节,用 0 填充到 32 字节(AES-256)
|
||||
padded := make([]byte, 32)
|
||||
copy(padded, key)
|
||||
return padded
|
||||
}
|
||||
// 超过 32 字节,截取前 32 字节(AES-256)
|
||||
return key[:32]
|
||||
}
|
||||
|
||||
// aesEncryptECB 使用 AES ECB 模式加密,PKCS5 填充
|
||||
func aesEncryptECB(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
|
||||
ciphertext := make([]byte, len(paddedPlaintext))
|
||||
mode := newECBEncrypter(block)
|
||||
mode.CryptBlocks(ciphertext, paddedPlaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// aesDecryptECB 使用 AES ECB 模式解密,PKCS5 去填充
|
||||
func aesDecryptECB(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ciphertext)%block.BlockSize() != 0 {
|
||||
return nil, errors.New("ciphertext length is not a multiple of block size")
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode := newECBDecrypter(block)
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext), nil
|
||||
}
|
||||
|
||||
// pkcs5Padding PKCS5 填充
|
||||
func pkcs5Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs5Unpadding 去除 PKCS5 填充
|
||||
func pkcs5Unpadding(src []byte) []byte {
|
||||
length := len(src)
|
||||
if length == 0 {
|
||||
return src
|
||||
}
|
||||
unpadding := int(src[length-1])
|
||||
if unpadding > length {
|
||||
return src
|
||||
}
|
||||
return src[:length-unpadding]
|
||||
}
|
||||
|
||||
// ECB 模式加密/解密实现
|
||||
type ecb struct {
|
||||
b cipher.Block
|
||||
blockSize int
|
||||
}
|
||||
|
||||
func newECB(b cipher.Block) *ecb {
|
||||
return &ecb{
|
||||
b: b,
|
||||
blockSize: b.BlockSize(),
|
||||
}
|
||||
}
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) BlockSize() int {
|
||||
return x.blockSize
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Encrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) BlockSize() int {
|
||||
return x.blockSize
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Decrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
108
internal/infrastructure/external/shumai/shumai_errors.go
vendored
Normal file
108
internal/infrastructure/external/shumai/shumai_errors.go
vendored
Normal file
@@ -0,0 +1,108 @@
|
||||
package shumai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ShumaiError 数脉服务错误
|
||||
type ShumaiError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Error 实现 error 接口
|
||||
func (e *ShumaiError) Error() string {
|
||||
return fmt.Sprintf("数脉错误 [%s]: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// IsSuccess 是否成功
|
||||
func (e *ShumaiError) IsSuccess() bool {
|
||||
return e.Code == "0" || e.Code == "200"
|
||||
}
|
||||
|
||||
// IsNoRecord 是否查无记录
|
||||
func (e *ShumaiError) IsNoRecord() bool {
|
||||
return e.Code == "404"
|
||||
}
|
||||
|
||||
// IsParamError 是否参数错误
|
||||
func (e *ShumaiError) IsParamError() bool {
|
||||
return e.Code == "400"
|
||||
}
|
||||
|
||||
// IsAuthError 是否认证错误
|
||||
func (e *ShumaiError) IsAuthError() bool {
|
||||
return e.Code == "601" || e.Code == "602"
|
||||
}
|
||||
|
||||
// IsSystemError 是否系统错误
|
||||
func (e *ShumaiError) IsSystemError() bool {
|
||||
return e.Code == "500" || e.Code == "501"
|
||||
}
|
||||
|
||||
// 预定义错误
|
||||
var (
|
||||
ErrSuccess = &ShumaiError{Code: "200", Message: "成功"}
|
||||
ErrParamError = &ShumaiError{Code: "400", Message: "参数错误"}
|
||||
ErrNoRecord = &ShumaiError{Code: "404", Message: "请求资源不存在"}
|
||||
ErrSystemError = &ShumaiError{Code: "500", Message: "系统内部错误,请联系服务商"}
|
||||
ErrThirdPartyError = &ShumaiError{Code: "501", Message: "第三方服务异常"}
|
||||
ErrNoPermission = &ShumaiError{Code: "601", Message: "服务商未开通接口权限"}
|
||||
ErrAccountDisabled = &ShumaiError{Code: "602", Message: "账号停用"}
|
||||
ErrInsufficientBalance = &ShumaiError{Code: "603", Message: "余额不足请充值"}
|
||||
ErrInterfaceDisabled = &ShumaiError{Code: "604", Message: "接口停用"}
|
||||
ErrInsufficientQuota = &ShumaiError{Code: "605", Message: "次数不足,请购买套餐"}
|
||||
ErrRateLimitExceeded = &ShumaiError{Code: "606", Message: "调用超限,请联系服务商"}
|
||||
ErrOther = &ShumaiError{Code: "1001", Message: "其他,以实际返回为准"}
|
||||
)
|
||||
|
||||
// NewShumaiError 创建数脉错误
|
||||
func NewShumaiError(code, message string) *ShumaiError {
|
||||
return &ShumaiError{Code: code, Message: message}
|
||||
}
|
||||
|
||||
// NewShumaiErrorFromCode 根据状态码创建错误
|
||||
func NewShumaiErrorFromCode(code string) *ShumaiError {
|
||||
switch code {
|
||||
case "0", "200":
|
||||
return ErrSuccess
|
||||
case "400":
|
||||
return ErrParamError
|
||||
case "404":
|
||||
return ErrNoRecord
|
||||
case "500":
|
||||
return ErrSystemError
|
||||
case "501":
|
||||
return ErrThirdPartyError
|
||||
case "601":
|
||||
return ErrNoPermission
|
||||
case "602":
|
||||
return ErrAccountDisabled
|
||||
case "603":
|
||||
return ErrInsufficientBalance
|
||||
case "604":
|
||||
return ErrInterfaceDisabled
|
||||
case "605":
|
||||
return ErrInsufficientQuota
|
||||
case "606":
|
||||
return ErrRateLimitExceeded
|
||||
case "1001":
|
||||
return ErrOther
|
||||
default:
|
||||
return &ShumaiError{Code: code, Message: "未知错误"}
|
||||
}
|
||||
}
|
||||
|
||||
// IsShumaiError 是否为数脉错误
|
||||
func IsShumaiError(err error) bool {
|
||||
_, ok := err.(*ShumaiError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetShumaiError 获取数脉错误
|
||||
func GetShumaiError(err error) *ShumaiError {
|
||||
if e, ok := err.(*ShumaiError); ok {
|
||||
return e
|
||||
}
|
||||
return nil
|
||||
}
|
||||
69
internal/infrastructure/external/shumai/shumai_factory.go
vendored
Normal file
69
internal/infrastructure/external/shumai/shumai_factory.go
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
package shumai
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewShumaiServiceWithConfig 使用 config 创建数脉服务
|
||||
func NewShumaiServiceWithConfig(cfg *config.Config) (*ShumaiService, error) {
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Shumai.Logging.Enabled,
|
||||
LogDir: cfg.Shumai.Logging.LogDir,
|
||||
ServiceName: "shumai",
|
||||
UseDaily: cfg.Shumai.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Shumai.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
for k, v := range cfg.Shumai.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[k] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: v.MaxSize,
|
||||
MaxBackups: v.MaxBackups,
|
||||
MaxAge: v.MaxAge,
|
||||
Compress: v.Compress,
|
||||
}
|
||||
}
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var signMethod SignMethod
|
||||
if cfg.Shumai.SignMethod == "md5" {
|
||||
signMethod = SignMethodMD5
|
||||
} else {
|
||||
signMethod = SignMethodHMACMD5
|
||||
}
|
||||
timeout := 60 * time.Second
|
||||
if cfg.Shumai.Timeout > 0 {
|
||||
timeout = cfg.Shumai.Timeout
|
||||
}
|
||||
|
||||
return NewShumaiService(
|
||||
cfg.Shumai.URL,
|
||||
cfg.Shumai.AppID,
|
||||
cfg.Shumai.AppSecret,
|
||||
signMethod,
|
||||
timeout,
|
||||
logger,
|
||||
cfg.Shumai.AppID2, // 走政务接口使用这个
|
||||
cfg.Shumai.AppSecret2, // 走政务接口使用这个
|
||||
), nil
|
||||
}
|
||||
|
||||
// NewShumaiServiceWithLogging 使用自定义日志配置创建数脉服务
|
||||
func NewShumaiServiceWithLogging(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig, appID2, appSecret2 string) (*ShumaiService, error) {
|
||||
loggingConfig.ServiceName = "shumai"
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewShumaiService(url, appID, appSecret, signMethod, timeout, logger, appID2, appSecret2), nil
|
||||
}
|
||||
|
||||
// NewShumaiServiceSimple 创建无数脉日志的数脉服务
|
||||
func NewShumaiServiceSimple(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, appID2, appSecret2 string) *ShumaiService {
|
||||
return NewShumaiService(url, appID, appSecret, signMethod, timeout, nil, appID2, appSecret2)
|
||||
}
|
||||
360
internal/infrastructure/external/shumai/shumai_service.go
vendored
Normal file
360
internal/infrastructure/external/shumai/shumai_service.go
vendored
Normal file
@@ -0,0 +1,360 @@
|
||||
package shumai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志
|
||||
maxLogParamValueLen = 300
|
||||
// 错误日志中 response_body 的最大长度
|
||||
maxLogResponseBodyLen = 500
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
ErrNotFound = errors.New("查询为空")
|
||||
)
|
||||
|
||||
// truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容
|
||||
func truncateForLog(s string, maxLen int) string {
|
||||
if maxLen <= 0 {
|
||||
return s
|
||||
}
|
||||
if len(s) <= maxLen {
|
||||
return s
|
||||
}
|
||||
return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]"
|
||||
}
|
||||
|
||||
// requestParamsForLog 返回适合写入错误日志的入参副本(长字符串会被截断)
|
||||
func requestParamsForLog(reqFormData map[string]interface{}) map[string]interface{} {
|
||||
if reqFormData == nil {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]interface{}, len(reqFormData))
|
||||
for k, v := range reqFormData {
|
||||
if v == nil {
|
||||
out[k] = nil
|
||||
continue
|
||||
}
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
out[k] = truncateForLog(val, maxLogParamValueLen)
|
||||
default:
|
||||
s := fmt.Sprint(v)
|
||||
out[k] = truncateForLog(s, maxLogParamValueLen)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ShumaiResponse 数脉 API 通用响应(占位,按实际文档调整)
|
||||
type ShumaiResponse struct {
|
||||
Code int `json:"code"` // 状态码
|
||||
Msg string `json:"msg"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// ShumaiConfig 数脉服务配置
|
||||
type ShumaiConfig struct {
|
||||
URL string
|
||||
AppID string
|
||||
AppSecret string
|
||||
AppID2 string // 走政务接口使用这个
|
||||
AppSecret2 string // 走政务接口使用这个
|
||||
SignMethod SignMethod
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// ShumaiService 数脉服务
|
||||
type ShumaiService struct {
|
||||
config ShumaiConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
useGovernment bool // 是否使用政务接口(app_id2)
|
||||
}
|
||||
|
||||
// NewShumaiService 创建数脉服务实例
|
||||
// appID2 和 appSecret2 用于政务接口,如果为空则只使用普通接口
|
||||
func NewShumaiService(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger, appID2, appSecret2 string) *ShumaiService {
|
||||
if signMethod == "" {
|
||||
signMethod = SignMethodHMACMD5
|
||||
}
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
return &ShumaiService{
|
||||
config: ShumaiConfig{
|
||||
URL: url,
|
||||
AppID: appID,
|
||||
AppSecret: appSecret,
|
||||
AppID2: appID2, // 走政务接口使用这个
|
||||
AppSecret2: appSecret2, // 走政务接口使用这个
|
||||
SignMethod: signMethod,
|
||||
Timeout: timeout,
|
||||
},
|
||||
logger: logger,
|
||||
useGovernment: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ShumaiService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
appID := s.getCurrentAppID()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, appID)))
|
||||
return fmt.Sprintf("shumai_%x", hash[:8])
|
||||
}
|
||||
|
||||
// generateRequestIDWithAppID 根据指定的 AppID 生成请求ID(用于不依赖全局状态的情况)
|
||||
func (s *ShumaiService) generateRequestIDWithAppID(appID string) string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, appID)))
|
||||
return fmt.Sprintf("shumai_%x", hash[:8])
|
||||
}
|
||||
|
||||
// getCurrentAppID 获取当前使用的 AppID
|
||||
func (s *ShumaiService) getCurrentAppID() string {
|
||||
if s.useGovernment && s.config.AppID2 != "" {
|
||||
return s.config.AppID2
|
||||
}
|
||||
return s.config.AppID
|
||||
}
|
||||
|
||||
// getCurrentAppSecret 获取当前使用的 AppSecret
|
||||
func (s *ShumaiService) getCurrentAppSecret() string {
|
||||
if s.useGovernment && s.config.AppSecret2 != "" {
|
||||
return s.config.AppSecret2
|
||||
}
|
||||
return s.config.AppSecret
|
||||
}
|
||||
|
||||
// UseGovernment 切换到政务接口(使用 app_id2 和 app_secret2)
|
||||
func (s *ShumaiService) UseGovernment() {
|
||||
s.useGovernment = true
|
||||
}
|
||||
|
||||
// UseNormal 切换到普通接口(使用 app_id 和 app_secret)
|
||||
func (s *ShumaiService) UseNormal() {
|
||||
s.useGovernment = false
|
||||
}
|
||||
|
||||
// IsUsingGovernment 检查是否正在使用政务接口
|
||||
func (s *ShumaiService) IsUsingGovernment() bool {
|
||||
return s.useGovernment
|
||||
}
|
||||
|
||||
// GetConfig 返回当前配置
|
||||
func (s *ShumaiService) GetConfig() ShumaiConfig {
|
||||
return s.config
|
||||
}
|
||||
|
||||
// CallAPIForm 以表单方式调用数脉 API(application/x-www-form-urlencoded)
|
||||
// 在方法内部将 reqFormData 转为表单:先写入业务参数,再追加 appid、timestamp、sign。
|
||||
// 签名算法:md5(appid×tamp&app_security),32 位小写,不足补 0。
|
||||
// useGovernment 可选参数:true 表示使用政务接口(app_id2),false 表示使用实时接口(app_id)
|
||||
// 如果未提供参数,则使用全局状态(通过 UseGovernment()/UseNormal() 设置)
|
||||
func (s *ShumaiService) CallAPIForm(ctx context.Context, apiPath string, reqFormData map[string]interface{}, useGovernment ...bool) ([]byte, error) {
|
||||
// 确定是否使用政务接口:如果提供了参数则使用参数值,否则使用全局状态
|
||||
var useGov bool
|
||||
if len(useGovernment) > 0 {
|
||||
useGov = useGovernment[0]
|
||||
} else {
|
||||
// 未提供参数时,使用全局状态以保持向后兼容
|
||||
useGov = s.useGovernment
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
|
||||
|
||||
// 根据参数选择使用的 AppID 和 AppSecret,而不是依赖全局状态
|
||||
var appID, appSecret string
|
||||
if useGov && s.config.AppID2 != "" {
|
||||
appID = s.config.AppID2
|
||||
appSecret = s.config.AppSecret2
|
||||
} else {
|
||||
appID = s.config.AppID
|
||||
appSecret = s.config.AppSecret
|
||||
}
|
||||
|
||||
// 使用指定的 AppID 生成请求ID
|
||||
requestID := s.generateRequestIDWithAppID(appID)
|
||||
sign := GenerateSignForm(appID, timestamp, appSecret)
|
||||
|
||||
var transactionID string
|
||||
if id, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = id
|
||||
}
|
||||
|
||||
form := url.Values{}
|
||||
form.Set("appid", appID)
|
||||
form.Set("timestamp", timestamp)
|
||||
form.Set("sign", sign)
|
||||
for k, v := range reqFormData {
|
||||
if v == nil {
|
||||
continue
|
||||
}
|
||||
form.Set(k, fmt.Sprint(v))
|
||||
}
|
||||
body := form.Encode()
|
||||
|
||||
baseURL := strings.TrimSuffix(s.config.URL, "/")
|
||||
|
||||
reqURL := baseURL
|
||||
if apiPath != "" {
|
||||
reqURL = baseURL + "/" + strings.TrimPrefix(apiPath, "/")
|
||||
}
|
||||
if apiPath == "" {
|
||||
apiPath = "shumai_form"
|
||||
}
|
||||
if s.logger != nil {
|
||||
s.logger.LogRequest(requestID, transactionID, apiPath, reqURL)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(body))
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, map[string]interface{}{"request_params": requestParamsForLog(reqFormData)})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
|
||||
client := &http.Client{Timeout: s.config.Timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
isTimeout := ctx.Err() == context.DeadlineExceeded
|
||||
if !isTimeout {
|
||||
if te, ok := err.(interface{ Timeout() bool }); ok && te.Timeout() {
|
||||
isTimeout = true
|
||||
}
|
||||
}
|
||||
if !isTimeout {
|
||||
es := err.Error()
|
||||
if strings.Contains(es, "deadline exceeded") || strings.Contains(es, "timeout") || strings.Contains(es, "canceled") {
|
||||
isTimeout = true
|
||||
}
|
||||
}
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
}
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, map[string]interface{}{"request_params": requestParamsForLog(reqFormData)})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
duration := time.Since(startTime)
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, map[string]interface{}{"request_params": requestParamsForLog(reqFormData)})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP %d", resp.StatusCode))
|
||||
if s.logger != nil {
|
||||
errorPayload := map[string]interface{}{
|
||||
"request_params": requestParamsForLog(reqFormData),
|
||||
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
|
||||
}
|
||||
s.logger.LogError(requestID, transactionID, apiPath, err, errorPayload)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if s.logger != nil {
|
||||
s.logger.LogResponse(requestID, transactionID, apiPath, resp.StatusCode, duration)
|
||||
}
|
||||
|
||||
var shumaiResp ShumaiResponse
|
||||
if err := json.Unmarshal(raw, &shumaiResp); err != nil {
|
||||
parseErr := errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, parseErr, map[string]interface{}{
|
||||
"request_params": requestParamsForLog(reqFormData),
|
||||
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
|
||||
})
|
||||
}
|
||||
return nil, parseErr
|
||||
}
|
||||
|
||||
codeStr := strconv.Itoa(shumaiResp.Code)
|
||||
msg := shumaiResp.Msg
|
||||
if msg == "" {
|
||||
msg = shumaiResp.Message
|
||||
}
|
||||
|
||||
shumaiErr := NewShumaiErrorFromCode(codeStr)
|
||||
if !shumaiErr.IsSuccess() {
|
||||
if shumaiErr.Message == "未知错误" && msg != "" {
|
||||
shumaiErr = NewShumaiError(codeStr, msg)
|
||||
}
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, shumaiErr, map[string]interface{}{
|
||||
"request_params": requestParamsForLog(reqFormData),
|
||||
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
|
||||
})
|
||||
}
|
||||
if shumaiErr.IsNoRecord() {
|
||||
return nil, errors.Join(ErrNotFound, shumaiErr)
|
||||
}
|
||||
return nil, errors.Join(ErrDatasource, shumaiErr)
|
||||
}
|
||||
|
||||
if shumaiResp.Data == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
|
||||
dataBytes, err := json.Marshal(shumaiResp.Data)
|
||||
if err != nil {
|
||||
marshalErr := errors.Join(ErrSystem, fmt.Errorf("data 序列化失败: %w", err))
|
||||
if s.logger != nil {
|
||||
s.logger.LogError(requestID, transactionID, apiPath, marshalErr, map[string]interface{}{
|
||||
"request_params": requestParamsForLog(reqFormData),
|
||||
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
|
||||
})
|
||||
}
|
||||
return nil, marshalErr
|
||||
}
|
||||
return dataBytes, nil
|
||||
}
|
||||
|
||||
func (s *ShumaiService) Encrypt(data string) (string, error) {
|
||||
appSecret := s.getCurrentAppSecret()
|
||||
encryptedValue, err := Encrypt(data, appSecret)
|
||||
if err != nil {
|
||||
return "", ErrSystem
|
||||
}
|
||||
return encryptedValue, nil
|
||||
}
|
||||
|
||||
func (s *ShumaiService) Decrypt(encodedData string) ([]byte, error) {
|
||||
appSecret := s.getCurrentAppSecret()
|
||||
decryptedValue, err := Decrypt(encodedData, appSecret)
|
||||
if err != nil {
|
||||
return nil, ErrSystem
|
||||
}
|
||||
return decryptedValue, nil
|
||||
}
|
||||
148
internal/infrastructure/external/sms/aliyun_sms.go
vendored
Normal file
148
internal/infrastructure/external/sms/aliyun_sms.go
vendored
Normal file
@@ -0,0 +1,148 @@
|
||||
package sms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"time"
|
||||
|
||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
)
|
||||
|
||||
// AliSMSService 阿里云短信服务
|
||||
type AliSMSService struct {
|
||||
client *dysmsapi.Client
|
||||
config config.SMSConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAliSMSService 创建阿里云短信服务
|
||||
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
|
||||
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
|
||||
}
|
||||
return &AliSMSService{
|
||||
client: client,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
request := dysmsapi.CreateSendSmsRequest()
|
||||
request.Scheme = "https"
|
||||
request.PhoneNumbers = phone
|
||||
request.SignName = s.config.SignName
|
||||
request.TemplateCode = s.config.TemplateCode
|
||||
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
|
||||
|
||||
response, err := s.client.SendSms(request)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to send SMS",
|
||||
zap.String("phone", phone),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("短信发送失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Code != "OK" {
|
||||
s.logger.Error("SMS send failed",
|
||||
zap.String("phone", phone),
|
||||
zap.String("code", response.Code),
|
||||
zap.String("message", response.Message))
|
||||
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
|
||||
}
|
||||
|
||||
s.logger.Info("SMS sent successfully",
|
||||
zap.String("phone", phone),
|
||||
zap.String("bizId", response.BizId))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendBalanceAlert 发送余额预警短信(低余额与欠费共用 balance_alert_template_code;模板需包含 name、time、money)
|
||||
func (s *AliSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
|
||||
request := dysmsapi.CreateSendSmsRequest()
|
||||
request.Scheme = "https"
|
||||
request.PhoneNumbers = phone
|
||||
request.SignName = s.config.SignName
|
||||
|
||||
name := "海宇数据用户"
|
||||
if len(enterpriseName) > 0 && enterpriseName[0] != "" {
|
||||
name = enterpriseName[0]
|
||||
}
|
||||
t := time.Now().Format("2006-01-02 15:04:05")
|
||||
var money float64
|
||||
if alertType == "low_balance" {
|
||||
money = threshold
|
||||
} else {
|
||||
money = balance
|
||||
}
|
||||
|
||||
templateCode := s.config.BalanceAlertTemplateCode
|
||||
if templateCode == "" {
|
||||
templateCode = "SMS_500565339"
|
||||
}
|
||||
tp, err := json.Marshal(struct {
|
||||
Name string `json:"name"`
|
||||
Time string `json:"time"`
|
||||
Money string `json:"money"`
|
||||
}{Name: name, Time: t, Money: fmt.Sprintf("%.2f", money)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("构建短信模板参数失败: %w", err)
|
||||
}
|
||||
request.TemplateCode = templateCode
|
||||
request.TemplateParam = string(tp)
|
||||
|
||||
response, err := s.client.SendSms(request)
|
||||
if err != nil {
|
||||
s.logger.Error("发送余额预警短信失败",
|
||||
zap.String("phone", phone),
|
||||
zap.String("alert_type", alertType),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("短信发送失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Code != "OK" {
|
||||
s.logger.Error("余额预警短信发送失败",
|
||||
zap.String("phone", phone),
|
||||
zap.String("alert_type", alertType),
|
||||
zap.String("code", response.Code),
|
||||
zap.String("message", response.Message))
|
||||
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
|
||||
}
|
||||
|
||||
s.logger.Info("余额预警短信发送成功",
|
||||
zap.String("phone", phone),
|
||||
zap.String("alert_type", alertType),
|
||||
zap.String("bizId", response.BizId))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成验证码
|
||||
func (s *AliSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
|
||||
max := big.NewInt(int64(pow10(length)))
|
||||
n, _ := rand.Int(rand.Reader, max)
|
||||
|
||||
format := fmt.Sprintf("%%0%dd", length)
|
||||
return fmt.Sprintf(format, n.Int64())
|
||||
}
|
||||
|
||||
func pow10(n int) int {
|
||||
result := 1
|
||||
for i := 0; i < n; i++ {
|
||||
result *= 10
|
||||
}
|
||||
return result
|
||||
}
|
||||
48
internal/infrastructure/external/sms/mock_sms.go
vendored
Normal file
48
internal/infrastructure/external/sms/mock_sms.go
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
package sms
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// MockSMSService 模拟短信服务(用于开发和测试)
|
||||
type MockSMSService struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewMockSMSService 创建模拟短信服务
|
||||
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
|
||||
return &MockSMSService{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendVerificationCode 模拟发送验证码
|
||||
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
s.logger.Info("Mock SMS sent",
|
||||
zap.String("phone", phone),
|
||||
zap.String("code", code))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendBalanceAlert 模拟余额预警
|
||||
func (s *MockSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
|
||||
s.logger.Info("Mock balance alert SMS",
|
||||
zap.String("phone", phone),
|
||||
zap.Float64("balance", balance),
|
||||
zap.String("alert_type", alertType))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成验证码
|
||||
func (s *MockSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
result := ""
|
||||
for i := 0; i < length; i++ {
|
||||
result += "1"
|
||||
}
|
||||
return result
|
||||
}
|
||||
38
internal/infrastructure/external/sms/sender.go
vendored
Normal file
38
internal/infrastructure/external/sms/sender.go
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
package sms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
)
|
||||
|
||||
// SMSSender 短信发送抽象(验证码 + 余额预警),支持阿里云与腾讯云等实现。
|
||||
type SMSSender interface {
|
||||
SendVerificationCode(ctx context.Context, phone, code string) error
|
||||
SendBalanceAlert(ctx context.Context, phone string, balance, threshold float64, alertType string, enterpriseName ...string) error
|
||||
GenerateCode(length int) string
|
||||
}
|
||||
|
||||
// NewSMSSender 根据 sms.provider 创建实现;mock_enabled 时返回模拟发送器。
|
||||
// provider 为空时默认 tencent。
|
||||
func NewSMSSender(cfg config.SMSConfig, logger *zap.Logger) (SMSSender, error) {
|
||||
if cfg.MockEnabled {
|
||||
return NewMockSMSService(logger), nil
|
||||
}
|
||||
p := strings.ToLower(strings.TrimSpace(cfg.Provider))
|
||||
if p == "" {
|
||||
p = "tencent"
|
||||
}
|
||||
switch p {
|
||||
case "tencent":
|
||||
return NewTencentSMSService(cfg, logger)
|
||||
case "aliyun", "alicloud", "ali":
|
||||
return NewAliSMSService(cfg, logger)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的短信服务商: %s(支持 aliyun、tencent)", cfg.Provider)
|
||||
}
|
||||
}
|
||||
187
internal/infrastructure/external/sms/tencent_sms.go
vendored
Normal file
187
internal/infrastructure/external/sms/tencent_sms.go
vendored
Normal file
@@ -0,0 +1,187 @@
|
||||
package sms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
sms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
)
|
||||
|
||||
// TencentSMSService 腾讯云短信(与 bdqr-server 接入方式一致)
|
||||
type TencentSMSService struct {
|
||||
client *sms.Client
|
||||
cfg config.SMSConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewTencentSMSService 创建腾讯云短信客户端
|
||||
func NewTencentSMSService(cfg config.SMSConfig, logger *zap.Logger) (*TencentSMSService, error) {
|
||||
tc := cfg.TencentCloud
|
||||
if tc.SecretId == "" || tc.SecretKey == "" {
|
||||
return nil, fmt.Errorf("腾讯云短信未配置 secret_id / secret_key")
|
||||
}
|
||||
credential := common.NewCredential(tc.SecretId, tc.SecretKey)
|
||||
cpf := profile.NewClientProfile()
|
||||
cpf.HttpProfile.ReqMethod = "POST"
|
||||
cpf.HttpProfile.ReqTimeout = 10
|
||||
cpf.HttpProfile.Endpoint = "sms.tencentcloudapi.com"
|
||||
if tc.Endpoint != "" {
|
||||
cpf.HttpProfile.Endpoint = tc.Endpoint
|
||||
}
|
||||
|
||||
region := tc.Region
|
||||
if region == "" {
|
||||
region = "ap-guangzhou"
|
||||
}
|
||||
|
||||
client, err := sms.NewClient(credential, region, cpf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建腾讯云短信客户端失败: %w", err)
|
||||
}
|
||||
|
||||
return &TencentSMSService{
|
||||
client: client,
|
||||
cfg: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func normalizeTencentPhone(phone string) string {
|
||||
if strings.HasPrefix(phone, "+86") {
|
||||
return phone
|
||||
}
|
||||
return "+86" + phone
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码(模板参数为单个验证码,与 bdqr 一致)
|
||||
func (s *TencentSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
tc := s.cfg.TencentCloud
|
||||
request := &sms.SendSmsRequest{}
|
||||
request.SmsSdkAppId = common.StringPtr(tc.SmsSdkAppId)
|
||||
request.SignName = common.StringPtr(tc.SignName)
|
||||
request.TemplateId = common.StringPtr(tc.TemplateID)
|
||||
request.TemplateParamSet = common.StringPtrs([]string{code})
|
||||
request.PhoneNumberSet = common.StringPtrs([]string{normalizeTencentPhone(phone)})
|
||||
|
||||
response, err := s.client.SendSms(request)
|
||||
if err != nil {
|
||||
s.logger.Error("腾讯云短信发送失败",
|
||||
zap.String("phone", phone),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("短信发送失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Response == nil || len(response.Response.SendStatusSet) == 0 {
|
||||
return fmt.Errorf("腾讯云短信返回空响应")
|
||||
}
|
||||
|
||||
st := response.Response.SendStatusSet[0]
|
||||
if st.Code == nil || *st.Code != "Ok" {
|
||||
msg := ""
|
||||
if st.Message != nil {
|
||||
msg = *st.Message
|
||||
}
|
||||
s.logger.Error("腾讯云短信业务失败",
|
||||
zap.String("phone", phone),
|
||||
zap.String("message", msg))
|
||||
return fmt.Errorf("短信发送失败: %s", msg)
|
||||
}
|
||||
|
||||
s.logger.Info("腾讯云短信发送成功",
|
||||
zap.String("phone", phone),
|
||||
zap.String("serial_no", safeStrPtr(st.SerialNo)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendBalanceAlert 发送余额类预警。低余额与欠费使用不同模板(见 low_balance_template_id / arrears_template_id),
|
||||
// 若未分别配置则回退 balance_alert_template_id。除验证码外,腾讯云短信按无变量模板发送。
|
||||
func (s *TencentSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
|
||||
tc := s.cfg.TencentCloud
|
||||
tplID := resolveTencentBalanceTemplateID(tc, alertType)
|
||||
if tplID == "" {
|
||||
return fmt.Errorf("腾讯云余额类短信模板未配置(请设置 sms.tencent_cloud.low_balance_template_id 与 arrears_template_id,或回退 balance_alert_template_id)")
|
||||
}
|
||||
|
||||
request := &sms.SendSmsRequest{}
|
||||
request.SmsSdkAppId = common.StringPtr(tc.SmsSdkAppId)
|
||||
request.SignName = common.StringPtr(tc.SignName)
|
||||
request.TemplateId = common.StringPtr(tplID)
|
||||
request.PhoneNumberSet = common.StringPtrs([]string{normalizeTencentPhone(phone)})
|
||||
|
||||
response, err := s.client.SendSms(request)
|
||||
if err != nil {
|
||||
s.logger.Error("腾讯云余额预警短信失败",
|
||||
zap.String("phone", phone),
|
||||
zap.String("alert_type", alertType),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("短信发送失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Response == nil || len(response.Response.SendStatusSet) == 0 {
|
||||
return fmt.Errorf("腾讯云短信返回空响应")
|
||||
}
|
||||
|
||||
st := response.Response.SendStatusSet[0]
|
||||
if st.Code == nil || *st.Code != "Ok" {
|
||||
msg := ""
|
||||
if st.Message != nil {
|
||||
msg = *st.Message
|
||||
}
|
||||
return fmt.Errorf("短信发送失败: %s", msg)
|
||||
}
|
||||
|
||||
s.logger.Info("腾讯云余额预警短信发送成功",
|
||||
zap.String("phone", phone),
|
||||
zap.String("alert_type", alertType))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成数字验证码
|
||||
func (s *TencentSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
max := big.NewInt(int64(pow10Tencent(length)))
|
||||
n, _ := rand.Int(rand.Reader, max)
|
||||
format := fmt.Sprintf("%%0%dd", length)
|
||||
return fmt.Sprintf(format, n.Int64())
|
||||
}
|
||||
|
||||
func safeStrPtr(p *string) string {
|
||||
if p == nil {
|
||||
return ""
|
||||
}
|
||||
return *p
|
||||
}
|
||||
|
||||
func pow10Tencent(n int) int {
|
||||
result := 1
|
||||
for i := 0; i < n; i++ {
|
||||
result *= 10
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func resolveTencentBalanceTemplateID(tc config.TencentSMSConfig, alertType string) string {
|
||||
switch alertType {
|
||||
case "low_balance":
|
||||
if tc.LowBalanceTemplateID != "" {
|
||||
return tc.LowBalanceTemplateID
|
||||
}
|
||||
case "arrears":
|
||||
if tc.ArrearsTemplateID != "" {
|
||||
return tc.ArrearsTemplateID
|
||||
}
|
||||
}
|
||||
return tc.BalanceAlertTemplateID
|
||||
}
|
||||
115
internal/infrastructure/external/storage/local_file_storage_service.go
vendored
Normal file
115
internal/infrastructure/external/storage/local_file_storage_service.go
vendored
Normal file
@@ -0,0 +1,115 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LocalFileStorageService 本地文件存储服务
|
||||
type LocalFileStorageService struct {
|
||||
basePath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// LocalFileStorageConfig 本地文件存储配置
|
||||
type LocalFileStorageConfig struct {
|
||||
BasePath string `yaml:"base_path"`
|
||||
}
|
||||
|
||||
// NewLocalFileStorageService 创建本地文件存储服务
|
||||
func NewLocalFileStorageService(basePath string, logger *zap.Logger) *LocalFileStorageService {
|
||||
// 确保基础路径存在
|
||||
if err := os.MkdirAll(basePath, 0755); err != nil {
|
||||
logger.Error("创建基础存储目录失败", zap.Error(err), zap.String("path", basePath))
|
||||
}
|
||||
|
||||
return &LocalFileStorageService{
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// StoreFile 存储文件
|
||||
func (s *LocalFileStorageService) StoreFile(ctx context.Context, file io.Reader, filename string) (string, error) {
|
||||
// 构建完整文件路径
|
||||
fullPath := filepath.Join(s.basePath, filename)
|
||||
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
s.logger.Error("创建目录失败", zap.Error(err), zap.String("dir", dir))
|
||||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建文件
|
||||
dst, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
s.logger.Error("创建文件失败", zap.Error(err), zap.String("path", fullPath))
|
||||
return "", fmt.Errorf("创建文件失败: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// 复制文件内容
|
||||
if _, err := io.Copy(dst, file); err != nil {
|
||||
s.logger.Error("写入文件失败", zap.Error(err), zap.String("path", fullPath))
|
||||
// 删除部分写入的文件
|
||||
_ = os.Remove(fullPath)
|
||||
return "", fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件存储成功", zap.String("path", fullPath))
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// StoreMultipartFile 存储multipart文件
|
||||
func (s *LocalFileStorageService) StoreMultipartFile(ctx context.Context, file *multipart.FileHeader, filename string) (string, error) {
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("打开上传文件失败: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
return s.StoreFile(ctx, src, filename)
|
||||
}
|
||||
|
||||
// GetFileURL 获取文件URL
|
||||
func (s *LocalFileStorageService) GetFileURL(ctx context.Context, filePath string) (string, error) {
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("文件不存在: %s", filePath)
|
||||
}
|
||||
|
||||
// 返回文件路径(在实际应用中,这里应该返回可访问的URL)
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件
|
||||
func (s *LocalFileStorageService) DeleteFile(ctx context.Context, filePath string) error {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// 文件不存在,不视为错误
|
||||
return nil
|
||||
}
|
||||
s.logger.Error("删除文件失败", zap.Error(err), zap.String("path", filePath))
|
||||
return fmt.Errorf("删除文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件删除成功", zap.String("path", filePath))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFileReader 获取文件读取器
|
||||
func (s *LocalFileStorageService) GetFileReader(ctx context.Context, filePath string) (io.ReadCloser, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开文件失败: %w", err)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
110
internal/infrastructure/external/storage/local_file_storage_service_impl.go
vendored
Normal file
110
internal/infrastructure/external/storage/local_file_storage_service_impl.go
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// LocalFileStorageServiceImpl 本地文件存储服务实现
|
||||
type LocalFileStorageServiceImpl struct {
|
||||
basePath string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewLocalFileStorageServiceImpl 创建本地文件存储服务实现
|
||||
func NewLocalFileStorageServiceImpl(basePath string, logger *zap.Logger) *LocalFileStorageServiceImpl {
|
||||
// 确保基础路径存在
|
||||
if err := os.MkdirAll(basePath, 0755); err != nil {
|
||||
logger.Error("创建基础存储目录失败", zap.Error(err), zap.String("path", basePath))
|
||||
}
|
||||
|
||||
return &LocalFileStorageServiceImpl{
|
||||
basePath: basePath,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// StoreFile 存储文件
|
||||
func (s *LocalFileStorageServiceImpl) StoreFile(ctx context.Context, file io.Reader, filename string) (string, error) {
|
||||
// 构建完整文件路径
|
||||
fullPath := filepath.Join(s.basePath, filename)
|
||||
|
||||
// 确保目录存在
|
||||
dir := filepath.Dir(fullPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
s.logger.Error("创建目录失败", zap.Error(err), zap.String("dir", dir))
|
||||
return "", fmt.Errorf("创建目录失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建文件
|
||||
dst, err := os.Create(fullPath)
|
||||
if err != nil {
|
||||
s.logger.Error("创建文件失败", zap.Error(err), zap.String("path", fullPath))
|
||||
return "", fmt.Errorf("创建文件失败: %w", err)
|
||||
}
|
||||
defer dst.Close()
|
||||
|
||||
// 复制文件内容
|
||||
if _, err := io.Copy(dst, file); err != nil {
|
||||
s.logger.Error("写入文件失败", zap.Error(err), zap.String("path", fullPath))
|
||||
// 删除部分写入的文件
|
||||
_ = os.Remove(fullPath)
|
||||
return "", fmt.Errorf("写入文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件存储成功", zap.String("path", fullPath))
|
||||
return fullPath, nil
|
||||
}
|
||||
|
||||
// StoreMultipartFile 存储multipart文件
|
||||
func (s *LocalFileStorageServiceImpl) StoreMultipartFile(ctx context.Context, file *multipart.FileHeader, filename string) (string, error) {
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("打开上传文件失败: %w", err)
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
return s.StoreFile(ctx, src, filename)
|
||||
}
|
||||
|
||||
// GetFileURL 获取文件URL
|
||||
func (s *LocalFileStorageServiceImpl) GetFileURL(ctx context.Context, filePath string) (string, error) {
|
||||
// 检查文件是否存在
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
return "", fmt.Errorf("文件不存在: %s", filePath)
|
||||
}
|
||||
|
||||
// 返回文件路径(在实际应用中,这里应该返回可访问的URL)
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件
|
||||
func (s *LocalFileStorageServiceImpl) DeleteFile(ctx context.Context, filePath string) error {
|
||||
if err := os.Remove(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// 文件不存在,不视为错误
|
||||
return nil
|
||||
}
|
||||
s.logger.Error("删除文件失败", zap.Error(err), zap.String("path", filePath))
|
||||
return fmt.Errorf("删除文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件删除成功", zap.String("path", filePath))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetFileReader 获取文件读取器
|
||||
func (s *LocalFileStorageServiceImpl) GetFileReader(ctx context.Context, filePath string) (io.ReadCloser, error) {
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("打开文件失败: %w", err)
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
353
internal/infrastructure/external/storage/qiniu_storage_service.go
vendored
Normal file
353
internal/infrastructure/external/storage/qiniu_storage_service.go
vendored
Normal file
@@ -0,0 +1,353 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
"go.uber.org/zap"
|
||||
|
||||
sharedStorage "hyapi-server/internal/shared/storage"
|
||||
)
|
||||
|
||||
// QiNiuStorageService 七牛云存储服务
|
||||
type QiNiuStorageService struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
bucket string
|
||||
domain string
|
||||
logger *zap.Logger
|
||||
mac *qbox.Mac
|
||||
bucketManager *storage.BucketManager
|
||||
}
|
||||
|
||||
// QiNiuStorageConfig 七牛云存储配置
|
||||
type QiNiuStorageConfig struct {
|
||||
AccessKey string `yaml:"access_key"`
|
||||
SecretKey string `yaml:"secret_key"`
|
||||
Bucket string `yaml:"bucket"`
|
||||
Domain string `yaml:"domain"`
|
||||
}
|
||||
|
||||
// NewQiNiuStorageService 创建七牛云存储服务
|
||||
func NewQiNiuStorageService(accessKey, secretKey, bucket, domain string, logger *zap.Logger) *QiNiuStorageService {
|
||||
mac := qbox.NewMac(accessKey, secretKey)
|
||||
// 使用默认配置,不需要指定region
|
||||
cfg := storage.Config{}
|
||||
bucketManager := storage.NewBucketManager(mac, &cfg)
|
||||
|
||||
return &QiNiuStorageService{
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
bucket: bucket,
|
||||
domain: domain,
|
||||
logger: logger,
|
||||
mac: mac,
|
||||
bucketManager: bucketManager,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到七牛云
|
||||
func (s *QiNiuStorageService) UploadFile(ctx context.Context, fileBytes []byte, fileName string) (*sharedStorage.UploadResult, error) {
|
||||
s.logger.Info("开始上传文件到七牛云",
|
||||
zap.String("file_name", fileName),
|
||||
zap.Int("file_size", len(fileBytes)),
|
||||
)
|
||||
|
||||
// 生成唯一的文件key
|
||||
key := s.generateFileKey(fileName)
|
||||
|
||||
// 创建上传凭证
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: s.bucket,
|
||||
}
|
||||
upToken := putPolicy.UploadToken(s.mac)
|
||||
|
||||
// 配置上传参数
|
||||
cfg := storage.Config{}
|
||||
formUploader := storage.NewFormUploader(&cfg)
|
||||
ret := storage.PutRet{}
|
||||
|
||||
// 上传文件
|
||||
err := formUploader.Put(ctx, &ret, upToken, key, strings.NewReader(string(fileBytes)), int64(len(fileBytes)), &storage.PutExtra{})
|
||||
if err != nil {
|
||||
s.logger.Error("文件上传失败",
|
||||
zap.String("file_name", fileName),
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("文件上传失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建文件URL
|
||||
fileURL := s.GetFileURL(ctx, key)
|
||||
|
||||
s.logger.Info("文件上传成功",
|
||||
zap.String("file_name", fileName),
|
||||
zap.String("key", key),
|
||||
zap.String("url", fileURL),
|
||||
)
|
||||
|
||||
return &sharedStorage.UploadResult{
|
||||
Key: key,
|
||||
URL: fileURL,
|
||||
MimeType: s.getMimeType(fileName),
|
||||
Size: int64(len(fileBytes)),
|
||||
Hash: ret.Hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateUploadToken 生成上传凭证
|
||||
func (s *QiNiuStorageService) GenerateUploadToken(ctx context.Context, key string) (string, error) {
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: s.bucket,
|
||||
// 设置过期时间(1小时)
|
||||
Expires: uint64(time.Now().Add(time.Hour).Unix()),
|
||||
}
|
||||
|
||||
token := putPolicy.UploadToken(s.mac)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetFileURL 获取文件访问URL
|
||||
func (s *QiNiuStorageService) GetFileURL(ctx context.Context, key string) string {
|
||||
// 如果是私有空间,需要生成带签名的URL
|
||||
if s.isPrivateBucket() {
|
||||
deadline := time.Now().Add(time.Hour).Unix() // 1小时过期
|
||||
privateAccessURL := storage.MakePrivateURL(s.mac, s.domain, key, deadline)
|
||||
return privateAccessURL
|
||||
}
|
||||
|
||||
// 公开空间直接返回URL
|
||||
return fmt.Sprintf("%s/%s", s.domain, key)
|
||||
}
|
||||
|
||||
// GetPrivateFileURL 获取私有文件访问URL
|
||||
func (s *QiNiuStorageService) GetPrivateFileURL(ctx context.Context, key string, expires int64) (string, error) {
|
||||
baseURL := s.GetFileURL(ctx, key)
|
||||
|
||||
// TODO: 实际集成七牛云SDK生成私有URL
|
||||
s.logger.Info("生成七牛云私有文件URL",
|
||||
zap.String("key", key),
|
||||
zap.Int64("expires", expires),
|
||||
)
|
||||
|
||||
// 模拟返回私有URL
|
||||
return fmt.Sprintf("%s?token=mock_private_token&expires=%d", baseURL, expires), nil
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件
|
||||
func (s *QiNiuStorageService) DeleteFile(ctx context.Context, key string) error {
|
||||
s.logger.Info("删除七牛云文件", zap.String("key", key))
|
||||
|
||||
err := s.bucketManager.Delete(s.bucket, key)
|
||||
if err != nil {
|
||||
s.logger.Error("删除文件失败",
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("删除文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件删除成功", zap.String("key", key))
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileExists 检查文件是否存在
|
||||
func (s *QiNiuStorageService) FileExists(ctx context.Context, key string) (bool, error) {
|
||||
// TODO: 实际集成七牛云SDK检查文件存在性
|
||||
s.logger.Info("检查七牛云文件存在性", zap.String("key", key))
|
||||
|
||||
// 模拟文件存在
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetFileInfo 获取文件信息
|
||||
func (s *QiNiuStorageService) GetFileInfo(ctx context.Context, key string) (*sharedStorage.FileInfo, error) {
|
||||
fileInfo, err := s.bucketManager.Stat(s.bucket, key)
|
||||
if err != nil {
|
||||
s.logger.Error("获取文件信息失败",
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取文件信息失败: %w", err)
|
||||
}
|
||||
|
||||
return &sharedStorage.FileInfo{
|
||||
Key: key,
|
||||
Size: fileInfo.Fsize,
|
||||
MimeType: fileInfo.MimeType,
|
||||
Hash: fileInfo.Hash,
|
||||
PutTime: fileInfo.PutTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListFiles 列出文件
|
||||
func (s *QiNiuStorageService) ListFiles(ctx context.Context, prefix string, limit int) ([]*sharedStorage.FileInfo, error) {
|
||||
entries, _, _, hasMore, err := s.bucketManager.ListFiles(s.bucket, prefix, "", "", limit)
|
||||
if err != nil {
|
||||
s.logger.Error("列出文件失败",
|
||||
zap.String("prefix", prefix),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("列出文件失败: %w", err)
|
||||
}
|
||||
|
||||
var fileInfos []*sharedStorage.FileInfo
|
||||
for _, entry := range entries {
|
||||
fileInfo := &sharedStorage.FileInfo{
|
||||
Key: entry.Key,
|
||||
Size: entry.Fsize,
|
||||
MimeType: entry.MimeType,
|
||||
Hash: entry.Hash,
|
||||
PutTime: entry.PutTime,
|
||||
}
|
||||
fileInfos = append(fileInfos, fileInfo)
|
||||
}
|
||||
|
||||
_ = hasMore // 暂时忽略hasMore
|
||||
return fileInfos, nil
|
||||
}
|
||||
|
||||
// generateFileKey 生成文件key
|
||||
func (s *QiNiuStorageService) generateFileKey(fileName string) string {
|
||||
// 生成时间戳
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
// 生成随机字符串
|
||||
randomStr := fmt.Sprintf("%d", time.Now().UnixNano()%1000000)
|
||||
// 获取文件扩展名
|
||||
ext := filepath.Ext(fileName)
|
||||
// 构建key: 日期/时间戳_随机数.扩展名
|
||||
key := fmt.Sprintf("certification/%s/%s_%s%s",
|
||||
time.Now().Format("20060102"), timestamp, randomStr, ext)
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// getMimeType 根据文件名获取MIME类型
|
||||
func (s *QiNiuStorageService) getMimeType(fileName string) string {
|
||||
ext := strings.ToLower(filepath.Ext(fileName))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".bmp":
|
||||
return "image/bmp"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateBucket 判断是否为私有空间
|
||||
func (s *QiNiuStorageService) isPrivateBucket() bool {
|
||||
// 这里可以根据配置或域名特征判断
|
||||
// 私有空间的域名通常包含特定标识
|
||||
return strings.Contains(s.domain, "private") ||
|
||||
strings.Contains(s.domain, "auth") ||
|
||||
strings.Contains(s.domain, "secure")
|
||||
}
|
||||
|
||||
// generateSignature 生成签名(用于私有空间访问)
|
||||
func (s *QiNiuStorageService) generateSignature(data string) string {
|
||||
h := hmac.New(sha1.New, []byte(s.secretKey))
|
||||
h.Write([]byte(data))
|
||||
return base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// UploadFromReader 从Reader上传文件
|
||||
func (s *QiNiuStorageService) UploadFromReader(ctx context.Context, reader io.Reader, fileName string, fileSize int64) (*sharedStorage.UploadResult, error) {
|
||||
// 读取文件内容
|
||||
fileBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
return s.UploadFile(ctx, fileBytes, fileName)
|
||||
}
|
||||
|
||||
// DownloadFile 从七牛云下载文件
|
||||
func (s *QiNiuStorageService) DownloadFile(ctx context.Context, fileURL string) ([]byte, error) {
|
||||
s.logger.Info("开始从七牛云下载文件", zap.String("file_url", fileURL))
|
||||
|
||||
// 创建HTTP客户端,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error();
|
||||
errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
errorMsg := "下载文件失败"
|
||||
if isTimeout {
|
||||
errorMsg = "下载文件超时"
|
||||
}
|
||||
s.logger.Error(errorMsg,
|
||||
zap.String("file_url", fileURL),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("%s: %w", errorMsg, err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
s.logger.Error("下载文件失败,状态码异常",
|
||||
zap.String("file_url", fileURL),
|
||||
zap.Int("status_code", resp.StatusCode),
|
||||
)
|
||||
return nil, fmt.Errorf("下载文件失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
fileContent, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
s.logger.Error("读取文件内容失败",
|
||||
zap.String("file_url", fileURL),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("读取文件内容失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件下载成功",
|
||||
zap.String("file_url", fileURL),
|
||||
zap.Int("file_size", len(fileContent)),
|
||||
)
|
||||
|
||||
return fileContent, nil
|
||||
}
|
||||
183
internal/infrastructure/external/tianyancha/tianyancha_service.go
vendored
Normal file
183
internal/infrastructure/external/tianyancha/tianyancha_service.go
vendored
Normal file
@@ -0,0 +1,183 @@
|
||||
package tianyancha
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrNotFound = errors.New("查询为空")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
ErrInvalidParam = errors.New("参数错误")
|
||||
)
|
||||
|
||||
// APIEndpoints 天眼查 API 端点映射
|
||||
var APIEndpoints = map[string]string{
|
||||
"VerifyThreeElements": "/open/ic/verify/2.0", // 企业三要素验证
|
||||
"InvestHistory": "/open/hi/invest/2.0", // 对外投资历史
|
||||
"FinancingHistory": "/open/cd/findHistoryRongzi/2.0", // 融资历史
|
||||
"PunishmentInfo": "/open/mr/punishmentInfo/3.0", // 行政处罚
|
||||
"AbnormalInfo": "/open/mr/abnormal/2.0", // 经营异常
|
||||
"OwnTax": "/open/mr/ownTax/2.0", // 欠税公告
|
||||
"TaxContravention": "/open/mr/taxContravention/2.0", // 税收违法
|
||||
"holderChange": "/open/ic/holderChange/2.0", // 股权变更
|
||||
"baseinfo": "/open/ic/baseinfo/normal", // 企业基本信息
|
||||
"investtree": "/v3/open/investtree", // 股权穿透
|
||||
}
|
||||
|
||||
// TianYanChaConfig 天眼查配置
|
||||
type TianYanChaConfig struct {
|
||||
BaseURL string
|
||||
Token string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
// TianYanChaService 天眼查服务
|
||||
type TianYanChaService struct {
|
||||
config TianYanChaConfig
|
||||
}
|
||||
|
||||
// APIResponse 标准API响应结构
|
||||
type APIResponse struct {
|
||||
Success bool `json:"success"`
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// TianYanChaResponse 天眼查原始响应结构
|
||||
type TianYanChaResponse struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
Reason string `json:"reason"`
|
||||
Result interface{} `json:"result"`
|
||||
}
|
||||
|
||||
// NewTianYanChaService 创建天眼查服务实例
|
||||
func NewTianYanChaService(baseURL, token string, timeout time.Duration) *TianYanChaService {
|
||||
if timeout == 0 {
|
||||
timeout = 60 * time.Second
|
||||
}
|
||||
|
||||
return &TianYanChaService{
|
||||
config: TianYanChaConfig{
|
||||
BaseURL: baseURL,
|
||||
Token: token,
|
||||
Timeout: timeout,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// CallAPI 调用天眼查API - 通用方法,由外部处理器传入具体参数
|
||||
func (t *TianYanChaService) CallAPI(ctx context.Context, apiCode string, params map[string]string) (*APIResponse, error) {
|
||||
// 从映射中获取 API 端点
|
||||
endpoint, exists := APIEndpoints[apiCode]
|
||||
if !exists {
|
||||
return nil, errors.Join(ErrInvalidParam, fmt.Errorf("未找到 API 代码对应的端点: %s", apiCode))
|
||||
}
|
||||
|
||||
// 构建完整 URL
|
||||
fullURL := strings.TrimRight(t.config.BaseURL, "/") + "/" + strings.TrimLeft(endpoint, "/")
|
||||
|
||||
// 检查 Token 是否配置
|
||||
if t.config.Token == "" {
|
||||
return nil, errors.Join(ErrSystem, fmt.Errorf("天眼查 API Token 未配置"))
|
||||
}
|
||||
|
||||
// 构建查询参数
|
||||
queryParams := url.Values{}
|
||||
for key, value := range params {
|
||||
queryParams.Set(key, value)
|
||||
}
|
||||
|
||||
// 构建完整URL
|
||||
requestURL := fullURL
|
||||
if len(queryParams) > 0 {
|
||||
requestURL += "?" + queryParams.Encode()
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", requestURL, nil)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrSystem, fmt.Errorf("创建请求失败: %v", err))
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Authorization", t.config.Token)
|
||||
|
||||
// 发送请求
|
||||
client := &http.Client{Timeout: t.config.Timeout}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
return nil, errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
|
||||
}
|
||||
return nil, errors.Join(ErrDatasource, fmt.Errorf("API 请求异常: %v", err))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查 HTTP 状态码
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, errors.Join(ErrDatasource, fmt.Errorf("API 请求失败,状态码: %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.Join(ErrSystem, fmt.Errorf("读取响应体失败: %v", err))
|
||||
}
|
||||
|
||||
// 解析 JSON 响应
|
||||
var tianYanChaResp TianYanChaResponse
|
||||
if err := json.Unmarshal(body, &tianYanChaResp); err != nil {
|
||||
return nil, errors.Join(ErrSystem, fmt.Errorf("解析响应 JSON 失败: %v", err))
|
||||
}
|
||||
|
||||
// 检查天眼查业务状态码
|
||||
if tianYanChaResp.ErrorCode != 0 {
|
||||
// 特殊处理:ErrorCode 300000 表示查询为空,返回成功但数据为空数组
|
||||
if tianYanChaResp.ErrorCode == 300000 {
|
||||
return &APIResponse{
|
||||
Success: true,
|
||||
Code: 0,
|
||||
Message: "",
|
||||
Data: []interface{}{}, // 返回空数组而不是nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &APIResponse{
|
||||
Success: false,
|
||||
Code: tianYanChaResp.ErrorCode,
|
||||
Message: tianYanChaResp.Reason,
|
||||
Data: tianYanChaResp.Result,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 成功情况
|
||||
return &APIResponse{
|
||||
Success: true,
|
||||
Code: 0,
|
||||
Message: tianYanChaResp.Reason,
|
||||
Data: tianYanChaResp.Result,
|
||||
}, nil
|
||||
}
|
||||
160
internal/infrastructure/external/westdex/crypto.go
vendored
Normal file
160
internal/infrastructure/external/westdex/crypto.go
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
package westdex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
const (
|
||||
KEY_SIZE = 16 // AES-128, 16 bytes
|
||||
)
|
||||
|
||||
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
|
||||
func Encrypt(data, secretKey string) (string, error) {
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
ciphertext, err := aesEncrypt([]byte(data), key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
|
||||
func Decrypt(encodedData, secretKey string) ([]byte, error) {
|
||||
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
|
||||
plaintext, err := aesDecrypt(ciphertext, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
|
||||
func generateAESKey(length int, password []byte) []byte {
|
||||
h := sha1.New()
|
||||
h.Write(password)
|
||||
state := h.Sum(nil)
|
||||
|
||||
keyBytes := make([]byte, 0, length/8)
|
||||
for len(keyBytes) < length/8 {
|
||||
h := sha1.New()
|
||||
h.Write(state)
|
||||
state = h.Sum(nil)
|
||||
keyBytes = append(keyBytes, state...)
|
||||
}
|
||||
|
||||
return keyBytes[:length/8]
|
||||
}
|
||||
|
||||
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
|
||||
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
|
||||
ciphertext := make([]byte, len(paddedPlaintext))
|
||||
mode := newECBEncrypter(block)
|
||||
mode.CryptBlocks(ciphertext, paddedPlaintext)
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
|
||||
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext := make([]byte, len(ciphertext))
|
||||
mode := newECBDecrypter(block)
|
||||
mode.CryptBlocks(plaintext, ciphertext)
|
||||
return pkcs5Unpadding(plaintext), nil
|
||||
}
|
||||
|
||||
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
|
||||
func pkcs5Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs5Unpadding removes PKCS5 padding from the input
|
||||
func pkcs5Unpadding(src []byte) []byte {
|
||||
length := len(src)
|
||||
unpadding := int(src[length-1])
|
||||
return src[:(length - unpadding)]
|
||||
}
|
||||
|
||||
// ECB mode encryption/decryption
|
||||
type ecb struct {
|
||||
b cipher.Block
|
||||
blockSize int
|
||||
}
|
||||
|
||||
func newECB(b cipher.Block) *ecb {
|
||||
return &ecb{
|
||||
b: b,
|
||||
blockSize: b.BlockSize(),
|
||||
}
|
||||
}
|
||||
|
||||
type ecbEncrypter ecb
|
||||
|
||||
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbEncrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Encrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
type ecbDecrypter ecb
|
||||
|
||||
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
|
||||
return (*ecbDecrypter)(newECB(b))
|
||||
}
|
||||
|
||||
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
|
||||
|
||||
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
|
||||
if len(src)%x.blockSize != 0 {
|
||||
panic("crypto/cipher: input not full blocks")
|
||||
}
|
||||
if len(dst) < len(src) {
|
||||
panic("crypto/cipher: output smaller than input")
|
||||
}
|
||||
for len(src) > 0 {
|
||||
x.b.Decrypt(dst, src[:x.blockSize])
|
||||
src = src[x.blockSize:]
|
||||
dst = dst[x.blockSize:]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Md5Encrypt 用于对传入的message进行MD5加密
|
||||
func Md5Encrypt(message string) string {
|
||||
hash := md5.New()
|
||||
hash.Write([]byte(message)) // 将字符串转换为字节切片并写入
|
||||
return hex.EncodeToString(hash.Sum(nil)) // 将哈希值转换为16进制字符串并返回
|
||||
}
|
||||
63
internal/infrastructure/external/westdex/westdex_factory.go
vendored
Normal file
63
internal/infrastructure/external/westdex/westdex_factory.go
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
package westdex
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewWestDexServiceWithConfig 使用配置创建西部数据服务
|
||||
func NewWestDexServiceWithConfig(cfg *config.Config) (*WestDexService, error) {
|
||||
// 将配置类型转换为通用外部服务日志配置
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.WestDex.Logging.Enabled,
|
||||
LogDir: cfg.WestDex.Logging.LogDir,
|
||||
ServiceName: "westdex",
|
||||
UseDaily: cfg.WestDex.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.WestDex.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
|
||||
// 转换级别配置
|
||||
for key, value := range cfg.WestDex.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: value.MaxSize,
|
||||
MaxBackups: value.MaxBackups,
|
||||
MaxAge: value.MaxAge,
|
||||
Compress: value.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建西部数据服务
|
||||
service := NewWestDexService(
|
||||
cfg.WestDex.URL,
|
||||
cfg.WestDex.Key,
|
||||
cfg.WestDex.SecretID,
|
||||
cfg.WestDex.SecretSecondID,
|
||||
logger,
|
||||
)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewWestDexServiceWithLogging 使用自定义日志配置创建西部数据服务
|
||||
func NewWestDexServiceWithLogging(url, key, secretID, secretSecondID string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*WestDexService, error) {
|
||||
// 设置服务名称
|
||||
loggingConfig.ServiceName = "westdex"
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建西部数据服务
|
||||
service := NewWestDexService(url, key, secretID, secretSecondID, logger)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
418
internal/infrastructure/external/westdex/westdex_service.go
vendored
Normal file
418
internal/infrastructure/external/westdex/westdex_service.go
vendored
Normal file
@@ -0,0 +1,418 @@
|
||||
package westdex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/crypto"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
ErrNotFound = errors.New("查询为空")
|
||||
)
|
||||
|
||||
type WestResp struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
Data string `json:"data"`
|
||||
ID string `json:"id"`
|
||||
ErrorCode *int `json:"error_code"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type G05HZ01WestResp struct {
|
||||
Message string `json:"message"`
|
||||
Code string `json:"code"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
ID string `json:"id"`
|
||||
ErrorCode *int `json:"error_code"`
|
||||
Reason string `json:"reason"`
|
||||
}
|
||||
|
||||
type WestConfig struct {
|
||||
Url string
|
||||
Key string
|
||||
SecretID string
|
||||
SecretSecondID string
|
||||
}
|
||||
|
||||
type WestDexService struct {
|
||||
config WestConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewWestDexService 是一个构造函数,用于初始化 WestDexService
|
||||
func NewWestDexService(url, key, secretID, secretSecondID string, logger *external_logger.ExternalServiceLogger) *WestDexService {
|
||||
return &WestDexService{
|
||||
config: WestConfig{
|
||||
Url: url,
|
||||
Key: key,
|
||||
SecretID: secretID,
|
||||
SecretSecondID: secretSecondID,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (w *WestDexService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, w.config.Key)))
|
||||
return fmt.Sprintf("westdex_%x", hash[:8])
|
||||
}
|
||||
|
||||
// buildRequestURL 构建请求URL
|
||||
func (w *WestDexService) buildRequestURL(code string) string {
|
||||
timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10)
|
||||
return fmt.Sprintf("%s/%s/%s?timestamp=%s", w.config.Url, w.config.SecretID, code, timestamp)
|
||||
}
|
||||
|
||||
// CallAPI 调用西部数据的 API
|
||||
func (w *WestDexService) CallAPI(ctx context.Context, code string, reqData map[string]interface{}) (resp []byte, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := w.generateRequestID()
|
||||
|
||||
// 从ctx中获取transactionId
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
// 构建请求URL
|
||||
reqUrl := w.buildRequestURL(code)
|
||||
|
||||
// 记录请求日志
|
||||
if w.logger != nil {
|
||||
w.logger.LogRequest(requestID, transactionID, code, reqUrl)
|
||||
}
|
||||
|
||||
jsonData, marshalErr := json.Marshal(reqData)
|
||||
if marshalErr != nil {
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", reqUrl, bytes.NewBuffer(jsonData))
|
||||
if newRequestErr != nil {
|
||||
err = errors.Join(ErrSystem, newRequestErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
httpResp, clientDoErr := client.Do(req)
|
||||
if clientDoErr != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, clientDoErr)
|
||||
}
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
closeErr := Body.Close()
|
||||
if closeErr != nil {
|
||||
// 记录关闭错误
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), reqData)
|
||||
}
|
||||
}
|
||||
}(httpResp.Body)
|
||||
|
||||
// 计算请求耗时
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 检查请求是否成功
|
||||
if httpResp.StatusCode == 200 {
|
||||
// 读取响应体
|
||||
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
|
||||
if ReadErr != nil {
|
||||
err = errors.Join(ErrSystem, ReadErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 手动调用 json.Unmarshal 触发自定义的 UnmarshalJSON 方法
|
||||
var westDexResp WestResp
|
||||
UnmarshalErr := json.Unmarshal(bodyBytes, &westDexResp)
|
||||
if UnmarshalErr != nil {
|
||||
err = errors.Join(ErrSystem, UnmarshalErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if w.logger != nil {
|
||||
w.logger.LogResponseWithID(requestID, transactionID, code, httpResp.StatusCode, duration, westDexResp.ID)
|
||||
}
|
||||
|
||||
if westDexResp.Code != "00000" && westDexResp.Code != "200" && westDexResp.Code != "0" {
|
||||
if westDexResp.Data == "" {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf(westDexResp.Message))
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
decryptedData, DecryptErr := crypto.WestDexDecrypt(westDexResp.Data, w.config.Key)
|
||||
if DecryptErr != nil {
|
||||
err = errors.Join(ErrSystem, DecryptErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录业务错误日志,包含响应ID
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, errors.Join(ErrDatasource, fmt.Errorf(westDexResp.Message)), reqData, westDexResp.ID)
|
||||
}
|
||||
|
||||
// 记录性能日志(失败)
|
||||
// 注意:通用日志系统不包含性能日志功能
|
||||
|
||||
return decryptedData, errors.Join(ErrDatasource, fmt.Errorf(westDexResp.Message))
|
||||
}
|
||||
|
||||
if westDexResp.Data == "" {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf(westDexResp.Message))
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
decryptedData, DecryptErr := crypto.WestDexDecrypt(westDexResp.Data, w.config.Key)
|
||||
if DecryptErr != nil {
|
||||
err = errors.Join(ErrSystem, DecryptErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录性能日志(成功)
|
||||
// 注意:通用日志系统不包含性能日志功能
|
||||
|
||||
return decryptedData, nil
|
||||
}
|
||||
|
||||
// 记录HTTP错误
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("西部请求失败Code: %d", httpResp.StatusCode))
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
// 注意:通用日志系统不包含性能日志功能
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// G05HZ01CallAPI 调用西部数据的 G05HZ01 API
|
||||
func (w *WestDexService) G05HZ01CallAPI(ctx context.Context, code string, reqData map[string]interface{}) (resp []byte, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := w.generateRequestID()
|
||||
|
||||
// 从ctx中获取transactionId
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
// 构建请求URL
|
||||
reqUrl := fmt.Sprintf("%s/%s/%s?timestamp=%d", w.config.Url, w.config.SecretSecondID, code, time.Now().UnixNano()/int64(time.Millisecond))
|
||||
|
||||
// 记录请求日志
|
||||
if w.logger != nil {
|
||||
w.logger.LogRequest(requestID, transactionID, code, reqUrl)
|
||||
}
|
||||
|
||||
jsonData, marshalErr := json.Marshal(reqData)
|
||||
if marshalErr != nil {
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", reqUrl, bytes.NewBuffer(jsonData))
|
||||
if newRequestErr != nil {
|
||||
err = errors.Join(ErrSystem, newRequestErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// 发送请求,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
httpResp, clientDoErr := client.Do(req)
|
||||
if clientDoErr != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, clientDoErr)
|
||||
}
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
closeErr := Body.Close()
|
||||
if closeErr != nil {
|
||||
// 记录关闭错误
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), reqData)
|
||||
}
|
||||
}
|
||||
}(httpResp.Body)
|
||||
|
||||
// 计算请求耗时
|
||||
duration := time.Since(startTime)
|
||||
|
||||
if httpResp.StatusCode == 200 {
|
||||
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
|
||||
if ReadErr != nil {
|
||||
err = errors.Join(ErrSystem, ReadErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var westDexResp G05HZ01WestResp
|
||||
UnmarshalErr := json.Unmarshal(bodyBytes, &westDexResp)
|
||||
if UnmarshalErr != nil {
|
||||
err = errors.Join(ErrSystem, UnmarshalErr)
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if w.logger != nil {
|
||||
w.logger.LogResponseWithID(requestID, transactionID, code, httpResp.StatusCode, duration, westDexResp.ID)
|
||||
}
|
||||
|
||||
if westDexResp.Code != "0000" {
|
||||
if westDexResp.Data == nil || westDexResp.Code == "1404" {
|
||||
err = errors.Join(ErrNotFound, fmt.Errorf(westDexResp.Message))
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
|
||||
}
|
||||
return nil, err
|
||||
} else {
|
||||
// 记录业务错误日志,包含响应ID
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, errors.Join(ErrSystem, fmt.Errorf(string(westDexResp.Data))), reqData, westDexResp.ID)
|
||||
}
|
||||
|
||||
// 记录性能日志(失败)
|
||||
// 注意:通用日志系统不包含性能日志功能
|
||||
|
||||
return westDexResp.Data, errors.Join(ErrSystem, fmt.Errorf(string(westDexResp.Data)))
|
||||
}
|
||||
}
|
||||
|
||||
if westDexResp.Data == nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf(westDexResp.Message))
|
||||
if w.logger != nil {
|
||||
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录性能日志(成功)
|
||||
// 注意:通用日志系统不包含性能日志功能
|
||||
|
||||
return westDexResp.Data, nil
|
||||
} else {
|
||||
// 记录HTTP错误
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("西部请求失败Code: %d", httpResp.StatusCode))
|
||||
if w.logger != nil {
|
||||
w.logger.LogError(requestID, transactionID, code, err, reqData)
|
||||
// 注意:通用日志系统不包含性能日志功能
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WestDexService) Encrypt(data string) (string, error) {
|
||||
encryptedValue, err := crypto.WestDexEncrypt(data, w.config.Key)
|
||||
if err != nil {
|
||||
return "", ErrSystem
|
||||
}
|
||||
|
||||
return encryptedValue, nil
|
||||
}
|
||||
func (w *WestDexService) Md5Encrypt(data string) string {
|
||||
result := Md5Encrypt(data)
|
||||
return result
|
||||
}
|
||||
|
||||
func (w *WestDexService) GetConfig() WestConfig {
|
||||
return w.config
|
||||
}
|
||||
62
internal/infrastructure/external/xingwei/xingwei_factory.go
vendored
Normal file
62
internal/infrastructure/external/xingwei/xingwei_factory.go
vendored
Normal file
@@ -0,0 +1,62 @@
|
||||
package xingwei
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewXingweiServiceWithConfig 使用配置创建行为数据服务
|
||||
func NewXingweiServiceWithConfig(cfg *config.Config) (*XingweiService, error) {
|
||||
// 将配置类型转换为通用外部服务日志配置
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Xingwei.Logging.Enabled,
|
||||
LogDir: cfg.Xingwei.Logging.LogDir,
|
||||
ServiceName: "xingwei",
|
||||
UseDaily: cfg.Xingwei.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Xingwei.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
|
||||
// 转换级别配置
|
||||
for key, value := range cfg.Xingwei.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: value.MaxSize,
|
||||
MaxBackups: value.MaxBackups,
|
||||
MaxAge: value.MaxAge,
|
||||
Compress: value.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建行为数据服务
|
||||
service := NewXingweiService(
|
||||
cfg.Xingwei.URL,
|
||||
cfg.Xingwei.ApiID,
|
||||
cfg.Xingwei.ApiKey,
|
||||
logger,
|
||||
)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewXingweiServiceWithLogging 使用自定义日志配置创建行为数据服务
|
||||
func NewXingweiServiceWithLogging(url, apiID, apiKey string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*XingweiService, error) {
|
||||
// 设置服务名称
|
||||
loggingConfig.ServiceName = "xingwei"
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建行为数据服务
|
||||
service := NewXingweiService(url, apiID, apiKey, logger)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
296
internal/infrastructure/external/xingwei/xingwei_service.go
vendored
Normal file
296
internal/infrastructure/external/xingwei/xingwei_service.go
vendored
Normal file
@@ -0,0 +1,296 @@
|
||||
package xingwei
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// 行为数据API状态码常量
|
||||
const (
|
||||
CodeSuccess = 200 // 操作成功
|
||||
CodeSystemError = 500 // 系统内部错误
|
||||
CodeMerchantError = 3001 // 商家相关报错(商家不存在、商家被禁用、商家余额不足)
|
||||
CodeAccountExpired = 3002 // 账户已过期
|
||||
CodeIPWhitelistMissing = 3003 // 未添加ip白名单
|
||||
CodeUnauthorized = 3004 // 未授权调用该接口
|
||||
CodeProductIDError = 4001 // 产品id错误
|
||||
CodeInterfaceDisabled = 4002 // 接口被停用
|
||||
CodeQueryException = 5001 // 接口查询异常,请联系技术人员
|
||||
CodeNotFound = 6000 // 未查询到结果
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
ErrNotFound = errors.New("未查询到结果")
|
||||
|
||||
// 请求ID计数器,确保唯一性
|
||||
requestIDCounter int64
|
||||
)
|
||||
|
||||
// XingweiResponse 行为数据API响应结构
|
||||
type XingweiResponse struct {
|
||||
Msg string `json:"msg"`
|
||||
Code int `json:"code"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// XingweiErrorCode 行为数据错误码定义
|
||||
type XingweiErrorCode struct {
|
||||
Code int
|
||||
Message string
|
||||
}
|
||||
|
||||
// 行为数据错误码映射
|
||||
var XingweiErrorCodes = map[int]XingweiErrorCode{
|
||||
CodeSuccess: {Code: CodeSuccess, Message: "操作成功"},
|
||||
CodeSystemError: {Code: CodeSystemError, Message: "系统内部错误"},
|
||||
CodeMerchantError: {Code: CodeMerchantError, Message: "商家相关报错(商家不存在、商家被禁用、商家余额不足)"},
|
||||
CodeAccountExpired: {Code: CodeAccountExpired, Message: "账户已过期"},
|
||||
CodeIPWhitelistMissing: {Code: CodeIPWhitelistMissing, Message: "未添加ip白名单"},
|
||||
CodeUnauthorized: {Code: CodeUnauthorized, Message: "未授权调用该接口"},
|
||||
CodeProductIDError: {Code: CodeProductIDError, Message: "产品id错误"},
|
||||
CodeInterfaceDisabled: {Code: CodeInterfaceDisabled, Message: "接口被停用"},
|
||||
CodeQueryException: {Code: CodeQueryException, Message: "接口查询异常,请联系技术人员"},
|
||||
CodeNotFound: {Code: CodeNotFound, Message: "未查询到结果"},
|
||||
}
|
||||
|
||||
// GetXingweiErrorMessage 根据错误码获取错误消息
|
||||
func GetXingweiErrorMessage(code int) string {
|
||||
if errorCode, exists := XingweiErrorCodes[code]; exists {
|
||||
return errorCode.Message
|
||||
}
|
||||
return fmt.Sprintf("未知错误码: %d", code)
|
||||
}
|
||||
|
||||
type XingweiConfig struct {
|
||||
URL string
|
||||
ApiID string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
type XingweiService struct {
|
||||
config XingweiConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewXingweiService 是一个构造函数,用于初始化 XingweiService
|
||||
func NewXingweiService(url, apiID, apiKey string, logger *external_logger.ExternalServiceLogger) *XingweiService {
|
||||
return &XingweiService{
|
||||
config: XingweiConfig{
|
||||
URL: url,
|
||||
ApiID: apiID,
|
||||
ApiKey: apiKey,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (x *XingweiService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
// 使用原子计数器确保唯一性
|
||||
counter := atomic.AddInt64(&requestIDCounter, 1)
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%d_%s", timestamp, counter, x.config.ApiID)))
|
||||
return fmt.Sprintf("xingwei_%x", hash[:8])
|
||||
}
|
||||
|
||||
// createSign 创建签名:使用MD5算法将apiId、timestamp、apiKey字符串拼接生成sign
|
||||
// 参考Java示例:DigestUtils.md5Hex(apiId + timestamp + apiKey)
|
||||
func (x *XingweiService) createSign(timestamp int64) string {
|
||||
signStr := x.config.ApiID + strconv.FormatInt(timestamp, 10) + x.config.ApiKey
|
||||
hash := md5.Sum([]byte(signStr))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
|
||||
// CallAPI 调用行为数据的 API
|
||||
func (x *XingweiService) CallAPI(ctx context.Context, projectID string, params map[string]interface{}) (resp []byte, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := x.generateRequestID()
|
||||
timestamp := time.Now().UnixMilli()
|
||||
|
||||
// 从ctx中获取transactionId
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
// 记录请求日志
|
||||
if x.logger != nil {
|
||||
x.logger.LogRequest(requestID, transactionID, "xingwei_api", x.config.URL)
|
||||
}
|
||||
|
||||
// 将请求参数转换为JSON
|
||||
jsonData, marshalErr := json.Marshal(params)
|
||||
if marshalErr != nil {
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", x.config.URL, bytes.NewBuffer(jsonData))
|
||||
if newRequestErr != nil {
|
||||
err = errors.Join(ErrSystem, newRequestErr)
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("timestamp", strconv.FormatInt(timestamp, 10))
|
||||
req.Header.Set("sign", x.createSign(timestamp))
|
||||
req.Header.Set("API-ID", x.config.ApiID)
|
||||
req.Header.Set("project_id", projectID)
|
||||
|
||||
// 创建HTTP客户端,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
httpResp, clientDoErr := client.Do(req)
|
||||
if clientDoErr != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, clientDoErr)
|
||||
}
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer func(Body io.ReadCloser) {
|
||||
closeErr := Body.Close()
|
||||
if closeErr != nil {
|
||||
// 记录关闭错误
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), params)
|
||||
}
|
||||
}
|
||||
}(httpResp.Body)
|
||||
|
||||
// 计算请求耗时
|
||||
duration := time.Since(startTime)
|
||||
|
||||
// 读取响应体
|
||||
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
|
||||
if ReadErr != nil {
|
||||
err = errors.Join(ErrSystem, ReadErr)
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if x.logger != nil {
|
||||
x.logger.LogResponse(requestID, transactionID, "xingwei_api", httpResp.StatusCode, duration)
|
||||
}
|
||||
|
||||
// 检查HTTP状态码
|
||||
if httpResp.StatusCode != http.StatusOK {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("行为数据请求失败,状态码: %d", httpResp.StatusCode))
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析响应结构
|
||||
var xingweiResp XingweiResponse
|
||||
if err := json.Unmarshal(bodyBytes, &xingweiResp); err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查业务状态码
|
||||
switch xingweiResp.Code {
|
||||
case CodeSuccess:
|
||||
// 成功响应,返回data字段
|
||||
if xingweiResp.Data == nil {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
|
||||
// 将data转换为JSON字节
|
||||
dataBytes, err := json.Marshal(xingweiResp.Data)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("data字段序列化失败: %w", err))
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return dataBytes, nil
|
||||
|
||||
case CodeNotFound:
|
||||
// 未查询到结果,返回空数组
|
||||
if x.logger != nil {
|
||||
// 这里只记录有响应,不记录具体返回内容
|
||||
x.logger.LogResponse(requestID, transactionID, "xingwei_api", httpResp.StatusCode, duration)
|
||||
}
|
||||
return []byte("[]"), nil
|
||||
|
||||
case CodeSystemError:
|
||||
// 系统内部错误
|
||||
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
|
||||
systemErr := fmt.Errorf("行为数据系统错误[%d]: %s", xingweiResp.Code, errorMsg)
|
||||
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api",
|
||||
errors.Join(ErrSystem, systemErr), params)
|
||||
}
|
||||
|
||||
return nil, errors.Join(ErrSystem, systemErr)
|
||||
|
||||
default:
|
||||
// 其他业务错误
|
||||
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
|
||||
businessErr := fmt.Errorf("行为数据业务错误[%d]: %s", xingweiResp.Code, errorMsg)
|
||||
|
||||
if x.logger != nil {
|
||||
x.logger.LogError(requestID, transactionID, "xingwei_api",
|
||||
errors.Join(ErrDatasource, businessErr), params)
|
||||
}
|
||||
|
||||
return nil, errors.Join(ErrDatasource, businessErr)
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig 获取配置信息
|
||||
func (x *XingweiService) GetConfig() XingweiConfig {
|
||||
return x.config
|
||||
}
|
||||
241
internal/infrastructure/external/xingwei/xingwei_test.go
vendored
Normal file
241
internal/infrastructure/external/xingwei/xingwei_test.go
vendored
Normal file
@@ -0,0 +1,241 @@
|
||||
package xingwei
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestXingweiService_CreateSign(t *testing.T) {
|
||||
// 创建测试配置 - 使用nil logger来避免日志问题
|
||||
service := NewXingweiService(
|
||||
"https://sjztyh.chengdaoji.cn/dataCenterManageApi/manage/interface/doc/api/handle",
|
||||
"test_api_id",
|
||||
"test_api_key",
|
||||
nil, // 使用nil logger
|
||||
)
|
||||
|
||||
// 测试签名生成
|
||||
timestamp := int64(1743474772049)
|
||||
sign := service.createSign(timestamp)
|
||||
|
||||
// 验证签名不为空
|
||||
if sign == "" {
|
||||
t.Error("签名不能为空")
|
||||
}
|
||||
|
||||
// 验证签名长度(MD5应该是32位十六进制字符串)
|
||||
if len(sign) != 32 {
|
||||
t.Errorf("签名长度应该是32位,实际是%d位", len(sign))
|
||||
}
|
||||
|
||||
t.Logf("生成的签名: %s", sign)
|
||||
}
|
||||
|
||||
func TestXingweiService_CallAPI(t *testing.T) {
|
||||
// 创建测试配置 - 使用nil logger来避免日志问题
|
||||
service := NewXingweiService(
|
||||
"https://sjztyh.chengdaoji.cn/dataCenterManageApi/manage/interface/doc/api/handle",
|
||||
"test_api_id",
|
||||
"test_api_key",
|
||||
nil, // 使用nil logger
|
||||
)
|
||||
|
||||
// 创建测试上下文
|
||||
ctx := context.Background()
|
||||
|
||||
// 测试参数
|
||||
projectID := "test_project_id"
|
||||
params := map[string]interface{}{
|
||||
"test_param": "test_value",
|
||||
}
|
||||
|
||||
// 注意:这个测试会实际发送HTTP请求,所以可能会失败
|
||||
// 在实际使用中,应该使用mock或者测试服务器
|
||||
resp, err := service.CallAPI(ctx, projectID, params)
|
||||
|
||||
// 由于这是真实的外部API调用,我们主要测试错误处理
|
||||
if err != nil {
|
||||
t.Logf("预期的错误(真实API调用): %v", err)
|
||||
} else {
|
||||
t.Logf("API调用成功,响应长度: %d", len(resp))
|
||||
}
|
||||
}
|
||||
|
||||
func TestXingweiService_GenerateRequestID(t *testing.T) {
|
||||
// 创建测试配置 - 使用nil logger来避免日志问题
|
||||
service := NewXingweiService(
|
||||
"https://sjztyh.chengdaoji.cn/dataCenterManageApi/manage/interface/doc/api/handle",
|
||||
"test_api_id",
|
||||
"test_api_key",
|
||||
nil, // 使用nil logger
|
||||
)
|
||||
|
||||
// 测试请求ID生成
|
||||
requestID1 := service.generateRequestID()
|
||||
requestID2 := service.generateRequestID()
|
||||
|
||||
// 验证请求ID不为空
|
||||
if requestID1 == "" || requestID2 == "" {
|
||||
t.Error("请求ID不能为空")
|
||||
}
|
||||
|
||||
// 验证请求ID应该以xingwei_开头
|
||||
if len(requestID1) < 8 || requestID1[:8] != "xingwei_" {
|
||||
t.Error("请求ID应该以xingwei_开头")
|
||||
}
|
||||
|
||||
// 验证两次生成的请求ID应该不同
|
||||
if requestID1 == requestID2 {
|
||||
t.Error("两次生成的请求ID应该不同")
|
||||
}
|
||||
|
||||
t.Logf("请求ID1: %s", requestID1)
|
||||
t.Logf("请求ID2: %s", requestID2)
|
||||
}
|
||||
|
||||
func TestGetXingweiErrorMessage(t *testing.T) {
|
||||
// 测试已知错误码(使用常量)
|
||||
testCases := []struct {
|
||||
code int
|
||||
expected string
|
||||
}{
|
||||
{CodeSuccess, "操作成功"},
|
||||
{CodeSystemError, "系统内部错误"},
|
||||
{CodeMerchantError, "商家相关报错(商家不存在、商家被禁用、商家余额不足)"},
|
||||
{CodeAccountExpired, "账户已过期"},
|
||||
{CodeIPWhitelistMissing, "未添加ip白名单"},
|
||||
{CodeUnauthorized, "未授权调用该接口"},
|
||||
{CodeProductIDError, "产品id错误"},
|
||||
{CodeInterfaceDisabled, "接口被停用"},
|
||||
{CodeQueryException, "接口查询异常,请联系技术人员"},
|
||||
{CodeNotFound, "未查询到结果"},
|
||||
{9999, "未知错误码: 9999"}, // 测试未知错误码
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := GetXingweiErrorMessage(tc.code)
|
||||
if result != tc.expected {
|
||||
t.Errorf("错误码 %d 的消息不正确,期望: %s, 实际: %s", tc.code, tc.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestXingweiResponseParsing(t *testing.T) {
|
||||
// 测试响应结构解析
|
||||
testCases := []struct {
|
||||
name string
|
||||
response string
|
||||
expectedCode int
|
||||
}{
|
||||
{
|
||||
name: "成功响应",
|
||||
response: `{"msg": "操作成功", "code": 200, "data": {"result": "test"}}`,
|
||||
expectedCode: CodeSuccess,
|
||||
},
|
||||
{
|
||||
name: "商家错误",
|
||||
response: `{"msg": "商家相关报错", "code": 3001, "data": null}`,
|
||||
expectedCode: CodeMerchantError,
|
||||
},
|
||||
{
|
||||
name: "未查询到结果",
|
||||
response: `{"msg": "未查询到结果", "code": 6000, "data": null}`,
|
||||
expectedCode: CodeNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var resp XingweiResponse
|
||||
err := json.Unmarshal([]byte(tc.response), &resp)
|
||||
if err != nil {
|
||||
t.Errorf("解析响应失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if resp.Code != tc.expectedCode {
|
||||
t.Errorf("错误码不匹配,期望: %d, 实际: %d", tc.expectedCode, resp.Code)
|
||||
}
|
||||
|
||||
// 测试错误消息获取
|
||||
errorMsg := GetXingweiErrorMessage(resp.Code)
|
||||
if errorMsg == "" {
|
||||
t.Errorf("无法获取错误码 %d 的消息", resp.Code)
|
||||
}
|
||||
|
||||
t.Logf("响应: %+v, 错误消息: %s", resp, errorMsg)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestXingweiErrorHandling 测试错误处理逻辑
|
||||
func TestXingweiErrorHandling(t *testing.T) {
|
||||
// 注意:这个测试主要验证常量定义和错误消息,不需要实际的服务实例
|
||||
|
||||
// 测试查空错误
|
||||
t.Run("NotFound错误", func(t *testing.T) {
|
||||
// 模拟返回查空响应
|
||||
response := `{"msg": "未查询到结果", "code": 6000, "data": null}`
|
||||
var xingweiResp XingweiResponse
|
||||
err := json.Unmarshal([]byte(response), &xingweiResp)
|
||||
if err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
// 验证状态码
|
||||
if xingweiResp.Code != CodeNotFound {
|
||||
t.Errorf("期望状态码 %d, 实际 %d", CodeNotFound, xingweiResp.Code)
|
||||
}
|
||||
|
||||
// 验证错误消息
|
||||
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
|
||||
if errorMsg != "未查询到结果" {
|
||||
t.Errorf("期望错误消息 '未查询到结果', 实际 '%s'", errorMsg)
|
||||
}
|
||||
|
||||
t.Logf("查空错误测试通过: 状态码=%d, 消息=%s", xingweiResp.Code, errorMsg)
|
||||
})
|
||||
|
||||
// 测试系统错误
|
||||
t.Run("SystemError错误", func(t *testing.T) {
|
||||
response := `{"msg": "系统内部错误", "code": 500, "data": null}`
|
||||
var xingweiResp XingweiResponse
|
||||
err := json.Unmarshal([]byte(response), &xingweiResp)
|
||||
if err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if xingweiResp.Code != CodeSystemError {
|
||||
t.Errorf("期望状态码 %d, 实际 %d", CodeSystemError, xingweiResp.Code)
|
||||
}
|
||||
|
||||
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
|
||||
if errorMsg != "系统内部错误" {
|
||||
t.Errorf("期望错误消息 '系统内部错误', 实际 '%s'", errorMsg)
|
||||
}
|
||||
|
||||
t.Logf("系统错误测试通过: 状态码=%d, 消息=%s", xingweiResp.Code, errorMsg)
|
||||
})
|
||||
|
||||
// 测试成功响应
|
||||
t.Run("Success响应", func(t *testing.T) {
|
||||
response := `{"msg": "操作成功", "code": 200, "data": {"result": "test"}}`
|
||||
var xingweiResp XingweiResponse
|
||||
err := json.Unmarshal([]byte(response), &xingweiResp)
|
||||
if err != nil {
|
||||
t.Fatalf("解析响应失败: %v", err)
|
||||
}
|
||||
|
||||
if xingweiResp.Code != CodeSuccess {
|
||||
t.Errorf("期望状态码 %d, 实际 %d", CodeSuccess, xingweiResp.Code)
|
||||
}
|
||||
|
||||
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
|
||||
if errorMsg != "操作成功" {
|
||||
t.Errorf("期望错误消息 '操作成功', 实际 '%s'", errorMsg)
|
||||
}
|
||||
|
||||
t.Logf("成功响应测试通过: 状态码=%d, 消息=%s", xingweiResp.Code, errorMsg)
|
||||
})
|
||||
}
|
||||
67
internal/infrastructure/external/yushan/yushan_factory.go
vendored
Normal file
67
internal/infrastructure/external/yushan/yushan_factory.go
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
package yushan
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewYushanServiceWithConfig 使用配置创建羽山服务
|
||||
func NewYushanServiceWithConfig(cfg *config.Config) (*YushanService, error) {
|
||||
// 将配置类型转换为通用外部服务日志配置
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Yushan.Logging.Enabled,
|
||||
LogDir: cfg.Yushan.Logging.LogDir,
|
||||
ServiceName: "yushan",
|
||||
UseDaily: cfg.Yushan.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Yushan.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
|
||||
// 转换级别配置
|
||||
for key, value := range cfg.Yushan.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: value.MaxSize,
|
||||
MaxBackups: value.MaxBackups,
|
||||
MaxAge: value.MaxAge,
|
||||
Compress: value.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建羽山服务
|
||||
service := NewYushanService(
|
||||
cfg.Yushan.URL,
|
||||
cfg.Yushan.APIKey,
|
||||
cfg.Yushan.AcctID,
|
||||
logger,
|
||||
)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewYushanServiceWithLogging 使用自定义日志配置创建羽山服务
|
||||
func NewYushanServiceWithLogging(url, apiKey, acctID string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*YushanService, error) {
|
||||
// 设置服务名称
|
||||
loggingConfig.ServiceName = "yushan"
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建羽山服务
|
||||
service := NewYushanService(url, apiKey, acctID, logger)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewYushanServiceSimple 创建简单的羽山服务(无日志)
|
||||
func NewYushanServiceSimple(url, apiKey, acctID string) *YushanService {
|
||||
return NewYushanService(url, apiKey, acctID, nil)
|
||||
}
|
||||
287
internal/infrastructure/external/yushan/yushan_service.go
vendored
Normal file
287
internal/infrastructure/external/yushan/yushan_service.go
vendored
Normal file
@@ -0,0 +1,287 @@
|
||||
package yushan
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrNotFound = errors.New("查询为空")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
)
|
||||
|
||||
type YushanConfig struct {
|
||||
URL string
|
||||
ApiKey string
|
||||
AcctID string
|
||||
}
|
||||
|
||||
type YushanService struct {
|
||||
config YushanConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewYushanService 是一个构造函数,用于初始化 YushanService
|
||||
func NewYushanService(url, apiKey, acctID string, logger *external_logger.ExternalServiceLogger) *YushanService {
|
||||
return &YushanService{
|
||||
config: YushanConfig{
|
||||
URL: url,
|
||||
ApiKey: apiKey,
|
||||
AcctID: acctID,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CallAPI 调用羽山数据的 API
|
||||
func (y *YushanService) CallAPI(ctx context.Context, code string, params map[string]interface{}) (respBytes []byte, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := y.generateRequestID()
|
||||
|
||||
// 从ctx中获取transactionId
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
// 记录请求日志
|
||||
if y.logger != nil {
|
||||
y.logger.LogRequest(requestID, transactionID, code, y.config.URL)
|
||||
}
|
||||
|
||||
// 获取当前时间戳
|
||||
unixMilliseconds := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
|
||||
// 生成请求序列号
|
||||
requestSN, _ := y.GenerateRandomString()
|
||||
|
||||
// 构建请求数据
|
||||
reqData := map[string]interface{}{
|
||||
"prod_id": code,
|
||||
"req_time": unixMilliseconds,
|
||||
"request_sn": requestSN,
|
||||
"req_data": params,
|
||||
}
|
||||
|
||||
// 将请求数据转换为 JSON 字节数组
|
||||
messageBytes, err := json.Marshal(reqData)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取 API 密钥
|
||||
key, err := hex.DecodeString(y.config.ApiKey)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 使用 AES CBC 加密请求数据
|
||||
cipherText := y.AES_CBC_Encrypt(messageBytes, key)
|
||||
|
||||
// 将加密后的数据编码为 Base64 字符串
|
||||
content := base64.StdEncoding.EncodeToString(cipherText)
|
||||
|
||||
// 发起 HTTP 请求,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", y.config.URL, strings.NewReader(content))
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("ACCT_ID", y.config.AcctID)
|
||||
|
||||
// 执行请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
}
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 读取响应体
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var respData []byte
|
||||
|
||||
if IsJSON(string(body)) {
|
||||
respData = body
|
||||
} else {
|
||||
sDec, err := base64.StdEncoding.DecodeString(string(body))
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
respData = y.AES_CBC_Decrypt(sDec, key)
|
||||
}
|
||||
retCode := gjson.GetBytes(respData, "retcode").String()
|
||||
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if y.logger != nil {
|
||||
duration := time.Since(startTime)
|
||||
y.logger.LogResponse(requestID, transactionID, code, resp.StatusCode, duration)
|
||||
}
|
||||
|
||||
if retCode == "100000" {
|
||||
// retcode 为 100000,表示查询为空
|
||||
return nil, ErrNotFound
|
||||
} else if retCode == "000000" {
|
||||
// retcode 为 000000,表示有数据,返回 retdata
|
||||
retData := gjson.GetBytes(respData, "retdata")
|
||||
if !retData.Exists() {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("羽山请求retdata为空"))
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return []byte(retData.Raw), nil
|
||||
} else {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("羽山请求未知的状态码"))
|
||||
if y.logger != nil {
|
||||
y.logger.LogError(requestID, transactionID, code, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (y *YushanService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, y.config.ApiKey)))
|
||||
return fmt.Sprintf("yushan_%x", hash[:8])
|
||||
}
|
||||
|
||||
// GenerateRandomString 生成一个32位的随机字符串订单号
|
||||
func (y *YushanService) GenerateRandomString() (string, error) {
|
||||
// 创建一个16字节的数组
|
||||
bytes := make([]byte, 16)
|
||||
// 读取随机字节到数组中
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
// 将字节数组编码为16进制字符串
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// AEC加密(CBC模式)
|
||||
func (y *YushanService) AES_CBC_Encrypt(plainText []byte, key []byte) []byte {
|
||||
//指定加密算法,返回一个AES算法的Block接口对象
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//进行填充
|
||||
plainText = Padding(plainText, block.BlockSize())
|
||||
//指定初始向量vi,长度和block的块尺寸一致
|
||||
iv := []byte("0000000000000000")
|
||||
//指定分组模式,返回一个BlockMode接口对象
|
||||
blockMode := cipher.NewCBCEncrypter(block, iv)
|
||||
//加密连续数据库
|
||||
cipherText := make([]byte, len(plainText))
|
||||
blockMode.CryptBlocks(cipherText, plainText)
|
||||
//返回base64密文
|
||||
return cipherText
|
||||
}
|
||||
|
||||
// AEC解密(CBC模式)
|
||||
func (y *YushanService) AES_CBC_Decrypt(cipherText []byte, key []byte) []byte {
|
||||
//指定解密算法,返回一个AES算法的Block接口对象
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
//指定初始化向量IV,和加密的一致
|
||||
iv := []byte("0000000000000000")
|
||||
//指定分组模式,返回一个BlockMode接口对象
|
||||
blockMode := cipher.NewCBCDecrypter(block, iv)
|
||||
//解密
|
||||
plainText := make([]byte, len(cipherText))
|
||||
blockMode.CryptBlocks(plainText, cipherText)
|
||||
//删除填充
|
||||
plainText = UnPadding(plainText)
|
||||
return plainText
|
||||
} // 对明文进行填充
|
||||
func Padding(plainText []byte, blockSize int) []byte {
|
||||
//计算要填充的长度
|
||||
n := blockSize - len(plainText)%blockSize
|
||||
//对原来的明文填充n个n
|
||||
temp := bytes.Repeat([]byte{byte(n)}, n)
|
||||
plainText = append(plainText, temp...)
|
||||
return plainText
|
||||
}
|
||||
|
||||
// 对密文删除填充
|
||||
func UnPadding(cipherText []byte) []byte {
|
||||
//取出密文最后一个字节end
|
||||
end := cipherText[len(cipherText)-1]
|
||||
//删除填充
|
||||
cipherText = cipherText[:len(cipherText)-int(end)]
|
||||
return cipherText
|
||||
}
|
||||
|
||||
// 判断字符串是否为 JSON 格式
|
||||
func IsJSON(s string) bool {
|
||||
var js interface{}
|
||||
return json.Unmarshal([]byte(s), &js) == nil
|
||||
}
|
||||
83
internal/infrastructure/external/yushan/yushan_test.go
vendored
Normal file
83
internal/infrastructure/external/yushan/yushan_test.go
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
package yushan
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateRequestID(t *testing.T) {
|
||||
service := &YushanService{
|
||||
config: YushanConfig{
|
||||
ApiKey: "test_api_key_123",
|
||||
},
|
||||
}
|
||||
|
||||
id1 := service.generateRequestID()
|
||||
|
||||
// 等待一小段时间确保时间戳不同
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
id2 := service.generateRequestID()
|
||||
|
||||
if id1 == "" || id2 == "" {
|
||||
t.Error("请求ID生成失败")
|
||||
}
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("不同时间生成的请求ID应该不同")
|
||||
}
|
||||
|
||||
// 验证ID格式
|
||||
if len(id1) < 20 { // yushan_ + 8位十六进制 + 其他
|
||||
t.Errorf("请求ID长度不足,实际: %s", id1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRandomString(t *testing.T) {
|
||||
service := &YushanService{}
|
||||
|
||||
str1, err := service.GenerateRandomString()
|
||||
if err != nil {
|
||||
t.Fatalf("生成随机字符串失败: %v", err)
|
||||
}
|
||||
|
||||
str2, err := service.GenerateRandomString()
|
||||
if err != nil {
|
||||
t.Fatalf("生成随机字符串失败: %v", err)
|
||||
}
|
||||
|
||||
if str1 == "" || str2 == "" {
|
||||
t.Error("随机字符串为空")
|
||||
}
|
||||
|
||||
if str1 == str2 {
|
||||
t.Error("两次生成的随机字符串应该不同")
|
||||
}
|
||||
|
||||
// 验证长度(16字节 = 32位十六进制字符)
|
||||
if len(str1) != 32 || len(str2) != 32 {
|
||||
t.Error("随机字符串长度应该是32位")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsJSON(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
expected bool
|
||||
}{
|
||||
{"{}", true},
|
||||
{"[]", true},
|
||||
{"{\"key\": \"value\"}", true},
|
||||
{"[1, 2, 3]", true},
|
||||
{"invalid json", false},
|
||||
{"", false},
|
||||
{"{invalid}", false},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
result := IsJSON(tc.input)
|
||||
if result != tc.expected {
|
||||
t.Errorf("输入: %s, 期望: %v, 实际: %v", tc.input, tc.expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
121
internal/infrastructure/external/zhicha/crypto.go
vendored
Normal file
121
internal/infrastructure/external/zhicha/crypto.go
vendored
Normal file
@@ -0,0 +1,121 @@
|
||||
package zhicha
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
KEY_SIZE = 16 // AES-128, 16 bytes
|
||||
)
|
||||
|
||||
// Encrypt 使用AES-128-CBC加密数据
|
||||
// 对应Python示例中的encrypt函数
|
||||
func Encrypt(data, key string) (string, error) {
|
||||
// 将十六进制密钥转换为字节
|
||||
binKey, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("密钥格式错误: %w", err)
|
||||
}
|
||||
|
||||
if len(binKey) < KEY_SIZE {
|
||||
return "", fmt.Errorf("密钥长度不足,需要至少%d字节", KEY_SIZE)
|
||||
}
|
||||
|
||||
// 从密钥前16个字符生成IV
|
||||
iv := []byte(key[:KEY_SIZE])
|
||||
|
||||
// 创建AES加密器
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建AES加密器失败: %w", err)
|
||||
}
|
||||
|
||||
// 对数据进行PKCS7填充
|
||||
paddedData := pkcs7Padding([]byte(data), aes.BlockSize)
|
||||
|
||||
// 创建CBC模式加密器
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
|
||||
// 加密
|
||||
ciphertext := make([]byte, len(paddedData))
|
||||
mode.CryptBlocks(ciphertext, paddedData)
|
||||
|
||||
// 返回Base64编码结果
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt 使用AES-128-CBC解密数据
|
||||
// 对应Python示例中的decrypt函数
|
||||
func Decrypt(encryptedData, key string) (string, error) {
|
||||
// 将十六进制密钥转换为字节
|
||||
binKey, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("密钥格式错误: %w", err)
|
||||
}
|
||||
|
||||
if len(binKey) < KEY_SIZE {
|
||||
return "", fmt.Errorf("密钥长度不足,需要至少%d字节", KEY_SIZE)
|
||||
}
|
||||
|
||||
// 从密钥前16个字符生成IV
|
||||
iv := []byte(key[:KEY_SIZE])
|
||||
|
||||
// 解码Base64数据
|
||||
decodedData, err := base64.StdEncoding.DecodeString(encryptedData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Base64解码失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查数据长度是否为AES块大小的倍数
|
||||
if len(decodedData) == 0 || len(decodedData)%aes.BlockSize != 0 {
|
||||
return "", fmt.Errorf("加密数据长度无效,必须是%d字节的倍数", aes.BlockSize)
|
||||
}
|
||||
|
||||
// 创建AES解密器
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建AES解密器失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建CBC模式解密器
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
|
||||
// 解密
|
||||
plaintext := make([]byte, len(decodedData))
|
||||
mode.CryptBlocks(plaintext, decodedData)
|
||||
|
||||
// 移除PKCS7填充
|
||||
unpadded, err := pkcs7Unpadding(plaintext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("移除填充失败: %w", err)
|
||||
}
|
||||
|
||||
return string(unpadded), nil
|
||||
}
|
||||
|
||||
// pkcs7Padding 使用PKCS7填充数据
|
||||
func pkcs7Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs7Unpadding 移除PKCS7填充
|
||||
func pkcs7Unpadding(src []byte) ([]byte, error) {
|
||||
length := len(src)
|
||||
if length == 0 {
|
||||
return nil, fmt.Errorf("数据为空")
|
||||
}
|
||||
|
||||
unpadding := int(src[length-1])
|
||||
if unpadding > length {
|
||||
return nil, fmt.Errorf("填充长度无效")
|
||||
}
|
||||
|
||||
return src[:length-unpadding], nil
|
||||
}
|
||||
170
internal/infrastructure/external/zhicha/zhicha_errors.go
vendored
Normal file
170
internal/infrastructure/external/zhicha/zhicha_errors.go
vendored
Normal file
@@ -0,0 +1,170 @@
|
||||
package zhicha
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// ZhichaError 智查金控服务错误
|
||||
type ZhichaError struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (e *ZhichaError) Error() string {
|
||||
return fmt.Sprintf("智查金控错误 [%s]: %s", e.Code, e.Message)
|
||||
}
|
||||
|
||||
// IsSuccess 检查是否成功
|
||||
func (e *ZhichaError) IsSuccess() bool {
|
||||
return e.Code == "200"
|
||||
}
|
||||
|
||||
// IsNoRecord 检查是否查询无记录
|
||||
func (e *ZhichaError) IsNoRecord() bool {
|
||||
return e.Code == "201"
|
||||
}
|
||||
|
||||
// IsBusinessError 检查是否是业务错误(非系统错误)
|
||||
func (e *ZhichaError) IsBusinessError() bool {
|
||||
return e.Code >= "302" && e.Code <= "320"
|
||||
}
|
||||
|
||||
// IsSystemError 检查是否是系统错误
|
||||
func (e *ZhichaError) IsSystemError() bool {
|
||||
return e.Code == "500"
|
||||
}
|
||||
|
||||
// IsAuthError 检查是否是认证相关错误
|
||||
func (e *ZhichaError) IsAuthError() bool {
|
||||
return e.Code == "304" || e.Code == "318" || e.Code == "319" || e.Code == "320"
|
||||
}
|
||||
|
||||
// IsParamError 检查是否是参数相关错误
|
||||
func (e *ZhichaError) IsParamError() bool {
|
||||
return e.Code == "302" || e.Code == "303" || e.Code == "305" || e.Code == "306" || e.Code == "307" || e.Code == "316" || e.Code == "317"
|
||||
}
|
||||
|
||||
// IsServiceError 检查是否是服务相关错误
|
||||
func (e *ZhichaError) IsServiceError() bool {
|
||||
return e.Code == "308" || e.Code == "309" || e.Code == "310" || e.Code == "311"
|
||||
}
|
||||
|
||||
// IsUserError 检查是否是用户相关错误
|
||||
func (e *ZhichaError) IsUserError() bool {
|
||||
return e.Code == "312" || e.Code == "313" || e.Code == "314" || e.Code == "315"
|
||||
}
|
||||
|
||||
// 预定义错误常量
|
||||
var (
|
||||
// 成功状态
|
||||
ErrSuccess = &ZhichaError{Code: "200", Message: "请求成功"}
|
||||
ErrNoRecord = &ZhichaError{Code: "201", Message: "查询无记录"}
|
||||
|
||||
// 业务参数错误
|
||||
ErrBusinessParamMissing = &ZhichaError{Code: "302", Message: "业务参数缺失"}
|
||||
ErrParamError = &ZhichaError{Code: "303", Message: "参数错误"}
|
||||
ErrHeaderParamMissing = &ZhichaError{Code: "304", Message: "请求头参数缺失"}
|
||||
ErrNameError = &ZhichaError{Code: "305", Message: "姓名错误"}
|
||||
ErrPhoneError = &ZhichaError{Code: "306", Message: "手机号错误"}
|
||||
ErrIDCardError = &ZhichaError{Code: "307", Message: "身份证号错误"}
|
||||
|
||||
// 服务相关错误
|
||||
ErrServiceNotExist = &ZhichaError{Code: "308", Message: "服务不存在"}
|
||||
ErrServiceNotEnabled = &ZhichaError{Code: "309", Message: "服务未开通"}
|
||||
ErrInsufficientBalance = &ZhichaError{Code: "310", Message: "余额不足"}
|
||||
ErrRemoteDataError = &ZhichaError{Code: "311", Message: "调用远程数据异常"}
|
||||
|
||||
// 用户相关错误
|
||||
ErrUserNotExist = &ZhichaError{Code: "312", Message: "用户不存在"}
|
||||
ErrUserStatusError = &ZhichaError{Code: "313", Message: "用户状态异常"}
|
||||
ErrUserUnauthorized = &ZhichaError{Code: "314", Message: "用户未授权"}
|
||||
ErrWhitelistError = &ZhichaError{Code: "315", Message: "白名单错误"}
|
||||
|
||||
// 时间戳和认证错误
|
||||
ErrTimestampInvalid = &ZhichaError{Code: "316", Message: "timestamp不合法"}
|
||||
ErrTimestampExpired = &ZhichaError{Code: "317", Message: "timestamp已过期"}
|
||||
ErrSignVerifyFailed = &ZhichaError{Code: "318", Message: "验签失败"}
|
||||
ErrDecryptFailed = &ZhichaError{Code: "319", Message: "解密失败"}
|
||||
ErrUnauthorized = &ZhichaError{Code: "320", Message: "未授权"}
|
||||
|
||||
// 系统错误
|
||||
ErrSystemError = &ZhichaError{Code: "500", Message: "系统异常,请联系管理员"}
|
||||
)
|
||||
|
||||
// NewZhichaError 创建新的智查金控错误
|
||||
func NewZhichaError(code, message string) *ZhichaError {
|
||||
return &ZhichaError{
|
||||
Code: code,
|
||||
Message: message,
|
||||
}
|
||||
}
|
||||
|
||||
// NewZhichaErrorFromCode 根据状态码创建错误
|
||||
func NewZhichaErrorFromCode(code string) *ZhichaError {
|
||||
switch code {
|
||||
case "200":
|
||||
return ErrSuccess
|
||||
case "201":
|
||||
return ErrNoRecord
|
||||
case "302":
|
||||
return ErrBusinessParamMissing
|
||||
case "303":
|
||||
return ErrParamError
|
||||
case "304":
|
||||
return ErrHeaderParamMissing
|
||||
case "305":
|
||||
return ErrNameError
|
||||
case "306":
|
||||
return ErrPhoneError
|
||||
case "307":
|
||||
return ErrIDCardError
|
||||
case "308":
|
||||
return ErrServiceNotExist
|
||||
case "309":
|
||||
return ErrServiceNotEnabled
|
||||
case "310":
|
||||
return ErrInsufficientBalance
|
||||
case "311":
|
||||
return ErrRemoteDataError
|
||||
case "312":
|
||||
return ErrUserNotExist
|
||||
case "313":
|
||||
return ErrUserStatusError
|
||||
case "314":
|
||||
return ErrUserUnauthorized
|
||||
case "315":
|
||||
return ErrWhitelistError
|
||||
case "316":
|
||||
return ErrTimestampInvalid
|
||||
case "317":
|
||||
return ErrTimestampExpired
|
||||
case "318":
|
||||
return ErrSignVerifyFailed
|
||||
case "319":
|
||||
return ErrDecryptFailed
|
||||
case "320":
|
||||
return ErrUnauthorized
|
||||
case "500":
|
||||
return ErrSystemError
|
||||
default:
|
||||
return &ZhichaError{
|
||||
Code: code,
|
||||
Message: "未知错误",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsZhichaError 检查是否是智查金控错误
|
||||
func IsZhichaError(err error) bool {
|
||||
_, ok := err.(*ZhichaError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetZhichaError 获取智查金控错误
|
||||
func GetZhichaError(err error) *ZhichaError {
|
||||
if zhichaErr, ok := err.(*ZhichaError); ok {
|
||||
return zhichaErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
68
internal/infrastructure/external/zhicha/zhicha_factory.go
vendored
Normal file
68
internal/infrastructure/external/zhicha/zhicha_factory.go
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
package zhicha
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
// NewZhichaServiceWithConfig 使用配置创建智查金控服务
|
||||
func NewZhichaServiceWithConfig(cfg *config.Config) (*ZhichaService, error) {
|
||||
// 将配置类型转换为通用外部服务日志配置
|
||||
loggingConfig := external_logger.ExternalServiceLoggingConfig{
|
||||
Enabled: cfg.Zhicha.Logging.Enabled,
|
||||
LogDir: cfg.Zhicha.Logging.LogDir,
|
||||
ServiceName: "zhicha",
|
||||
UseDaily: cfg.Zhicha.Logging.UseDaily,
|
||||
EnableLevelSeparation: cfg.Zhicha.Logging.EnableLevelSeparation,
|
||||
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
|
||||
}
|
||||
|
||||
// 转换级别配置
|
||||
for key, value := range cfg.Zhicha.Logging.LevelConfigs {
|
||||
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
|
||||
MaxSize: value.MaxSize,
|
||||
MaxBackups: value.MaxBackups,
|
||||
MaxAge: value.MaxAge,
|
||||
Compress: value.Compress,
|
||||
}
|
||||
}
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建智查金控服务
|
||||
service := NewZhichaService(
|
||||
cfg.Zhicha.URL,
|
||||
cfg.Zhicha.AppID,
|
||||
cfg.Zhicha.AppSecret,
|
||||
cfg.Zhicha.EncryptKey,
|
||||
logger,
|
||||
)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewZhichaServiceWithLogging 使用自定义日志配置创建智查金控服务
|
||||
func NewZhichaServiceWithLogging(url, appID, appSecret, encryptKey string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*ZhichaService, error) {
|
||||
// 设置服务名称
|
||||
loggingConfig.ServiceName = "zhicha"
|
||||
|
||||
// 创建通用外部服务日志器
|
||||
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建智查金控服务
|
||||
service := NewZhichaService(url, appID, appSecret, encryptKey, logger)
|
||||
|
||||
return service, nil
|
||||
}
|
||||
|
||||
// NewZhichaServiceSimple 创建简单的智查金控服务(无日志)
|
||||
func NewZhichaServiceSimple(url, appID, appSecret, encryptKey string) *ZhichaService {
|
||||
return NewZhichaService(url, appID, appSecret, encryptKey, nil)
|
||||
}
|
||||
338
internal/infrastructure/external/zhicha/zhicha_service.go
vendored
Normal file
338
internal/infrastructure/external/zhicha/zhicha_service.go
vendored
Normal file
@@ -0,0 +1,338 @@
|
||||
package zhicha
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/md5"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"hyapi-server/internal/shared/external_logger"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDatasource = errors.New("数据源异常")
|
||||
ErrSystem = errors.New("系统异常")
|
||||
)
|
||||
|
||||
// contextKey 用于在 context 中存储不跳过 201 错误检查的标志
|
||||
type contextKey string
|
||||
|
||||
const dontSkipCode201CheckKey contextKey = "dont_skip_code_201_check"
|
||||
|
||||
type ZhichaResp struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data"`
|
||||
Success bool `json:"success"`
|
||||
}
|
||||
|
||||
type ZhichaConfig struct {
|
||||
URL string
|
||||
AppID string
|
||||
AppSecret string
|
||||
EncryptKey string
|
||||
}
|
||||
|
||||
type ZhichaService struct {
|
||||
config ZhichaConfig
|
||||
logger *external_logger.ExternalServiceLogger
|
||||
}
|
||||
|
||||
// NewZhichaService 是一个构造函数,用于初始化 ZhichaService
|
||||
func NewZhichaService(url, appID, appSecret, encryptKey string, logger *external_logger.ExternalServiceLogger) *ZhichaService {
|
||||
return &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: url,
|
||||
AppID: appID,
|
||||
AppSecret: appSecret,
|
||||
EncryptKey: encryptKey,
|
||||
},
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// generateRequestID 生成请求ID
|
||||
func (z *ZhichaService) generateRequestID() string {
|
||||
timestamp := time.Now().UnixNano()
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, z.config.AppID)))
|
||||
return fmt.Sprintf("zhicha_%x", hash[:8])
|
||||
}
|
||||
|
||||
// generateSign 生成签名
|
||||
func (z *ZhichaService) generateSign(timestamp int64) string {
|
||||
// 第一步:对app_secret进行MD5加密
|
||||
encryptedSecret := fmt.Sprintf("%x", md5.Sum([]byte(z.config.AppSecret)))
|
||||
|
||||
// 第二步:将加密后的密钥和时间戳拼接,再次MD5加密
|
||||
signStr := encryptedSecret + strconv.FormatInt(timestamp, 10)
|
||||
sign := fmt.Sprintf("%x", md5.Sum([]byte(signStr)))
|
||||
|
||||
return sign
|
||||
}
|
||||
|
||||
// CallAPI 调用智查金控的 API
|
||||
func (z *ZhichaService) CallAPI(ctx context.Context, proID string, params map[string]interface{}) (data interface{}, err error) {
|
||||
startTime := time.Now()
|
||||
requestID := z.generateRequestID()
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// 从ctx中获取transactionId
|
||||
var transactionID string
|
||||
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
|
||||
transactionID = ctxTransactionID
|
||||
}
|
||||
|
||||
// 记录请求日志
|
||||
if z.logger != nil {
|
||||
z.logger.LogRequest(requestID, transactionID, proID, z.config.URL)
|
||||
}
|
||||
|
||||
jsonData, marshalErr := json.Marshal(params)
|
||||
if marshalErr != nil {
|
||||
err = errors.Join(ErrSystem, marshalErr)
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 创建HTTP POST请求
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", z.config.URL, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("appId", z.config.AppID)
|
||||
req.Header.Set("proId", proID)
|
||||
req.Header.Set("timestamp", strconv.FormatInt(timestamp, 10))
|
||||
req.Header.Set("sign", z.generateSign(timestamp))
|
||||
|
||||
// 创建HTTP客户端,超时时间设置为60秒
|
||||
client := &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
response, err := client.Do(req)
|
||||
if err != nil {
|
||||
// 检查是否是超时错误
|
||||
isTimeout := false
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
isTimeout = true
|
||||
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
|
||||
// 检查是否是网络超时错误
|
||||
isTimeout = true
|
||||
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
|
||||
errStr == "timeout" ||
|
||||
errStr == "Client.Timeout exceeded" ||
|
||||
errStr == "net/http: request canceled" {
|
||||
isTimeout = true
|
||||
}
|
||||
|
||||
if isTimeout {
|
||||
// 超时错误应该返回数据源异常,而不是系统异常
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
|
||||
} else {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
}
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
// 读取响应
|
||||
respBody, err := io.ReadAll(response.Body)
|
||||
if err != nil {
|
||||
err = errors.Join(ErrSystem, err)
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 记录响应日志(不记录具体响应数据)
|
||||
if z.logger != nil {
|
||||
duration := time.Since(startTime)
|
||||
z.logger.LogResponse(requestID, transactionID, proID, response.StatusCode, duration)
|
||||
}
|
||||
|
||||
// 检查HTTP状态码
|
||||
if response.StatusCode != http.StatusOK {
|
||||
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var zhichaResp ZhichaResp
|
||||
if err := json.Unmarshal(respBody, &zhichaResp); err != nil {
|
||||
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %s", err.Error()))
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, err, params)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查业务状态码
|
||||
if zhichaResp.Code != "200" && zhichaResp.Code != "201" {
|
||||
// 创建智查金控错误用于日志记录
|
||||
zhichaErr := NewZhichaErrorFromCode(zhichaResp.Code)
|
||||
if zhichaErr.Code == "未知错误" {
|
||||
zhichaErr.Message = zhichaResp.Message
|
||||
}
|
||||
|
||||
// 记录智查金控的详细错误信息到日志
|
||||
if z.logger != nil {
|
||||
z.logger.LogError(requestID, transactionID, proID, zhichaErr, params)
|
||||
}
|
||||
|
||||
// 对外统一返回数据源异常错误
|
||||
return nil, ErrDatasource
|
||||
}
|
||||
|
||||
// 201 表示查询为空,兼容其它情况如果data也为空,则返回空对象
|
||||
if zhichaResp.Code == "201" {
|
||||
// 先做类型断言
|
||||
dataMap, ok := zhichaResp.Data.(map[string]interface{})
|
||||
if ok && len(dataMap) > 0 {
|
||||
return dataMap, nil
|
||||
}
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// 返回data字段
|
||||
return zhichaResp.Data, nil
|
||||
}
|
||||
|
||||
// Encrypt 使用配置的加密密钥对数据进行AES-128-CBC加密
|
||||
func (z *ZhichaService) Encrypt(data string) (string, error) {
|
||||
if z.config.EncryptKey == "" {
|
||||
return "", fmt.Errorf("加密密钥未配置")
|
||||
}
|
||||
|
||||
// 将十六进制密钥转换为字节
|
||||
binKey, err := hex.DecodeString(z.config.EncryptKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("密钥格式错误: %w", err)
|
||||
}
|
||||
|
||||
if len(binKey) < 16 { // AES-128, 16 bytes
|
||||
return "", fmt.Errorf("密钥长度不足,需要至少16字节")
|
||||
}
|
||||
|
||||
// 从密钥前16个字符生成IV
|
||||
iv := []byte(z.config.EncryptKey[:16])
|
||||
|
||||
// 创建AES加密器
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建AES加密器失败: %w", err)
|
||||
}
|
||||
|
||||
// 对数据进行PKCS7填充
|
||||
paddedData := z.pkcs7Padding([]byte(data), aes.BlockSize)
|
||||
|
||||
// 创建CBC模式加密器
|
||||
mode := cipher.NewCBCEncrypter(block, iv)
|
||||
|
||||
// 加密
|
||||
ciphertext := make([]byte, len(paddedData))
|
||||
mode.CryptBlocks(ciphertext, paddedData)
|
||||
|
||||
// 返回Base64编码结果
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt 使用配置的加密密钥对数据进行AES-128-CBC解密
|
||||
func (z *ZhichaService) Decrypt(encryptedData string) (string, error) {
|
||||
if z.config.EncryptKey == "" {
|
||||
return "", fmt.Errorf("加密密钥未配置")
|
||||
}
|
||||
|
||||
// 将十六进制密钥转换为字节
|
||||
binKey, err := hex.DecodeString(z.config.EncryptKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("密钥格式错误: %w", err)
|
||||
}
|
||||
|
||||
if len(binKey) < 16 { // AES-128, 16 bytes
|
||||
return "", fmt.Errorf("密钥长度不足,需要至少16字节")
|
||||
}
|
||||
|
||||
// 从密钥前16个字符生成IV
|
||||
iv := []byte(z.config.EncryptKey[:16])
|
||||
|
||||
// 解码Base64数据
|
||||
decodedData, err := base64.StdEncoding.DecodeString(encryptedData)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("Base64解码失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查数据长度是否为AES块大小的倍数
|
||||
if len(decodedData) == 0 || len(decodedData)%aes.BlockSize != 0 {
|
||||
return "", fmt.Errorf("加密数据长度无效,必须是%d字节的倍数", aes.BlockSize)
|
||||
}
|
||||
|
||||
// 创建AES解密器
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("创建AES解密器失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建CBC模式解密器
|
||||
mode := cipher.NewCBCDecrypter(block, iv)
|
||||
|
||||
// 解密
|
||||
plaintext := make([]byte, len(decodedData))
|
||||
mode.CryptBlocks(plaintext, decodedData)
|
||||
|
||||
// 移除PKCS7填充
|
||||
unpadded, err := z.pkcs7Unpadding(plaintext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("移除填充失败: %w", err)
|
||||
}
|
||||
|
||||
return string(unpadded), nil
|
||||
}
|
||||
|
||||
// pkcs7Padding 使用PKCS7填充数据
|
||||
func (z *ZhichaService) pkcs7Padding(src []byte, blockSize int) []byte {
|
||||
padding := blockSize - len(src)%blockSize
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(src, padtext...)
|
||||
}
|
||||
|
||||
// pkcs7Unpadding 移除PKCS7填充
|
||||
func (z *ZhichaService) pkcs7Unpadding(src []byte) ([]byte, error) {
|
||||
length := len(src)
|
||||
if length == 0 {
|
||||
return nil, fmt.Errorf("数据为空")
|
||||
}
|
||||
|
||||
unpadding := int(src[length-1])
|
||||
if unpadding > length {
|
||||
return nil, fmt.Errorf("填充长度无效")
|
||||
}
|
||||
|
||||
return src[:length-unpadding], nil
|
||||
}
|
||||
703
internal/infrastructure/external/zhicha/zhicha_test.go
vendored
Normal file
703
internal/infrastructure/external/zhicha/zhicha_test.go
vendored
Normal file
@@ -0,0 +1,703 @@
|
||||
package zhicha
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGenerateSign(t *testing.T) {
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
AppSecret: "test_secret_123",
|
||||
},
|
||||
}
|
||||
|
||||
timestamp := int64(1640995200) // 2022-01-01 00:00:00
|
||||
sign := service.generateSign(timestamp)
|
||||
|
||||
if sign == "" {
|
||||
t.Error("签名生成失败,签名为空")
|
||||
}
|
||||
|
||||
// 验证签名长度(MD5是32位十六进制)
|
||||
if len(sign) != 32 {
|
||||
t.Errorf("签名长度错误,期望32位,实际%d位", len(sign))
|
||||
}
|
||||
|
||||
// 验证相同参数生成相同签名
|
||||
sign2 := service.generateSign(timestamp)
|
||||
if sign != sign2 {
|
||||
t.Error("相同参数生成的签名不一致")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// 测试密钥(32位十六进制)
|
||||
key := "1234567890abcdef1234567890abcdef"
|
||||
|
||||
// 测试数据
|
||||
testData := "这是一个测试数据,包含中文和English"
|
||||
|
||||
// 加密
|
||||
encrypted, err := Encrypt(testData, key)
|
||||
if err != nil {
|
||||
t.Fatalf("加密失败: %v", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Error("加密结果为空")
|
||||
}
|
||||
|
||||
// 解密
|
||||
decrypted, err := Decrypt(encrypted, key)
|
||||
if err != nil {
|
||||
t.Fatalf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Errorf("解密结果不匹配,期望: %s, 实际: %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptWithInvalidKey(t *testing.T) {
|
||||
// 测试无效密钥
|
||||
invalidKeys := []string{
|
||||
"", // 空密钥
|
||||
"123", // 太短
|
||||
"invalid_key_string", // 非十六进制
|
||||
"1234567890abcdef", // 16位,不足32位
|
||||
}
|
||||
|
||||
testData := "test data"
|
||||
|
||||
for _, key := range invalidKeys {
|
||||
_, err := Encrypt(testData, key)
|
||||
if err == nil {
|
||||
t.Errorf("使用无效密钥 %s 应该返回错误", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWithInvalidData(t *testing.T) {
|
||||
key := "af4ca0098e6a202a5c08c413ebd9fd62"
|
||||
|
||||
// 测试无效的加密数据
|
||||
invalidData := []string{
|
||||
"", // 空数据
|
||||
"invalid_base64", // 无效的Base64
|
||||
"dGVzdA==", // 有效的Base64但不是AES加密数据
|
||||
"i96w+SDjwENjuvsokMFbLw==",
|
||||
"oaihmICgEcszWMk0gXoB12E/ygF4g78x0/sC3/KHnBk=",
|
||||
"5bx+WvXvdNRVVOp9UuNFHg==",
|
||||
}
|
||||
|
||||
for _, data := range invalidData {
|
||||
decrypted, err := Decrypt(data, key)
|
||||
if err == nil {
|
||||
t.Errorf("使用无效数据 %s 应该返回错误", data)
|
||||
}
|
||||
fmt.Println("data: ", data)
|
||||
fmt.Println("decrypted: ", decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPKCS7Padding(t *testing.T) {
|
||||
testCases := []struct {
|
||||
input string
|
||||
blockSize int
|
||||
expected int
|
||||
}{
|
||||
{"", 16, 16},
|
||||
{"a", 16, 16},
|
||||
{"ab", 16, 16},
|
||||
{"abc", 16, 16},
|
||||
{"abcd", 16, 16},
|
||||
{"abcde", 16, 16},
|
||||
{"abcdef", 16, 16},
|
||||
{"abcdefg", 16, 16},
|
||||
{"abcdefgh", 16, 16},
|
||||
{"abcdefghi", 16, 16},
|
||||
{"abcdefghij", 16, 16},
|
||||
{"abcdefghijk", 16, 16},
|
||||
{"abcdefghijkl", 16, 16},
|
||||
{"abcdefghijklm", 16, 16},
|
||||
{"abcdefghijklmn", 16, 16},
|
||||
{"abcdefghijklmno", 16, 16},
|
||||
{"abcdefghijklmnop", 16, 16},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
padded := pkcs7Padding([]byte(tc.input), tc.blockSize)
|
||||
if len(padded)%tc.blockSize != 0 {
|
||||
t.Errorf("输入: %s, 期望块大小倍数,实际: %d", tc.input, len(padded))
|
||||
}
|
||||
|
||||
// 测试移除填充
|
||||
unpadded, err := pkcs7Unpadding(padded)
|
||||
if err != nil {
|
||||
t.Errorf("移除填充失败: %v", err)
|
||||
}
|
||||
|
||||
if string(unpadded) != tc.input {
|
||||
t.Errorf("输入: %s, 期望: %s, 实际: %s", tc.input, tc.input, string(unpadded))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRequestID(t *testing.T) {
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
AppID: "test_app_id",
|
||||
},
|
||||
}
|
||||
|
||||
id1 := service.generateRequestID()
|
||||
|
||||
// 等待一小段时间确保时间戳不同
|
||||
time.Sleep(time.Millisecond)
|
||||
|
||||
id2 := service.generateRequestID()
|
||||
|
||||
if id1 == "" || id2 == "" {
|
||||
t.Error("请求ID生成失败")
|
||||
}
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("不同时间生成的请求ID应该不同")
|
||||
}
|
||||
|
||||
// 验证ID格式
|
||||
if len(id1) < 20 { // zhicha_ + 8位十六进制 + 其他
|
||||
t.Errorf("请求ID长度不足,实际: %s", id1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCallAPISuccess(t *testing.T) {
|
||||
// 创建测试服务
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "http://proxy.haiyudata.com/dataMiddle/api/handle",
|
||||
AppID: "4b78fff61ab8426f",
|
||||
AppSecret: "1128f01b94124ae899c2e9f2b1f37681",
|
||||
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
|
||||
},
|
||||
logger: nil, // 测试时不使用日志
|
||||
}
|
||||
|
||||
// 测试参数
|
||||
idCardEncrypted, err := service.Encrypt("45212220000827423X")
|
||||
if err != nil {
|
||||
t.Fatalf("加密身份证号失败: %v", err)
|
||||
}
|
||||
nameEncrypted, err := service.Encrypt("张荣宏")
|
||||
if err != nil {
|
||||
t.Fatalf("加密姓名失败: %v", err)
|
||||
}
|
||||
params := map[string]interface{}{
|
||||
"idCard": idCardEncrypted,
|
||||
"name": nameEncrypted,
|
||||
"authorized": "1",
|
||||
}
|
||||
|
||||
// 创建带超时的context
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 调用API
|
||||
data, err := service.CallAPI(ctx, "ZCI001", params)
|
||||
|
||||
// 注意:这是真实API调用,可能会因为网络、认证等原因失败
|
||||
// 我们主要测试方法调用是否正常,不强制要求API返回成功
|
||||
if err != nil {
|
||||
// 如果是网络错误或认证错误,这是正常的
|
||||
t.Logf("API调用返回错误: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 如果成功,验证响应
|
||||
if data == nil {
|
||||
t.Error("响应数据为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 将data转换为字符串进行显示
|
||||
var dataStr string
|
||||
if str, ok := data.(string); ok {
|
||||
dataStr = str
|
||||
} else {
|
||||
// 如果不是字符串,尝试JSON序列化
|
||||
if dataBytes, err := json.Marshal(data); err == nil {
|
||||
dataStr = string(dataBytes)
|
||||
} else {
|
||||
dataStr = fmt.Sprintf("%v", data)
|
||||
}
|
||||
}
|
||||
|
||||
t.Logf("API调用成功,响应内容: %s", dataStr)
|
||||
}
|
||||
|
||||
func TestCallAPIWithInvalidURL(t *testing.T) {
|
||||
// 创建使用无效URL的服务
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://invalid-url-that-does-not-exist.com/api",
|
||||
AppID: "test_app_id",
|
||||
AppSecret: "test_app_secret",
|
||||
EncryptKey: "test_encrypt_key",
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"test": "data",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 应该返回错误
|
||||
_, err := service.CallAPI(ctx, "test_pro_id", params)
|
||||
if err == nil {
|
||||
t.Error("使用无效URL应该返回错误")
|
||||
}
|
||||
|
||||
t.Logf("预期的错误: %v", err)
|
||||
}
|
||||
|
||||
func TestCallAPIWithContextCancellation(t *testing.T) {
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://www.zhichajinkong.com/dataMiddle/api/handle",
|
||||
AppID: "4b78fff61ab8426f",
|
||||
AppSecret: "1128f01b94124ae899c2e9f2b1f37681",
|
||||
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"test": "data",
|
||||
}
|
||||
|
||||
// 创建可取消的context
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// 立即取消
|
||||
cancel()
|
||||
|
||||
// 应该返回context取消错误
|
||||
_, err := service.CallAPI(ctx, "test_pro_id", params)
|
||||
if err == nil {
|
||||
t.Error("context取消后应该返回错误")
|
||||
}
|
||||
|
||||
// 检查是否是context取消错误
|
||||
if err != context.Canceled && !strings.Contains(err.Error(), "context") {
|
||||
t.Errorf("期望context相关错误,实际: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("Context取消错误: %v", err)
|
||||
}
|
||||
|
||||
func TestCallAPIWithTimeout(t *testing.T) {
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://www.zhichajinkong.com/dataMiddle/api/handle",
|
||||
AppID: "4b78fff61ab8426f",
|
||||
AppSecret: "1128f01b94124ae899c2e9f2b1f37681",
|
||||
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"test": "data",
|
||||
}
|
||||
|
||||
// 创建很短的超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
// 应该因为超时而失败
|
||||
_, err := service.CallAPI(ctx, "test_pro_id", params)
|
||||
if err == nil {
|
||||
t.Error("超时后应该返回错误")
|
||||
}
|
||||
|
||||
// 检查是否是超时错误
|
||||
if err != context.DeadlineExceeded && !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "deadline") {
|
||||
t.Errorf("期望超时相关错误,实际: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("超时错误: %v", err)
|
||||
}
|
||||
|
||||
func TestCallAPIRequestHeaders(t *testing.T) {
|
||||
// 这个测试验证请求头是否正确设置
|
||||
// 由于我们不能直接访问HTTP请求,我们通过日志或其他方式来验证
|
||||
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://www.zhichajinkong.com/dataMiddle/api/handle",
|
||||
AppID: "test_app_id",
|
||||
AppSecret: "test_app_secret",
|
||||
EncryptKey: "test_encrypt_key",
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
params := map[string]interface{}{
|
||||
"test": "headers",
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 调用API(可能会失败,但我们主要测试请求头设置)
|
||||
_, err := service.CallAPI(ctx, "test_pro_id", params)
|
||||
|
||||
// 验证签名生成是否正确
|
||||
timestamp := time.Now().Unix()
|
||||
sign := service.generateSign(timestamp)
|
||||
|
||||
if sign == "" {
|
||||
t.Error("签名生成失败")
|
||||
}
|
||||
|
||||
if len(sign) != 32 {
|
||||
t.Errorf("签名长度错误,期望32位,实际%d位", len(sign))
|
||||
}
|
||||
|
||||
t.Logf("签名生成成功: %s", sign)
|
||||
t.Logf("API调用结果: %v", err)
|
||||
}
|
||||
|
||||
func TestZhichaErrorHandling(t *testing.T) {
|
||||
// 测试核心错误类型
|
||||
testCases := []struct {
|
||||
name string
|
||||
code string
|
||||
message string
|
||||
expectedErr *ZhichaError
|
||||
}{
|
||||
{
|
||||
name: "成功状态",
|
||||
code: "200",
|
||||
message: "请求成功",
|
||||
expectedErr: ErrSuccess,
|
||||
},
|
||||
{
|
||||
name: "查询无记录",
|
||||
code: "201",
|
||||
message: "查询无记录",
|
||||
expectedErr: ErrNoRecord,
|
||||
},
|
||||
{
|
||||
name: "手机号错误",
|
||||
code: "306",
|
||||
message: "手机号错误",
|
||||
expectedErr: ErrPhoneError,
|
||||
},
|
||||
{
|
||||
name: "姓名错误",
|
||||
code: "305",
|
||||
message: "姓名错误",
|
||||
expectedErr: ErrNameError,
|
||||
},
|
||||
{
|
||||
name: "身份证号错误",
|
||||
code: "307",
|
||||
message: "身份证号错误",
|
||||
expectedErr: ErrIDCardError,
|
||||
},
|
||||
{
|
||||
name: "余额不足",
|
||||
code: "310",
|
||||
message: "余额不足",
|
||||
expectedErr: ErrInsufficientBalance,
|
||||
},
|
||||
{
|
||||
name: "用户不存在",
|
||||
code: "312",
|
||||
message: "用户不存在",
|
||||
expectedErr: ErrUserNotExist,
|
||||
},
|
||||
{
|
||||
name: "系统异常",
|
||||
code: "500",
|
||||
message: "系统异常,请联系管理员",
|
||||
expectedErr: ErrSystemError,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 测试从状态码创建错误
|
||||
err := NewZhichaErrorFromCode(tc.code)
|
||||
|
||||
if err.Code != tc.expectedErr.Code {
|
||||
t.Errorf("期望错误码 %s,实际 %s", tc.expectedErr.Code, err.Code)
|
||||
}
|
||||
|
||||
if err.Message != tc.expectedErr.Message {
|
||||
t.Errorf("期望错误消息 %s,实际 %s", tc.expectedErr.Message, err.Message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhichaErrorHelpers(t *testing.T) {
|
||||
// 测试错误类型判断函数
|
||||
err := NewZhichaError("302", "业务参数缺失")
|
||||
|
||||
// 测试IsZhichaError
|
||||
if !IsZhichaError(err) {
|
||||
t.Error("IsZhichaError应该返回true")
|
||||
}
|
||||
|
||||
// 测试GetZhichaError
|
||||
zhichaErr := GetZhichaError(err)
|
||||
if zhichaErr == nil {
|
||||
t.Error("GetZhichaError应该返回非nil值")
|
||||
}
|
||||
|
||||
if zhichaErr.Code != "302" {
|
||||
t.Errorf("期望错误码302,实际%s", zhichaErr.Code)
|
||||
}
|
||||
|
||||
// 测试普通错误
|
||||
normalErr := fmt.Errorf("普通错误")
|
||||
if IsZhichaError(normalErr) {
|
||||
t.Error("普通错误不应该被识别为智查金控错误")
|
||||
}
|
||||
|
||||
if GetZhichaError(normalErr) != nil {
|
||||
t.Error("普通错误的GetZhichaError应该返回nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestZhichaErrorString(t *testing.T) {
|
||||
// 测试错误字符串格式
|
||||
err := NewZhichaError("304", "请求头参数缺失")
|
||||
expectedStr := "智查金控错误 [304]: 请求头参数缺失"
|
||||
|
||||
if err.Error() != expectedStr {
|
||||
t.Errorf("期望错误字符串 %s,实际 %s", expectedStr, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorsIsFunctionality(t *testing.T) {
|
||||
// 测试 errors.Is() 功能是否正常工作
|
||||
|
||||
// 创建各种错误
|
||||
testCases := []struct {
|
||||
name string
|
||||
err error
|
||||
expected error
|
||||
shouldMatch bool
|
||||
}{
|
||||
{
|
||||
name: "手机号错误匹配",
|
||||
err: ErrPhoneError,
|
||||
expected: ErrPhoneError,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "姓名错误匹配",
|
||||
err: ErrNameError,
|
||||
expected: ErrNameError,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "身份证号错误匹配",
|
||||
err: ErrIDCardError,
|
||||
expected: ErrIDCardError,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "余额不足错误匹配",
|
||||
err: ErrInsufficientBalance,
|
||||
expected: ErrInsufficientBalance,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "用户不存在错误匹配",
|
||||
err: ErrUserNotExist,
|
||||
expected: ErrUserNotExist,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "系统错误匹配",
|
||||
err: ErrSystemError,
|
||||
expected: ErrSystemError,
|
||||
shouldMatch: true,
|
||||
},
|
||||
{
|
||||
name: "不同错误不匹配",
|
||||
err: ErrPhoneError,
|
||||
expected: ErrNameError,
|
||||
shouldMatch: false,
|
||||
},
|
||||
{
|
||||
name: "手机号错误与身份证号错误不匹配",
|
||||
err: ErrPhoneError,
|
||||
expected: ErrIDCardError,
|
||||
shouldMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// 使用 errors.Is() 进行判断
|
||||
if errors.Is(tc.err, tc.expected) != tc.shouldMatch {
|
||||
if tc.shouldMatch {
|
||||
t.Errorf("期望 errors.Is(%v, %v) 返回 true", tc.err, tc.expected)
|
||||
} else {
|
||||
t.Errorf("期望 errors.Is(%v, %v) 返回 false", tc.err, tc.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorsIsInSwitch(t *testing.T) {
|
||||
// 测试在 switch 语句中使用 errors.Is()
|
||||
|
||||
// 模拟API调用返回手机号错误
|
||||
err := ErrPhoneError
|
||||
|
||||
// 使用 switch 语句进行错误判断
|
||||
var result string
|
||||
switch {
|
||||
case errors.Is(err, ErrSuccess):
|
||||
result = "请求成功"
|
||||
case errors.Is(err, ErrNoRecord):
|
||||
result = "查询无记录"
|
||||
case errors.Is(err, ErrPhoneError):
|
||||
result = "手机号格式错误"
|
||||
case errors.Is(err, ErrNameError):
|
||||
result = "姓名格式错误"
|
||||
case errors.Is(err, ErrIDCardError):
|
||||
result = "身份证号格式错误"
|
||||
case errors.Is(err, ErrHeaderParamMissing):
|
||||
result = "请求头参数缺失"
|
||||
case errors.Is(err, ErrInsufficientBalance):
|
||||
result = "余额不足"
|
||||
case errors.Is(err, ErrUserNotExist):
|
||||
result = "用户不存在"
|
||||
case errors.Is(err, ErrUserUnauthorized):
|
||||
result = "用户未授权"
|
||||
case errors.Is(err, ErrSystemError):
|
||||
result = "系统异常"
|
||||
default:
|
||||
result = "未知错误"
|
||||
}
|
||||
|
||||
// 验证结果
|
||||
expected := "手机号格式错误"
|
||||
if result != expected {
|
||||
t.Errorf("期望结果 %s,实际 %s", expected, result)
|
||||
}
|
||||
|
||||
t.Logf("Switch语句错误判断结果: %s", result)
|
||||
}
|
||||
|
||||
func TestServiceEncryptDecrypt(t *testing.T) {
|
||||
// 创建测试服务
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://test.com",
|
||||
AppID: "test_app_id",
|
||||
AppSecret: "test_app_secret",
|
||||
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
// 测试数据
|
||||
testData := "Hello, 智查金控!"
|
||||
|
||||
// 测试加密
|
||||
encrypted, err := service.Encrypt(testData)
|
||||
if err != nil {
|
||||
t.Fatalf("加密失败: %v", err)
|
||||
}
|
||||
|
||||
if encrypted == "" {
|
||||
t.Error("加密结果为空")
|
||||
}
|
||||
|
||||
if encrypted == testData {
|
||||
t.Error("加密结果与原文相同")
|
||||
}
|
||||
|
||||
t.Logf("原文: %s", testData)
|
||||
t.Logf("加密后: %s", encrypted)
|
||||
|
||||
// 测试解密
|
||||
decrypted, err := service.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("解密失败: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Errorf("解密结果不匹配,期望: %s,实际: %s", testData, decrypted)
|
||||
}
|
||||
|
||||
t.Logf("解密后: %s", decrypted)
|
||||
}
|
||||
|
||||
func TestEncryptWithoutKey(t *testing.T) {
|
||||
// 创建没有加密密钥的服务
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://test.com",
|
||||
AppID: "test_app_id",
|
||||
AppSecret: "test_app_secret",
|
||||
// 没有设置 EncryptKey
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
// 应该返回错误
|
||||
_, err := service.Encrypt("test data")
|
||||
if err == nil {
|
||||
t.Error("没有加密密钥时应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "加密密钥未配置") {
|
||||
t.Errorf("期望错误包含'加密密钥未配置',实际: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("预期的错误: %v", err)
|
||||
}
|
||||
|
||||
func TestDecryptWithoutKey(t *testing.T) {
|
||||
// 创建没有加密密钥的服务
|
||||
service := &ZhichaService{
|
||||
config: ZhichaConfig{
|
||||
URL: "https://test.com",
|
||||
AppID: "test_app_id",
|
||||
AppSecret: "test_app_secret",
|
||||
// 没有设置 EncryptKey
|
||||
},
|
||||
logger: nil,
|
||||
}
|
||||
|
||||
// 应该返回错误
|
||||
_, err := service.Decrypt("test encrypted data")
|
||||
if err == nil {
|
||||
t.Error("没有加密密钥时应该返回错误")
|
||||
}
|
||||
|
||||
if !strings.Contains(err.Error(), "加密密钥未配置") {
|
||||
t.Errorf("期望错误包含'加密密钥未配置',实际: %v", err)
|
||||
}
|
||||
|
||||
t.Logf("预期的错误: %v", err)
|
||||
}
|
||||
168
internal/infrastructure/http/handlers/admin_security_handler.go
Normal file
168
internal/infrastructure/http/handlers/admin_security_handler.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
securityEntities "hyapi-server/internal/domains/security/entities"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
"hyapi-server/internal/shared/ipgeo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// AdminSecurityHandler 管理员安全数据处理器
|
||||
type AdminSecurityHandler struct {
|
||||
db *gorm.DB
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
ipLocator *ipgeo.Locator
|
||||
}
|
||||
|
||||
func NewAdminSecurityHandler(
|
||||
db *gorm.DB,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
ipLocator *ipgeo.Locator,
|
||||
) *AdminSecurityHandler {
|
||||
return &AdminSecurityHandler{
|
||||
db: db,
|
||||
responseBuilder: responseBuilder,
|
||||
logger: logger,
|
||||
ipLocator: ipLocator,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *AdminSecurityHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if value := c.Query(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func (h *AdminSecurityHandler) parseRange(c *gin.Context) (time.Time, time.Time, bool) {
|
||||
startTime := time.Now().Add(-24 * time.Hour)
|
||||
endTime := time.Now()
|
||||
|
||||
if start := strings.TrimSpace(c.Query("start_time")); start != "" {
|
||||
t, err := time.Parse("2006-01-02 15:04:05", start)
|
||||
if err != nil {
|
||||
h.responseBuilder.BadRequest(c, "start_time格式错误,示例:2026-03-19 10:00:00")
|
||||
return time.Time{}, time.Time{}, false
|
||||
}
|
||||
startTime = t
|
||||
}
|
||||
if end := strings.TrimSpace(c.Query("end_time")); end != "" {
|
||||
t, err := time.Parse("2006-01-02 15:04:05", end)
|
||||
if err != nil {
|
||||
h.responseBuilder.BadRequest(c, "end_time格式错误,示例:2026-03-19 12:00:00")
|
||||
return time.Time{}, time.Time{}, false
|
||||
}
|
||||
endTime = t
|
||||
}
|
||||
return startTime, endTime, true
|
||||
}
|
||||
|
||||
// ListSuspiciousIPs 获取可疑IP列表
|
||||
func (h *AdminSecurityHandler) ListSuspiciousIPs(c *gin.Context) {
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 20)
|
||||
if pageSize > 100 {
|
||||
pageSize = 100
|
||||
}
|
||||
startTime, endTime, ok := h.parseRange(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
ip := strings.TrimSpace(c.Query("ip"))
|
||||
path := strings.TrimSpace(c.Query("path"))
|
||||
|
||||
query := h.db.Model(&securityEntities.SuspiciousIPRecord{}).
|
||||
Where("created_at >= ? AND created_at <= ?", startTime, endTime)
|
||||
if ip != "" {
|
||||
query = query.Where("ip = ?", ip)
|
||||
}
|
||||
if path != "" {
|
||||
query = query.Where("path LIKE ?", "%"+path+"%")
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
h.logger.Error("查询可疑IP总数失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "查询失败")
|
||||
return
|
||||
}
|
||||
|
||||
var items []securityEntities.SuspiciousIPRecord
|
||||
if err := query.Order("created_at DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&items).Error; err != nil {
|
||||
h.logger.Error("查询可疑IP列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "查询失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, gin.H{
|
||||
"items": items,
|
||||
"total": total,
|
||||
}, "获取成功")
|
||||
}
|
||||
|
||||
type geoStreamRow struct {
|
||||
IP string `json:"ip"`
|
||||
Path string `json:"path"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
// GetSuspiciousIPGeoStream 获取地球请求流数据
|
||||
func (h *AdminSecurityHandler) GetSuspiciousIPGeoStream(c *gin.Context) {
|
||||
startTime, endTime, ok := h.parseRange(c)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
topN := h.getIntQuery(c, "top_n", 200)
|
||||
if topN > 1000 {
|
||||
topN = 1000
|
||||
}
|
||||
|
||||
var rows []geoStreamRow
|
||||
err := h.db.Model(&securityEntities.SuspiciousIPRecord{}).
|
||||
Select("ip, path, COUNT(1) as count").
|
||||
Where("created_at >= ? AND created_at <= ?", startTime, endTime).
|
||||
Group("ip, path").
|
||||
Order("count DESC").
|
||||
Limit(topN).
|
||||
Scan(&rows).Error
|
||||
if err != nil {
|
||||
h.logger.Error("查询地球请求流失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "查询失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 目标固定服务器点位(上海)
|
||||
const serverName = "HYAPI-Server"
|
||||
const serverLng = 121.4737
|
||||
const serverLat = 31.2304
|
||||
|
||||
result := make([]gin.H, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
record := securityEntities.SuspiciousIPRecord{IP: row.IP}
|
||||
fromName, fromLng, fromLat := h.ipLocator.ToGeoPoint(record)
|
||||
result = append(result, gin.H{
|
||||
"from_name": fromName,
|
||||
"from_lng": fromLng,
|
||||
"from_lat": fromLat,
|
||||
"to_name": serverName,
|
||||
"to_lng": serverLng,
|
||||
"to_lat": serverLat,
|
||||
"value": row.Count,
|
||||
"path": row.Path,
|
||||
"ip": row.IP,
|
||||
})
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取成功")
|
||||
}
|
||||
411
internal/infrastructure/http/handlers/announcement_handler.go
Normal file
411
internal/infrastructure/http/handlers/announcement_handler.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/application/article"
|
||||
"hyapi-server/internal/application/article/dto/commands"
|
||||
appQueries "hyapi-server/internal/application/article/dto/queries"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AnnouncementHandler 公告HTTP处理器
|
||||
type AnnouncementHandler struct {
|
||||
appService article.AnnouncementApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAnnouncementHandler 创建公告HTTP处理器
|
||||
func NewAnnouncementHandler(
|
||||
appService article.AnnouncementApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
) *AnnouncementHandler {
|
||||
return &AnnouncementHandler{
|
||||
appService: appService,
|
||||
responseBuilder: responseBuilder,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateAnnouncement 创建公告
|
||||
// @Summary 创建公告
|
||||
// @Description 创建新的公告
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateAnnouncementCommand true "创建公告请求"
|
||||
// @Success 201 {object} map[string]interface{} "公告创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements [post]
|
||||
func (h *AnnouncementHandler) CreateAnnouncement(c *gin.Context) {
|
||||
var cmd commands.CreateAnnouncementCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证用户是否已登录
|
||||
if _, exists := c.Get("user_id"); !exists {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.CreateAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("创建公告失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, nil, "公告创建成功")
|
||||
}
|
||||
|
||||
// GetAnnouncementByID 获取公告详情
|
||||
// @Summary 获取公告详情
|
||||
// @Description 根据ID获取公告详情
|
||||
// @Tags 公告管理-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "公告ID"
|
||||
// @Success 200 {object} responses.AnnouncementInfoResponse "获取公告详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 404 {object} map[string]interface{} "公告不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/announcements/{id} [get]
|
||||
func (h *AnnouncementHandler) GetAnnouncementByID(c *gin.Context) {
|
||||
var query appQueries.GetAnnouncementQuery
|
||||
|
||||
// 绑定URI参数(公告ID)
|
||||
if err := h.validator.ValidateParam(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.GetAnnouncementByID(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取公告详情失败", zap.Error(err))
|
||||
h.responseBuilder.NotFound(c, "公告不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取公告详情成功")
|
||||
}
|
||||
|
||||
// ListAnnouncements 获取公告列表
|
||||
// @Summary 获取公告列表
|
||||
// @Description 分页获取公告列表,支持多种筛选条件
|
||||
// @Tags 公告管理-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param status query string false "公告状态"
|
||||
// @Param title query string false "标题关键词"
|
||||
// @Param order_by query string false "排序字段"
|
||||
// @Param order_dir query string false "排序方向"
|
||||
// @Success 200 {object} responses.AnnouncementListResponse "获取公告列表成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/announcements [get]
|
||||
func (h *AnnouncementHandler) ListAnnouncements(c *gin.Context) {
|
||||
var query appQueries.ListAnnouncementQuery
|
||||
if err := h.validator.ValidateQuery(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
if query.PageSize > 100 {
|
||||
query.PageSize = 100
|
||||
}
|
||||
|
||||
response, err := h.appService.ListAnnouncements(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取公告列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取公告列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取公告列表成功")
|
||||
}
|
||||
|
||||
// PublishAnnouncement 发布公告
|
||||
// @Summary 发布公告
|
||||
// @Description 发布指定的公告
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Success 200 {object} map[string]interface{} "发布成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id}/publish [post]
|
||||
func (h *AnnouncementHandler) PublishAnnouncement(c *gin.Context) {
|
||||
var cmd commands.PublishAnnouncementCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.PublishAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("发布公告失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "发布成功")
|
||||
}
|
||||
|
||||
// WithdrawAnnouncement 撤回公告
|
||||
// @Summary 撤回公告
|
||||
// @Description 撤回已发布的公告
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Success 200 {object} map[string]interface{} "撤回成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id}/withdraw [post]
|
||||
func (h *AnnouncementHandler) WithdrawAnnouncement(c *gin.Context) {
|
||||
var cmd commands.WithdrawAnnouncementCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.WithdrawAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("撤回公告失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "撤回成功")
|
||||
}
|
||||
|
||||
// ArchiveAnnouncement 归档公告
|
||||
// @Summary 归档公告
|
||||
// @Description 归档指定的公告
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Success 200 {object} map[string]interface{} "归档成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id}/archive [post]
|
||||
func (h *AnnouncementHandler) ArchiveAnnouncement(c *gin.Context) {
|
||||
var cmd commands.ArchiveAnnouncementCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.ArchiveAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("归档公告失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "归档成功")
|
||||
}
|
||||
|
||||
// UpdateAnnouncement 更新公告
|
||||
// @Summary 更新公告
|
||||
// @Description 更新指定的公告
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Param request body commands.UpdateAnnouncementCommand true "更新公告请求"
|
||||
// @Success 200 {object} map[string]interface{} "更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id} [put]
|
||||
func (h *AnnouncementHandler) UpdateAnnouncement(c *gin.Context) {
|
||||
var cmd commands.UpdateAnnouncementCommand
|
||||
|
||||
// 先绑定URI参数(公告ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(公告信息)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.UpdateAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("更新公告失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "更新成功")
|
||||
}
|
||||
|
||||
// DeleteAnnouncement 删除公告
|
||||
// @Summary 删除公告
|
||||
// @Description 删除指定的公告
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Success 200 {object} map[string]interface{} "删除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id} [delete]
|
||||
func (h *AnnouncementHandler) DeleteAnnouncement(c *gin.Context) {
|
||||
var cmd commands.DeleteAnnouncementCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.DeleteAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("删除公告失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "删除成功")
|
||||
}
|
||||
|
||||
// SchedulePublishAnnouncement 定时发布公告
|
||||
// @Summary 定时发布公告
|
||||
// @Description 设置公告的定时发布时间
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Param request body commands.SchedulePublishAnnouncementCommand true "定时发布请求"
|
||||
// @Success 200 {object} map[string]interface{} "设置成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id}/schedule-publish [post]
|
||||
func (h *AnnouncementHandler) SchedulePublishAnnouncement(c *gin.Context) {
|
||||
var cmd commands.SchedulePublishAnnouncementCommand
|
||||
|
||||
// 先绑定URI参数(公告ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(定时发布时间)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.SchedulePublishAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("设置定时发布失败", zap.String("id", cmd.ID), zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "设置成功")
|
||||
}
|
||||
|
||||
// UpdateSchedulePublishAnnouncement 更新定时发布公告
|
||||
// @Summary 更新定时发布公告
|
||||
// @Description 修改公告的定时发布时间
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Param request body commands.UpdateSchedulePublishAnnouncementCommand true "更新定时发布请求"
|
||||
// @Success 200 {object} map[string]interface{} "更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id}/update-schedule-publish [post]
|
||||
func (h *AnnouncementHandler) UpdateSchedulePublishAnnouncement(c *gin.Context) {
|
||||
var cmd commands.UpdateSchedulePublishAnnouncementCommand
|
||||
|
||||
// 先绑定URI参数(公告ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(定时发布时间)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.UpdateSchedulePublishAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("更新定时发布时间失败", zap.String("id", cmd.ID), zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "更新成功")
|
||||
}
|
||||
|
||||
// CancelSchedulePublishAnnouncement 取消定时发布公告
|
||||
// @Summary 取消定时发布公告
|
||||
// @Description 取消公告的定时发布
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "公告ID"
|
||||
// @Success 200 {object} map[string]interface{} "取消成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/{id}/cancel-schedule [post]
|
||||
func (h *AnnouncementHandler) CancelSchedulePublishAnnouncement(c *gin.Context) {
|
||||
var cmd commands.CancelSchedulePublishAnnouncementCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.CancelSchedulePublishAnnouncement(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("取消定时发布失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "取消成功")
|
||||
}
|
||||
|
||||
// GetAnnouncementStats 获取公告统计信息
|
||||
// @Summary 获取公告统计信息
|
||||
// @Description 获取公告的统计数据
|
||||
// @Tags 公告管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.AnnouncementStatsResponse "获取统计信息成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/announcements/stats [get]
|
||||
func (h *AnnouncementHandler) GetAnnouncementStats(c *gin.Context) {
|
||||
response, err := h.appService.GetAnnouncementStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取公告统计信息失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取统计信息成功")
|
||||
}
|
||||
666
internal/infrastructure/http/handlers/api_handler.go
Normal file
666
internal/infrastructure/http/handlers/api_handler.go
Normal file
@@ -0,0 +1,666 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
"hyapi-server/internal/application/api"
|
||||
"hyapi-server/internal/application/api/commands"
|
||||
"hyapi-server/internal/application/api/dto"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ApiHandler API调用HTTP处理器
|
||||
type ApiHandler struct {
|
||||
appService api.ApiApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewApiHandler 创建API调用HTTP处理器
|
||||
func NewApiHandler(
|
||||
appService api.ApiApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
) *ApiHandler {
|
||||
return &ApiHandler{
|
||||
appService: appService,
|
||||
responseBuilder: responseBuilder,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// HandleApiCall 统一API调用入口
|
||||
// @Summary API调用
|
||||
// @Description 统一API调用入口,参数加密传输
|
||||
// @Tags API调用
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.ApiCallCommand true "API调用请求"
|
||||
// @Success 200 {object} dto.ApiCallResponse "调用成功"
|
||||
// @Failure 400 {object} dto.ApiCallResponse "请求参数错误"
|
||||
// @Failure 401 {object} dto.ApiCallResponse "未授权"
|
||||
// @Failure 429 {object} dto.ApiCallResponse "请求过于频繁"
|
||||
// @Failure 500 {object} dto.ApiCallResponse "服务器内部错误"
|
||||
// @Router /api/v1/:api_name [post]
|
||||
func (h *ApiHandler) HandleApiCall(c *gin.Context) {
|
||||
// 1. 基础参数校验
|
||||
accessId := c.GetHeader("Access-Id")
|
||||
if accessId == "" {
|
||||
response := dto.NewErrorResponse(1005, "缺少Access-Id", "")
|
||||
c.JSON(200, response)
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 绑定和校验请求参数
|
||||
var cmd commands.ApiCallCommand
|
||||
cmd.ClientIP = c.ClientIP()
|
||||
cmd.AccessId = accessId
|
||||
cmd.ApiName = c.Param("api_name")
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
response := dto.NewErrorResponse(1003, "请求参数结构不正确", "")
|
||||
c.JSON(200, response)
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 调用应用服务
|
||||
transactionId, encryptedResp, err := h.appService.CallApi(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
// 根据错误类型返回对应的错误码和预定义错误消息
|
||||
errorCode := api.GetErrorCode(err)
|
||||
errorMessage := api.GetErrorMessage(err)
|
||||
response := dto.NewErrorResponse(errorCode, errorMessage, transactionId)
|
||||
c.JSON(200, response) // API调用接口统一返回200状态码
|
||||
return
|
||||
}
|
||||
|
||||
// 4. 返回成功响应
|
||||
response := dto.NewSuccessResponse(transactionId, encryptedResp)
|
||||
c.JSON(200, response)
|
||||
}
|
||||
|
||||
// GetUserApiKeys 获取用户API密钥
|
||||
func (h *ApiHandler) GetUserApiKeys(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserApiKeys(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户API密钥失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取API密钥成功")
|
||||
}
|
||||
|
||||
// GetUserWhiteList 获取用户白名单列表
|
||||
func (h *ApiHandler) GetUserWhiteList(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取查询参数
|
||||
remarkKeyword := c.Query("remark") // 备注模糊查询关键词
|
||||
|
||||
result, err := h.appService.GetUserWhiteList(c.Request.Context(), userID, remarkKeyword)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户白名单失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取白名单成功")
|
||||
}
|
||||
|
||||
// AddWhiteListIP 添加白名单IP
|
||||
func (h *ApiHandler) AddWhiteListIP(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var req dto.WhiteListRequest
|
||||
if err := h.validator.BindAndValidate(c, &req); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.AddWhiteListIP(c.Request.Context(), userID, req.IPAddress, req.Remark)
|
||||
if err != nil {
|
||||
h.logger.Error("添加白名单IP失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "添加白名单IP成功")
|
||||
}
|
||||
|
||||
// DeleteWhiteListIP 删除白名单IP
|
||||
// @Summary 删除白名单IP
|
||||
// @Description 从当前用户的白名单中删除指定IP地址
|
||||
// @Tags API管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param ip path string true "IP地址"
|
||||
// @Success 200 {object} map[string]interface{} "删除白名单IP成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/my/whitelist/{ip} [delete]
|
||||
func (h *ApiHandler) DeleteWhiteListIP(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
ipAddress := c.Param("ip")
|
||||
if ipAddress == "" {
|
||||
h.responseBuilder.BadRequest(c, "IP地址不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.DeleteWhiteListIP(c.Request.Context(), userID, ipAddress)
|
||||
if err != nil {
|
||||
h.logger.Error("删除白名单IP失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "删除白名单IP成功")
|
||||
}
|
||||
|
||||
// EncryptParams 加密参数接口(用于前端调试)
|
||||
// @Summary 加密参数
|
||||
// @Description 用于前端调试时加密API调用参数
|
||||
// @Tags API调试
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.EncryptCommand true "加密请求"
|
||||
// @Success 200 {object} dto.EncryptResponse "加密成功"
|
||||
// @Failure 400 {object} dto.EncryptResponse "请求参数错误"
|
||||
// @Failure 401 {object} dto.EncryptResponse "未授权"
|
||||
// @Router /api/v1/encrypt [post]
|
||||
func (h *ApiHandler) EncryptParams(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.EncryptCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用应用服务层进行加密
|
||||
encryptedData, err := h.appService.EncryptParams(c.Request.Context(), userID, &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("加密参数失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "加密参数失败")
|
||||
return
|
||||
}
|
||||
|
||||
response := dto.EncryptResponse{
|
||||
EncryptedData: encryptedData,
|
||||
}
|
||||
h.responseBuilder.Success(c, response, "加密成功")
|
||||
}
|
||||
|
||||
// DecryptParams 解密参数
|
||||
// @Summary 解密参数
|
||||
// @Description 使用密钥解密加密的数据
|
||||
// @Tags API调试
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.DecryptCommand true "解密请求"
|
||||
// @Success 200 {object} map[string]interface{} "解密成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未授权"
|
||||
// @Failure 500 {object} map[string]interface{} "解密失败"
|
||||
// @Router /api/v1/decrypt [post]
|
||||
func (h *ApiHandler) DecryptParams(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.DecryptCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用应用服务层进行解密
|
||||
decryptedData, err := h.appService.DecryptParams(c.Request.Context(), userID, &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("解密参数失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "解密参数失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, decryptedData, "解密成功")
|
||||
}
|
||||
|
||||
// GetFormConfig 获取指定API的表单配置
|
||||
// @Summary 获取表单配置
|
||||
// @Description 获取指定API的表单配置,用于前端动态生成表单
|
||||
// @Tags API调试
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param api_code path string true "API代码"
|
||||
// @Success 200 {object} map[string]interface{} "获取成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未授权"
|
||||
// @Failure 404 {object} map[string]interface{} "API接口不存在"
|
||||
// @Router /api/v1/form-config/{api_code} [get]
|
||||
func (h *ApiHandler) GetFormConfig(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
apiCode := c.Param("api_code")
|
||||
if apiCode == "" {
|
||||
h.responseBuilder.BadRequest(c, "API代码不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("获取表单配置", zap.String("api_code", apiCode), zap.String("user_id", userID))
|
||||
|
||||
// 获取表单配置
|
||||
config, err := h.appService.GetFormConfig(c.Request.Context(), apiCode)
|
||||
if err != nil {
|
||||
h.logger.Error("获取表单配置失败", zap.String("api_code", apiCode), zap.String("user_id", userID), zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取表单配置失败")
|
||||
return
|
||||
}
|
||||
|
||||
if config == nil {
|
||||
h.responseBuilder.BadRequest(c, "API接口不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("获取表单配置成功", zap.String("api_code", apiCode), zap.String("user_id", userID), zap.Int("field_count", len(config.Fields)))
|
||||
h.responseBuilder.Success(c, config, "获取表单配置成功")
|
||||
}
|
||||
|
||||
// getCurrentUserID 获取当前用户ID
|
||||
func (h *ApiHandler) getCurrentUserID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetUserApiCalls 获取用户API调用记录
|
||||
// @Summary 获取用户API调用记录
|
||||
// @Description 获取当前用户的API调用记录列表,支持分页和筛选
|
||||
// @Tags API管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param start_time query string false "开始时间 (格式: 2006-01-02 15:04:05)"
|
||||
// @Param end_time query string false "结束时间 (格式: 2006-01-02 15:04:05)"
|
||||
// @Param transaction_id query string false "交易ID"
|
||||
// @Param product_name query string false "产品名称"
|
||||
// @Param status query string false "状态 (pending/success/failed)"
|
||||
// @Success 200 {object} dto.ApiCallListResponse "获取成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/my/api-calls [get]
|
||||
func (h *ApiHandler) GetUserApiCalls(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId := c.Query("transaction_id"); transactionId != "" {
|
||||
filters["transaction_id"] = transactionId
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName := c.Query("product_name"); productName != "" {
|
||||
filters["product_name"] = productName
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status := c.Query("status"); status != "" {
|
||||
filters["status"] = status
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: "created_at",
|
||||
Order: "desc",
|
||||
}
|
||||
|
||||
result, err := h.appService.GetUserApiCalls(c.Request.Context(), userID, filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户API调用记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取API调用记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取API调用记录成功")
|
||||
}
|
||||
|
||||
// GetAdminApiCalls 获取管理端API调用记录
|
||||
// @Summary 获取管理端API调用记录
|
||||
// @Description 管理员获取API调用记录,支持筛选和分页
|
||||
// @Tags API管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param user_id query string false "用户ID"
|
||||
// @Param transaction_id query string false "交易ID"
|
||||
// @Param product_name query string false "产品名称"
|
||||
// @Param status query string false "状态"
|
||||
// @Param start_time query string false "开始时间" format(date-time)
|
||||
// @Param end_time query string false "结束时间" format(date-time)
|
||||
// @Param sort_by query string false "排序字段"
|
||||
// @Param sort_order query string false "排序方向" Enums(asc, desc)
|
||||
// @Success 200 {object} dto.ApiCallListResponse "获取API调用记录成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/api-calls [get]
|
||||
func (h *ApiHandler) GetAdminApiCalls(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
page := h.getIntQuery(c, "page", 1)
|
||||
pageSize := h.getIntQuery(c, "page_size", 10)
|
||||
|
||||
// 构建筛选条件
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 用户ID筛选
|
||||
if userId := c.Query("user_id"); userId != "" {
|
||||
filters["user_id"] = userId
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
// 交易ID筛选
|
||||
if transactionId := c.Query("transaction_id"); transactionId != "" {
|
||||
filters["transaction_id"] = transactionId
|
||||
}
|
||||
|
||||
// 产品名称筛选
|
||||
if productName := c.Query("product_name"); productName != "" {
|
||||
filters["product_name"] = productName
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if status := c.Query("status"); status != "" {
|
||||
filters["status"] = status
|
||||
}
|
||||
|
||||
// 构建分页选项
|
||||
options := interfaces.ListOptions{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Sort: "created_at",
|
||||
Order: "desc",
|
||||
}
|
||||
|
||||
result, err := h.appService.GetAdminApiCalls(c.Request.Context(), filters, options)
|
||||
if err != nil {
|
||||
h.logger.Error("获取管理端API调用记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "获取API调用记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取API调用记录成功")
|
||||
}
|
||||
|
||||
// ExportAdminApiCalls 导出管理端API调用记录
|
||||
// @Summary 导出管理端API调用记录
|
||||
// @Description 管理员导出API调用记录,支持Excel和CSV格式
|
||||
// @Tags API调用管理
|
||||
// @Accept json
|
||||
// @Produce application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,text/csv
|
||||
// @Security Bearer
|
||||
// @Param user_ids query string false "用户ID列表,逗号分隔"
|
||||
// @Param product_ids query string false "产品ID列表,逗号分隔"
|
||||
// @Param start_time query string false "开始时间" format(date-time)
|
||||
// @Param end_time query string false "结束时间" format(date-time)
|
||||
// @Param format query string false "导出格式" Enums(excel, csv) default(excel)
|
||||
// @Success 200 {file} file "导出文件"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/api-calls/export [get]
|
||||
func (h *ApiHandler) ExportAdminApiCalls(c *gin.Context) {
|
||||
// 解析查询参数
|
||||
filters := make(map[string]interface{})
|
||||
|
||||
// 用户ID筛选
|
||||
if userIds := c.Query("user_ids"); userIds != "" {
|
||||
filters["user_ids"] = userIds
|
||||
}
|
||||
|
||||
// 产品ID筛选
|
||||
if productIds := c.Query("product_ids"); productIds != "" {
|
||||
filters["product_ids"] = productIds
|
||||
}
|
||||
|
||||
// 时间范围筛选
|
||||
if startTime := c.Query("start_time"); startTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
|
||||
filters["start_time"] = t
|
||||
}
|
||||
}
|
||||
if endTime := c.Query("end_time"); endTime != "" {
|
||||
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
|
||||
filters["end_time"] = t
|
||||
}
|
||||
}
|
||||
|
||||
// 获取导出格式,默认为excel
|
||||
format := c.DefaultQuery("format", "excel")
|
||||
if format != "excel" && format != "csv" {
|
||||
h.responseBuilder.BadRequest(c, "不支持的导出格式")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用应用服务导出数据
|
||||
fileData, err := h.appService.ExportAdminApiCalls(c.Request.Context(), filters, format)
|
||||
if err != nil {
|
||||
h.logger.Error("导出API调用记录失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, "导出API调用记录失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
contentType := "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
filename := "API调用记录.xlsx"
|
||||
if format == "csv" {
|
||||
contentType = "text/csv;charset=utf-8"
|
||||
filename = "API调用记录.csv"
|
||||
}
|
||||
|
||||
c.Header("Content-Type", contentType)
|
||||
c.Header("Content-Disposition", "attachment; filename="+filename)
|
||||
c.Data(200, contentType, fileData)
|
||||
}
|
||||
|
||||
// getIntQuery 获取整数查询参数
|
||||
func (h *ApiHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
|
||||
if value := c.Query(key); value != "" {
|
||||
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
|
||||
return intValue
|
||||
}
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// GetUserBalanceAlertSettings 获取用户余额预警设置
|
||||
// @Summary 获取用户余额预警设置
|
||||
// @Description 获取当前用户的余额预警配置
|
||||
// @Tags 用户设置
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{} "获取成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未授权"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/user/balance-alert/settings [get]
|
||||
func (h *ApiHandler) GetUserBalanceAlertSettings(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.appService.GetUserBalanceAlertSettings(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户余额预警设置失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取预警设置失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, settings, "获取成功")
|
||||
}
|
||||
|
||||
// UpdateUserBalanceAlertSettings 更新用户余额预警设置
|
||||
// @Summary 更新用户余额预警设置
|
||||
// @Description 更新当前用户的余额预警配置
|
||||
// @Tags 用户设置
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body map[string]interface{} true "预警设置"
|
||||
// @Success 200 {object} map[string]interface{} "更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未授权"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/user/balance-alert/settings [put]
|
||||
func (h *ApiHandler) UpdateUserBalanceAlertSettings(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Enabled bool `json:"enabled" binding:"required"`
|
||||
Threshold float64 `json:"threshold" binding:"required,min=0"`
|
||||
AlertPhone string `json:"alert_phone" binding:"required"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.UpdateUserBalanceAlertSettings(c.Request.Context(), userID, request.Enabled, request.Threshold, request.AlertPhone)
|
||||
if err != nil {
|
||||
h.logger.Error("更新用户余额预警设置失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "更新预警设置失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, gin.H{}, "更新成功")
|
||||
}
|
||||
|
||||
// TestBalanceAlertSms 测试余额预警短信
|
||||
// @Summary 测试余额预警短信
|
||||
// @Description 发送测试预警短信到指定手机号
|
||||
// @Tags 用户设置
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body map[string]interface{} true "测试参数"
|
||||
// @Success 200 {object} map[string]interface{} "发送成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未授权"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/user/balance-alert/test-sms [post]
|
||||
func (h *ApiHandler) TestBalanceAlertSms(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var request struct {
|
||||
Phone string `json:"phone" binding:"required,len=11"`
|
||||
Balance float64 `json:"balance" binding:"required"`
|
||||
AlertType string `json:"alert_type" binding:"required,oneof=low_balance arrears"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&request); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
err := h.appService.TestBalanceAlertSms(c.Request.Context(), userID, request.Phone, request.Balance, request.AlertType)
|
||||
if err != nil {
|
||||
h.logger.Error("发送测试预警短信失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "发送测试短信失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, gin.H{}, "测试短信发送成功")
|
||||
}
|
||||
776
internal/infrastructure/http/handlers/article_handler.go
Normal file
776
internal/infrastructure/http/handlers/article_handler.go
Normal file
@@ -0,0 +1,776 @@
|
||||
//nolint:unused
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"hyapi-server/internal/application/article"
|
||||
"hyapi-server/internal/application/article/dto/commands"
|
||||
appQueries "hyapi-server/internal/application/article/dto/queries"
|
||||
_ "hyapi-server/internal/application/article/dto/responses"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ArticleHandler 文章HTTP处理器
|
||||
type ArticleHandler struct {
|
||||
appService article.ArticleApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewArticleHandler 创建文章HTTP处理器
|
||||
func NewArticleHandler(
|
||||
appService article.ArticleApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
) *ArticleHandler {
|
||||
return &ArticleHandler{
|
||||
appService: appService,
|
||||
responseBuilder: responseBuilder,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateArticle 创建文章
|
||||
// @Summary 创建文章
|
||||
// @Description 创建新的文章
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateArticleCommand true "创建文章请求"
|
||||
// @Success 201 {object} map[string]interface{} "文章创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles [post]
|
||||
func (h *ArticleHandler) CreateArticle(c *gin.Context) {
|
||||
var cmd commands.CreateArticleCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证用户是否已登录
|
||||
if _, exists := c.Get("user_id"); !exists {
|
||||
h.responseBuilder.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.CreateArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("创建文章失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, nil, "文章创建成功")
|
||||
}
|
||||
|
||||
// GetArticleByID 获取文章详情
|
||||
// @Summary 获取文章详情
|
||||
// @Description 根据ID获取文章详情
|
||||
// @Tags 文章管理-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "文章ID"
|
||||
// @Success 200 {object} responses.ArticleInfoResponse "获取文章详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/articles/{id} [get]
|
||||
func (h *ArticleHandler) GetArticleByID(c *gin.Context) {
|
||||
var query appQueries.GetArticleQuery
|
||||
|
||||
// 绑定URI参数(文章ID)
|
||||
if err := h.validator.ValidateParam(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.GetArticleByID(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取文章详情失败", zap.Error(err))
|
||||
h.responseBuilder.NotFound(c, "文章不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取文章详情成功")
|
||||
}
|
||||
|
||||
// ListArticles 获取文章列表
|
||||
// @Summary 获取文章列表
|
||||
// @Description 分页获取文章列表,支持多种筛选条件
|
||||
// @Tags 文章管理-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param status query string false "文章状态"
|
||||
// @Param category_id query string false "分类ID"
|
||||
// @Param tag_id query string false "标签ID"
|
||||
// @Param title query string false "标题关键词"
|
||||
// @Param summary query string false "摘要关键词"
|
||||
// @Param is_featured query bool false "是否推荐"
|
||||
// @Param order_by query string false "排序字段"
|
||||
// @Param order_dir query string false "排序方向"
|
||||
// @Success 200 {object} responses.ArticleListResponse "获取文章列表成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/articles [get]
|
||||
func (h *ArticleHandler) ListArticles(c *gin.Context) {
|
||||
var query appQueries.ListArticleQuery
|
||||
if err := h.validator.ValidateQuery(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
if query.PageSize > 100 {
|
||||
query.PageSize = 100
|
||||
}
|
||||
|
||||
response, err := h.appService.ListArticles(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取文章列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取文章列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取文章列表成功")
|
||||
}
|
||||
|
||||
// ListArticlesForAdmin 获取文章列表(管理员端)
|
||||
// @Summary 获取文章列表(管理员端)
|
||||
// @Description 分页获取文章列表,支持多种筛选条件,包含所有状态的文章
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param status query string false "文章状态"
|
||||
// @Param category_id query string false "分类ID"
|
||||
// @Param tag_id query string false "标签ID"
|
||||
// @Param title query string false "标题关键词"
|
||||
// @Param summary query string false "摘要关键词"
|
||||
// @Param is_featured query bool false "是否推荐"
|
||||
// @Param order_by query string false "排序字段"
|
||||
// @Param order_dir query string false "排序方向"
|
||||
// @Success 200 {object} responses.ArticleListResponse "获取文章列表成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles [get]
|
||||
func (h *ArticleHandler) ListArticlesForAdmin(c *gin.Context) {
|
||||
var query appQueries.ListArticleQuery
|
||||
if err := h.validator.ValidateQuery(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
if query.Page <= 0 {
|
||||
query.Page = 1
|
||||
}
|
||||
if query.PageSize <= 0 {
|
||||
query.PageSize = 10
|
||||
}
|
||||
if query.PageSize > 100 {
|
||||
query.PageSize = 100
|
||||
}
|
||||
|
||||
response, err := h.appService.ListArticlesForAdmin(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取文章列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取文章列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取文章列表成功")
|
||||
}
|
||||
|
||||
// UpdateArticle 更新文章
|
||||
// @Summary 更新文章
|
||||
// @Description 更新文章信息
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Param request body commands.UpdateArticleCommand true "更新文章请求"
|
||||
// @Success 200 {object} map[string]interface{} "文章更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id} [put]
|
||||
func (h *ArticleHandler) UpdateArticle(c *gin.Context) {
|
||||
var cmd commands.UpdateArticleCommand
|
||||
|
||||
// 先绑定URI参数(文章ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(文章信息)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.UpdateArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("更新文章失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "文章更新成功")
|
||||
}
|
||||
|
||||
// DeleteArticle 删除文章
|
||||
// @Summary 删除文章
|
||||
// @Description 删除指定文章
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Success 200 {object} map[string]interface{} "文章删除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id} [delete]
|
||||
func (h *ArticleHandler) DeleteArticle(c *gin.Context) {
|
||||
var cmd commands.DeleteArticleCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.DeleteArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("删除文章失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "文章删除成功")
|
||||
}
|
||||
|
||||
// PublishArticle 发布文章
|
||||
// @Summary 发布文章
|
||||
// @Description 将草稿文章发布
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Success 200 {object} map[string]interface{} "文章发布成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id}/publish [post]
|
||||
func (h *ArticleHandler) PublishArticle(c *gin.Context) {
|
||||
var cmd commands.PublishArticleCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.PublishArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("发布文章失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "文章发布成功")
|
||||
}
|
||||
|
||||
// SchedulePublishArticle 定时发布文章
|
||||
// @Summary 定时发布文章
|
||||
// @Description 设置文章的定时发布时间,支持格式:YYYY-MM-DD HH:mm:ss
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Param request body commands.SchedulePublishCommand true "定时发布请求"
|
||||
// @Success 200 {object} map[string]interface{} "定时发布设置成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id}/schedule-publish [post]
|
||||
func (h *ArticleHandler) SchedulePublishArticle(c *gin.Context) {
|
||||
var cmd commands.SchedulePublishCommand
|
||||
|
||||
// 先绑定URI参数(文章ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(定时发布时间)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.SchedulePublishArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("设置定时发布失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "定时发布设置成功")
|
||||
}
|
||||
|
||||
// CancelSchedulePublishArticle 取消定时发布文章
|
||||
// @Summary 取消定时发布文章
|
||||
// @Description 取消文章的定时发布设置
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Success 200 {object} map[string]interface{} "取消定时发布成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id}/cancel-schedule [post]
|
||||
func (h *ArticleHandler) CancelSchedulePublishArticle(c *gin.Context) {
|
||||
var cmd commands.CancelScheduleCommand
|
||||
|
||||
// 绑定URI参数(文章ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.CancelSchedulePublishArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("取消定时发布失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "取消定时发布成功")
|
||||
}
|
||||
|
||||
// ArchiveArticle 归档文章
|
||||
// @Summary 归档文章
|
||||
// @Description 将已发布文章归档
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Success 200 {object} map[string]interface{} "文章归档成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id}/archive [post]
|
||||
func (h *ArticleHandler) ArchiveArticle(c *gin.Context) {
|
||||
var cmd commands.ArchiveArticleCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.ArchiveArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("归档文章失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "文章归档成功")
|
||||
}
|
||||
|
||||
// SetFeatured 设置推荐状态
|
||||
// @Summary 设置推荐状态
|
||||
// @Description 设置文章的推荐状态
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Param request body commands.SetFeaturedCommand true "设置推荐状态请求"
|
||||
// @Success 200 {object} map[string]interface{} "设置推荐状态成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id}/featured [put]
|
||||
func (h *ArticleHandler) SetFeatured(c *gin.Context) {
|
||||
var cmd commands.SetFeaturedCommand
|
||||
|
||||
// 先绑定URI参数(文章ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(推荐状态)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.SetFeatured(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("设置推荐状态失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "设置推荐状态成功")
|
||||
}
|
||||
|
||||
// GetArticleStats 获取文章统计
|
||||
// @Summary 获取文章统计
|
||||
// @Description 获取文章相关统计数据
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.ArticleStatsResponse "获取统计成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/stats [get]
|
||||
func (h *ArticleHandler) GetArticleStats(c *gin.Context) {
|
||||
response, err := h.appService.GetArticleStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取文章统计失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取文章统计失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取统计成功")
|
||||
}
|
||||
|
||||
// UpdateSchedulePublishArticle 修改定时发布时间
|
||||
// @Summary 修改定时发布时间
|
||||
// @Description 修改文章的定时发布时间
|
||||
// @Tags 文章管理-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "文章ID"
|
||||
// @Param request body commands.SchedulePublishCommand true "修改定时发布请求"
|
||||
// @Success 200 {object} map[string]interface{} "修改定时发布时间成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "文章不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/articles/{id}/update-schedule-publish [post]
|
||||
func (h *ArticleHandler) UpdateSchedulePublishArticle(c *gin.Context) {
|
||||
var cmd commands.SchedulePublishCommand
|
||||
|
||||
// 先绑定URI参数(文章ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(定时发布时间)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.UpdateSchedulePublishArticle(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("修改定时发布时间失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "修改定时发布时间成功")
|
||||
}
|
||||
|
||||
// ==================== 分类相关方法 ====================
|
||||
|
||||
// ListCategories 获取分类列表
|
||||
// @Summary 获取分类列表
|
||||
// @Description 获取所有文章分类
|
||||
// @Tags 文章分类-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} responses.CategoryListResponse "获取分类列表成功"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/article-categories [get]
|
||||
func (h *ArticleHandler) ListCategories(c *gin.Context) {
|
||||
response, err := h.appService.ListCategories(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取分类列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取分类列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取分类列表成功")
|
||||
}
|
||||
|
||||
// GetCategoryByID 获取分类详情
|
||||
// @Summary 获取分类详情
|
||||
// @Description 根据ID获取分类详情
|
||||
// @Tags 文章分类-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "分类ID"
|
||||
// @Success 200 {object} responses.CategoryInfoResponse "获取分类详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 404 {object} map[string]interface{} "分类不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/article-categories/{id} [get]
|
||||
func (h *ArticleHandler) GetCategoryByID(c *gin.Context) {
|
||||
var query appQueries.GetCategoryQuery
|
||||
|
||||
// 绑定URI参数(分类ID)
|
||||
if err := h.validator.ValidateParam(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.GetCategoryByID(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取分类详情失败", zap.Error(err))
|
||||
h.responseBuilder.NotFound(c, "分类不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取分类详情成功")
|
||||
}
|
||||
|
||||
// CreateCategory 创建分类
|
||||
// @Summary 创建分类
|
||||
// @Description 创建新的文章分类
|
||||
// @Tags 文章分类-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateCategoryCommand true "创建分类请求"
|
||||
// @Success 201 {object} map[string]interface{} "分类创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/article-categories [post]
|
||||
func (h *ArticleHandler) CreateCategory(c *gin.Context) {
|
||||
var cmd commands.CreateCategoryCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.CreateCategory(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("创建分类失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, nil, "分类创建成功")
|
||||
}
|
||||
|
||||
// UpdateCategory 更新分类
|
||||
// @Summary 更新分类
|
||||
// @Description 更新分类信息
|
||||
// @Tags 文章分类-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "分类ID"
|
||||
// @Param request body commands.UpdateCategoryCommand true "更新分类请求"
|
||||
// @Success 200 {object} map[string]interface{} "分类更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "分类不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/article-categories/{id} [put]
|
||||
func (h *ArticleHandler) UpdateCategory(c *gin.Context) {
|
||||
var cmd commands.UpdateCategoryCommand
|
||||
|
||||
// 先绑定URI参数(分类ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(分类信息)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.UpdateCategory(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("更新分类失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "分类更新成功")
|
||||
}
|
||||
|
||||
// DeleteCategory 删除分类
|
||||
// @Summary 删除分类
|
||||
// @Description 删除指定分类
|
||||
// @Tags 文章分类-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "分类ID"
|
||||
// @Success 200 {object} map[string]interface{} "分类删除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "分类不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/article-categories/{id} [delete]
|
||||
func (h *ArticleHandler) DeleteCategory(c *gin.Context) {
|
||||
var cmd commands.DeleteCategoryCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.DeleteCategory(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("删除分类失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "分类删除成功")
|
||||
}
|
||||
|
||||
// ==================== 标签相关方法 ====================
|
||||
|
||||
// ListTags 获取标签列表
|
||||
// @Summary 获取标签列表
|
||||
// @Description 获取所有文章标签
|
||||
// @Tags 文章标签-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Success 200 {object} responses.TagListResponse "获取标签列表成功"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/article-tags [get]
|
||||
func (h *ArticleHandler) ListTags(c *gin.Context) {
|
||||
response, err := h.appService.ListTags(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取标签列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取标签列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取标签列表成功")
|
||||
}
|
||||
|
||||
// GetTagByID 获取标签详情
|
||||
// @Summary 获取标签详情
|
||||
// @Description 根据ID获取标签详情
|
||||
// @Tags 文章标签-用户端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param id path string true "标签ID"
|
||||
// @Success 200 {object} responses.TagInfoResponse "获取标签详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 404 {object} map[string]interface{} "标签不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/article-tags/{id} [get]
|
||||
func (h *ArticleHandler) GetTagByID(c *gin.Context) {
|
||||
var query appQueries.GetTagQuery
|
||||
|
||||
// 绑定URI参数(标签ID)
|
||||
if err := h.validator.ValidateParam(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.GetTagByID(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取标签详情失败", zap.Error(err))
|
||||
h.responseBuilder.NotFound(c, "标签不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取标签详情成功")
|
||||
}
|
||||
|
||||
// CreateTag 创建标签
|
||||
// @Summary 创建标签
|
||||
// @Description 创建新的文章标签
|
||||
// @Tags 文章标签-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateTagCommand true "创建标签请求"
|
||||
// @Success 201 {object} map[string]interface{} "标签创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/article-tags [post]
|
||||
func (h *ArticleHandler) CreateTag(c *gin.Context) {
|
||||
var cmd commands.CreateTagCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.CreateTag(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("创建标签失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, nil, "标签创建成功")
|
||||
}
|
||||
|
||||
// UpdateTag 更新标签
|
||||
// @Summary 更新标签
|
||||
// @Description 更新标签信息
|
||||
// @Tags 文章标签-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "标签ID"
|
||||
// @Param request body commands.UpdateTagCommand true "更新标签请求"
|
||||
// @Success 200 {object} map[string]interface{} "标签更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "标签不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/article-tags/{id} [put]
|
||||
func (h *ArticleHandler) UpdateTag(c *gin.Context) {
|
||||
var cmd commands.UpdateTagCommand
|
||||
|
||||
// 先绑定URI参数(标签ID)
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 再绑定JSON请求体(标签信息)
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.UpdateTag(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("更新标签失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "标签更新成功")
|
||||
}
|
||||
|
||||
// DeleteTag 删除标签
|
||||
// @Summary 删除标签
|
||||
// @Description 删除指定标签
|
||||
// @Tags 文章标签-管理端
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "标签ID"
|
||||
// @Success 200 {object} map[string]interface{} "标签删除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "标签不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/article-tags/{id} [delete]
|
||||
func (h *ArticleHandler) DeleteTag(c *gin.Context) {
|
||||
var cmd commands.DeleteTagCommand
|
||||
if err := h.validator.ValidateParam(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.appService.DeleteTag(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("删除标签失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "标签删除成功")
|
||||
}
|
||||
92
internal/infrastructure/http/handlers/captcha_handler.go
Normal file
92
internal/infrastructure/http/handlers/captcha_handler.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/config"
|
||||
"hyapi-server/internal/infrastructure/external/captcha"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// CaptchaHandler 验证码(滑块)HTTP 处理器
|
||||
type CaptchaHandler struct {
|
||||
captchaService *captcha.CaptchaService
|
||||
response interfaces.ResponseBuilder
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCaptchaHandler 创建验证码处理器
|
||||
func NewCaptchaHandler(
|
||||
captchaService *captcha.CaptchaService,
|
||||
response interfaces.ResponseBuilder,
|
||||
cfg *config.Config,
|
||||
logger *zap.Logger,
|
||||
) *CaptchaHandler {
|
||||
return &CaptchaHandler{
|
||||
captchaService: captchaService,
|
||||
response: response,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// EncryptedSceneIdReq 获取加密场景 ID 的请求(可选参数)
|
||||
type EncryptedSceneIdReq struct {
|
||||
ExpireSeconds *int `form:"expire_seconds" json:"expire_seconds"` // 有效期秒数,1~86400,默认 3600
|
||||
}
|
||||
|
||||
// GetEncryptedSceneId 获取加密场景 ID,供前端加密模式初始化阿里云验证码
|
||||
// @Summary 获取验证码加密场景ID
|
||||
// @Description 用于加密模式下发 EncryptedSceneId,前端用此初始化滑块验证码
|
||||
// @Tags 验证码
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param body body EncryptedSceneIdReq false "可选:expire_seconds 有效期(1-86400),默认3600"
|
||||
// @Success 200 {object} map[string]interface{} "encryptedSceneId"
|
||||
// @Failure 400 {object} map[string]interface{} "配置未启用或参数错误"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/captcha/encryptedSceneId [post]
|
||||
func (h *CaptchaHandler) GetEncryptedSceneId(c *gin.Context) {
|
||||
expireSec := 3600
|
||||
if c.Request.ContentLength > 0 {
|
||||
var req EncryptedSceneIdReq
|
||||
if err := c.ShouldBindJSON(&req); err == nil && req.ExpireSeconds != nil {
|
||||
expireSec = *req.ExpireSeconds
|
||||
}
|
||||
}
|
||||
if expireSec <= 0 || expireSec > 86400 {
|
||||
h.response.BadRequest(c, "expire_seconds 必须在 1~86400 之间")
|
||||
return
|
||||
}
|
||||
|
||||
encrypted, err := h.captchaService.GetEncryptedSceneId(expireSec)
|
||||
if err != nil {
|
||||
if err == captcha.ErrCaptchaEncryptMissing || err == captcha.ErrCaptchaConfig {
|
||||
h.logger.Warn("验证码加密场景ID生成失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "验证码加密模式未配置或配置错误")
|
||||
return
|
||||
}
|
||||
h.logger.Error("验证码加密场景ID生成失败", zap.Error(err))
|
||||
h.response.InternalError(c, "生成失败,请稍后重试")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, map[string]string{"encryptedSceneId": encrypted}, "ok")
|
||||
}
|
||||
|
||||
// GetConfig 获取验证码前端配置(是否启用、场景ID等),便于前端决定是否展示滑块
|
||||
// @Summary 获取验证码配置
|
||||
// @Description 返回是否启用滑块、场景ID(非加密模式用)
|
||||
// @Tags 验证码
|
||||
// @Produce json
|
||||
// @Success 200 {object} map[string]interface{} "captchaEnabled, sceneId"
|
||||
// @Router /api/v1/captcha/config [get]
|
||||
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
|
||||
data := map[string]interface{}{
|
||||
"captchaEnabled": h.config.SMS.CaptchaEnabled,
|
||||
"sceneId": h.config.SMS.SceneID,
|
||||
}
|
||||
h.response.Success(c, data, "ok")
|
||||
}
|
||||
729
internal/infrastructure/http/handlers/certification_handler.go
Normal file
729
internal/infrastructure/http/handlers/certification_handler.go
Normal file
@@ -0,0 +1,729 @@
|
||||
//nolint:unused
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/application/certification"
|
||||
"hyapi-server/internal/application/certification/dto/commands"
|
||||
"hyapi-server/internal/application/certification/dto/queries"
|
||||
_ "hyapi-server/internal/application/certification/dto/responses"
|
||||
"hyapi-server/internal/infrastructure/external/storage"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
"hyapi-server/internal/shared/middleware"
|
||||
)
|
||||
|
||||
// CertificationHandler 认证HTTP处理器
|
||||
type CertificationHandler struct {
|
||||
appService certification.CertificationApplicationService
|
||||
response interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
jwtAuth *middleware.JWTAuthMiddleware
|
||||
storageService *storage.QiNiuStorageService
|
||||
}
|
||||
|
||||
// NewCertificationHandler 创建认证处理器
|
||||
func NewCertificationHandler(
|
||||
appService certification.CertificationApplicationService,
|
||||
response interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
jwtAuth *middleware.JWTAuthMiddleware,
|
||||
storageService *storage.QiNiuStorageService,
|
||||
) *CertificationHandler {
|
||||
return &CertificationHandler{
|
||||
appService: appService,
|
||||
response: response,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
jwtAuth: jwtAuth,
|
||||
storageService: storageService,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 认证申请管理 ================
|
||||
// GetCertification 获取认证详情
|
||||
// @Summary 获取认证详情
|
||||
// @Description 根据认证ID获取认证详情
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.CertificationResponse "获取认证详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/details [get]
|
||||
func (h *CertificationHandler) GetCertification(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetCertificationQuery{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
result, err := h.appService.GetCertification(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证详情失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取认证详情成功")
|
||||
}
|
||||
|
||||
// ================ 企业信息管理 ================
|
||||
|
||||
// SubmitEnterpriseInfo 提交企业信息
|
||||
// @Summary 提交企业信息
|
||||
// @Description 提交企业认证所需的企业信息
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.SubmitEnterpriseInfoCommand true "提交企业信息请求"
|
||||
// @Success 200 {object} responses.CertificationResponse "企业信息提交成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/enterprise-info [post]
|
||||
func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.SubmitEnterpriseInfoCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.SubmitEnterpriseInfo(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("提交企业信息失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "企业信息提交成功")
|
||||
}
|
||||
|
||||
// ConfirmAuth 前端确认是否完成认证
|
||||
// @Summary 前端确认认证状态
|
||||
// @Description 前端轮询确认企业认证是否完成
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body queries.ConfirmAuthCommand true "确认状态请求"
|
||||
// @Success 200 {object} responses.ConfirmAuthResponse "状态确认成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/confirm-auth [post]
|
||||
func (h *CertificationHandler) ConfirmAuth(c *gin.Context) {
|
||||
var cmd queries.ConfirmAuthCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.ConfirmAuth(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("确认认证/签署状态失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "状态确认成功")
|
||||
}
|
||||
|
||||
// ConfirmSign 前端确认是否完成签署
|
||||
// @Summary 前端确认签署状态
|
||||
// @Description 前端轮询确认合同签署是否完成
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body queries.ConfirmSignCommand true "确认状态请求"
|
||||
// @Success 200 {object} responses.ConfirmSignResponse "状态确认成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/confirm-sign [post]
|
||||
func (h *CertificationHandler) ConfirmSign(c *gin.Context) {
|
||||
var cmd queries.ConfirmSignCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
result, err := h.appService.ConfirmSign(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("确认认证/签署状态失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "状态确认成功")
|
||||
}
|
||||
|
||||
// ================ 合同管理 ================
|
||||
|
||||
// ApplyContract 申请合同签署
|
||||
// @Summary 申请合同签署
|
||||
// @Description 申请企业认证合同签署
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.ApplyContractCommand true "申请合同请求"
|
||||
// @Success 200 {object} responses.ContractSignUrlResponse "合同申请成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/apply-contract [post]
|
||||
func (h *CertificationHandler) ApplyContract(c *gin.Context) {
|
||||
var cmd commands.ApplyContractCommand
|
||||
cmd.UserID = h.getCurrentUserID(c)
|
||||
if cmd.UserID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.ApplyContract(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("申请合同失败", zap.Error(err), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "合同申请成功")
|
||||
}
|
||||
|
||||
// RecognizeBusinessLicense OCR识别营业执照
|
||||
// @Summary OCR识别营业执照
|
||||
// @Description 上传营业执照图片进行OCR识别,自动填充企业信息
|
||||
// @Tags 认证管理
|
||||
// @Accept multipart/form-data
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param image formData file true "营业执照图片文件"
|
||||
// @Success 200 {object} responses.BusinessLicenseResult "营业执照识别成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/ocr/business-license [post]
|
||||
func (h *CertificationHandler) RecognizeBusinessLicense(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取上传的文件
|
||||
file, err := c.FormFile("image")
|
||||
if err != nil {
|
||||
h.logger.Error("获取上传文件失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "请选择要上传的营业执照图片")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证文件类型
|
||||
allowedTypes := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/png": true,
|
||||
"image/webp": true,
|
||||
}
|
||||
if !allowedTypes[file.Header.Get("Content-Type")] {
|
||||
h.response.BadRequest(c, "只支持JPG、PNG、WEBP格式的图片")
|
||||
return
|
||||
}
|
||||
|
||||
// 验证文件大小(限制为5MB)
|
||||
if file.Size > 5*1024*1024 {
|
||||
h.response.BadRequest(c, "图片大小不能超过5MB")
|
||||
return
|
||||
}
|
||||
|
||||
// 打开文件
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
h.logger.Error("打开上传文件失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "文件读取失败")
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
// 读取文件内容
|
||||
imageBytes, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
h.logger.Error("读取文件内容失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "文件读取失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用OCR服务识别营业执照
|
||||
result, err := h.appService.RecognizeBusinessLicense(c.Request.Context(), imageBytes)
|
||||
if err != nil {
|
||||
h.logger.Error("营业执照OCR识别失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "营业执照识别失败:"+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("营业执照OCR识别成功",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("company_name", result.CompanyName),
|
||||
zap.Float64("confidence", result.Confidence),
|
||||
)
|
||||
|
||||
h.response.Success(c, result, "营业执照识别成功")
|
||||
}
|
||||
|
||||
// UploadCertificationFile 上传认证相关图片到七牛云(企业信息中的营业执照、办公场地、场景附件、授权代表身份证等)
|
||||
// @Summary 上传认证图片
|
||||
// @Description 上传企业信息中使用的图片到七牛云,返回可访问的 URL
|
||||
// @Tags 认证管理
|
||||
// @Accept multipart/form-data
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param file formData file true "图片文件"
|
||||
// @Success 200 {object} map[string]string "上传成功,返回 url 与 key"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/upload [post]
|
||||
func (h *CertificationHandler) UploadCertificationFile(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
h.logger.Error("获取上传文件失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "请选择要上传的图片文件")
|
||||
return
|
||||
}
|
||||
|
||||
allowedTypes := map[string]bool{
|
||||
"image/jpeg": true,
|
||||
"image/jpg": true,
|
||||
"image/png": true,
|
||||
"image/webp": true,
|
||||
}
|
||||
contentType := file.Header.Get("Content-Type")
|
||||
if !allowedTypes[contentType] {
|
||||
h.response.BadRequest(c, "只支持 JPG、PNG、WEBP 格式的图片")
|
||||
return
|
||||
}
|
||||
|
||||
if file.Size > 5*1024*1024 {
|
||||
h.response.BadRequest(c, "图片大小不能超过 5MB")
|
||||
return
|
||||
}
|
||||
|
||||
src, err := file.Open()
|
||||
if err != nil {
|
||||
h.logger.Error("打开上传文件失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "文件读取失败")
|
||||
return
|
||||
}
|
||||
defer src.Close()
|
||||
|
||||
fileBytes, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
h.logger.Error("读取文件内容失败", zap.Error(err), zap.String("user_id", userID))
|
||||
h.response.BadRequest(c, "文件读取失败")
|
||||
return
|
||||
}
|
||||
|
||||
uploadResult, err := h.storageService.UploadFile(c.Request.Context(), fileBytes, file.Filename)
|
||||
if err != nil {
|
||||
h.logger.Error("上传文件到七牛云失败", zap.Error(err), zap.String("user_id", userID), zap.String("file_name", file.Filename))
|
||||
h.response.BadRequest(c, "图片上传失败,请稍后重试")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, map[string]string{
|
||||
"url": uploadResult.URL,
|
||||
"key": uploadResult.Key,
|
||||
}, "上传成功")
|
||||
}
|
||||
|
||||
// ListCertifications 获取认证列表(管理员)
|
||||
// @Summary 获取认证列表
|
||||
// @Description 管理员获取认证申请列表
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param page_size query int false "每页数量" default(10)
|
||||
// @Param sort_by query string false "排序字段"
|
||||
// @Param sort_order query string false "排序方向" Enums(asc, desc)
|
||||
// @Param status query string false "认证状态"
|
||||
// @Param user_id query string false "用户ID"
|
||||
// @Param company_name query string false "公司名称"
|
||||
// @Param legal_person_name query string false "法人姓名"
|
||||
// @Param search_keyword query string false "搜索关键词"
|
||||
// @Success 200 {object} responses.CertificationListResponse "获取认证列表成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications [get]
|
||||
func (h *CertificationHandler) ListCertifications(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var query queries.ListCertificationsQuery
|
||||
if err := h.validator.BindAndValidate(c, &query); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.ListCertifications(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证列表失败", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取认证列表成功")
|
||||
}
|
||||
|
||||
// AdminCompleteCertificationWithoutContract 管理员代用户完成认证(暂不关联合同)
|
||||
// @Summary 管理员代用户完成认证(暂不关联合同)
|
||||
// @Description 后台补充企业信息并直接完成认证,暂时不要求上传合同
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.AdminCompleteCertificationCommand true "管理员代用户完成认证请求"
|
||||
// @Success 200 {object} responses.CertificationResponse "代用户完成认证成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certifications/admin/complete-without-contract [post]
|
||||
func (h *CertificationHandler) AdminCompleteCertificationWithoutContract(c *gin.Context) {
|
||||
adminID := h.getCurrentUserID(c)
|
||||
if adminID == "" {
|
||||
h.response.Unauthorized(c, "用户未登录")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.AdminCompleteCertificationCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.AdminID = adminID
|
||||
|
||||
result, err := h.appService.AdminCompleteCertificationWithoutContract(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("管理员代用户完成认证失败", zap.Error(err), zap.String("admin_id", adminID), zap.String("user_id", cmd.UserID))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "代用户完成认证成功")
|
||||
}
|
||||
|
||||
// AdminListSubmitRecords 管理端分页查询企业信息提交记录
|
||||
// @Summary 管理端企业审核列表
|
||||
// @Tags 认证管理
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码"
|
||||
// @Param page_size query int false "每页条数"
|
||||
// @Param certification_status query string false "按状态机筛选:info_pending_review/info_submitted/info_rejected,空为全部"
|
||||
// @Success 200 {object} responses.AdminSubmitRecordsListResponse
|
||||
// @Router /api/v1/certifications/admin/submit-records [get]
|
||||
func (h *CertificationHandler) AdminListSubmitRecords(c *gin.Context) {
|
||||
query := &queries.AdminListSubmitRecordsQuery{}
|
||||
if err := c.ShouldBindQuery(query); err != nil {
|
||||
h.response.BadRequest(c, "参数错误")
|
||||
return
|
||||
}
|
||||
result, err := h.appService.AdminListSubmitRecords(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.response.Success(c, result, "获取成功")
|
||||
}
|
||||
|
||||
// AdminGetSubmitRecordByID 管理端获取单条提交记录详情
|
||||
// @Summary 管理端企业审核详情
|
||||
// @Tags 认证管理
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "记录ID"
|
||||
// @Success 200 {object} responses.AdminSubmitRecordDetail
|
||||
// @Router /api/v1/certifications/admin/submit-records/{id} [get]
|
||||
func (h *CertificationHandler) AdminGetSubmitRecordByID(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.response.BadRequest(c, "记录ID不能为空")
|
||||
return
|
||||
}
|
||||
result, err := h.appService.AdminGetSubmitRecordByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.response.Success(c, result, "获取成功")
|
||||
}
|
||||
|
||||
// AdminApproveSubmitRecord 管理端审核通过
|
||||
// @Summary 管理端企业审核通过
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "记录ID"
|
||||
// @Param request body object true "可选 remark"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/certifications/admin/submit-records/{id}/approve [post]
|
||||
func (h *CertificationHandler) AdminApproveSubmitRecord(c *gin.Context) {
|
||||
adminID := h.getCurrentUserID(c)
|
||||
if adminID == "" {
|
||||
h.response.Unauthorized(c, "未登录")
|
||||
return
|
||||
}
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.response.BadRequest(c, "记录ID不能为空")
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Remark string `json:"remark"`
|
||||
}
|
||||
_ = c.ShouldBindJSON(&body)
|
||||
if err := h.appService.AdminApproveSubmitRecord(c.Request.Context(), id, adminID, body.Remark); err != nil {
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.response.Success(c, nil, "审核通过")
|
||||
}
|
||||
|
||||
// AdminRejectSubmitRecord 管理端审核拒绝
|
||||
// @Summary 管理端企业审核拒绝
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "记录ID"
|
||||
// @Param request body object true "remark 必填"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/certifications/admin/submit-records/{id}/reject [post]
|
||||
func (h *CertificationHandler) AdminRejectSubmitRecord(c *gin.Context) {
|
||||
adminID := h.getCurrentUserID(c)
|
||||
if adminID == "" {
|
||||
h.response.Unauthorized(c, "未登录")
|
||||
return
|
||||
}
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.response.BadRequest(c, "记录ID不能为空")
|
||||
return
|
||||
}
|
||||
var body struct {
|
||||
Remark string `json:"remark" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
h.response.BadRequest(c, "请填写拒绝原因(remark)")
|
||||
return
|
||||
}
|
||||
if err := h.appService.AdminRejectSubmitRecord(c.Request.Context(), id, adminID, body.Remark); err != nil {
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.response.Success(c, nil, "已拒绝")
|
||||
}
|
||||
|
||||
// AdminTransitionCertificationStatus 管理端按用户变更认证状态(以状态机为准)
|
||||
// @Summary 管理端变更认证状态
|
||||
// @Tags 认证管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.AdminTransitionCertificationStatusCommand true "user_id, target_status(info_submitted/info_rejected), remark"
|
||||
// @Success 200 {object} map[string]interface{}
|
||||
// @Router /api/v1/certifications/admin/transition-status [post]
|
||||
func (h *CertificationHandler) AdminTransitionCertificationStatus(c *gin.Context) {
|
||||
adminID := h.getCurrentUserID(c)
|
||||
if adminID == "" {
|
||||
h.response.Unauthorized(c, "未登录")
|
||||
return
|
||||
}
|
||||
var cmd commands.AdminTransitionCertificationStatusCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.response.BadRequest(c, "参数错误")
|
||||
return
|
||||
}
|
||||
cmd.AdminID = adminID
|
||||
if err := h.appService.AdminTransitionCertificationStatus(c.Request.Context(), &cmd); err != nil {
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
h.response.Success(c, nil, "状态已更新")
|
||||
}
|
||||
|
||||
// ================ 回调处理 ================
|
||||
|
||||
// HandleEsignCallback 处理e签宝回调
|
||||
// @Summary 处理e签宝回调
|
||||
// @Description 处理e签宝的异步回调通知
|
||||
// @Tags 认证管理
|
||||
// @Accept application/json
|
||||
// @Produce text/plain
|
||||
// @Success 200 {string} string "success"
|
||||
// @Failure 400 {string} string "fail"
|
||||
// @Router /api/v1/certifications/esign/callback [post]
|
||||
func (h *CertificationHandler) HandleEsignCallback(c *gin.Context) {
|
||||
// 记录请求基本信息
|
||||
h.logger.Info("收到e签宝回调请求",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("url", c.Request.URL.String()),
|
||||
zap.String("remote_addr", c.ClientIP()),
|
||||
zap.String("user_agent", c.GetHeader("User-Agent")),
|
||||
)
|
||||
|
||||
// 记录所有请求头
|
||||
headers := make(map[string]string)
|
||||
for key, values := range c.Request.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
h.logger.Info("回调请求头信息", zap.Any("headers", headers))
|
||||
|
||||
// 记录URL查询参数
|
||||
queryParams := make(map[string]string)
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
queryParams[key] = values[0]
|
||||
}
|
||||
}
|
||||
if len(queryParams) > 0 {
|
||||
h.logger.Info("回调URL查询参数", zap.Any("query_params", queryParams))
|
||||
}
|
||||
|
||||
// 读取并记录请求体
|
||||
var callbackData *commands.EsignCallbackData
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := c.GetRawData()
|
||||
if err != nil {
|
||||
h.logger.Error("读取回调请求体失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "读取请求体失败")
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(bodyBytes, &callbackData); err != nil {
|
||||
h.logger.Error("回调请求体不是有效的JSON格式", zap.Error(err))
|
||||
h.response.BadRequest(c, "请求体格式错误")
|
||||
return
|
||||
}
|
||||
h.logger.Info("回调请求体内容", zap.Any("body", callbackData))
|
||||
|
||||
// 如果后续还需要用 c.Request.Body
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
}
|
||||
|
||||
// 记录Content-Type
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
h.logger.Info("回调请求Content-Type", zap.String("content_type", contentType))
|
||||
|
||||
// 记录Content-Length
|
||||
contentLength := c.GetHeader("Content-Length")
|
||||
if contentLength != "" {
|
||||
h.logger.Info("回调请求Content-Length", zap.String("content_length", contentLength))
|
||||
}
|
||||
|
||||
// 记录时间戳
|
||||
h.logger.Info("回调请求时间",
|
||||
zap.Time("request_time", time.Now()),
|
||||
zap.String("request_id", c.GetHeader("X-Request-ID")),
|
||||
)
|
||||
|
||||
// 记录完整的请求信息摘要
|
||||
h.logger.Info("e签宝回调完整信息摘要",
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("url", c.Request.URL.String()),
|
||||
zap.String("client_ip", c.ClientIP()),
|
||||
zap.String("content_type", contentType),
|
||||
zap.Any("headers", headers),
|
||||
zap.Any("query_params", queryParams),
|
||||
zap.Any("body", callbackData),
|
||||
)
|
||||
|
||||
// 处理回调数据
|
||||
if callbackData != nil {
|
||||
// 构建请求头映射
|
||||
headers := make(map[string]string)
|
||||
for key, values := range c.Request.Header {
|
||||
if len(values) > 0 {
|
||||
headers[key] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
// 构建查询参数映射
|
||||
queryParams := make(map[string]string)
|
||||
for key, values := range c.Request.URL.Query() {
|
||||
if len(values) > 0 {
|
||||
queryParams[key] = values[0]
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.appService.HandleEsignCallback(c.Request.Context(), &commands.EsignCallbackCommand{
|
||||
Data: callbackData,
|
||||
Headers: headers,
|
||||
QueryParams: queryParams,
|
||||
}); err != nil {
|
||||
h.logger.Error("处理e签宝回调失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "回调处理失败: "+err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 返回成功响应
|
||||
c.JSON(200, map[string]interface{}{
|
||||
"code": "200",
|
||||
"msg": "success",
|
||||
})
|
||||
}
|
||||
|
||||
// ================ 辅助方法 ================
|
||||
|
||||
// getCurrentUserID 获取当前用户ID
|
||||
func (h *CertificationHandler) getCurrentUserID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,471 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/application/product"
|
||||
"hyapi-server/internal/config"
|
||||
financeRepositories "hyapi-server/internal/domains/finance/repositories"
|
||||
)
|
||||
|
||||
// ComponentReportOrderHandler 组件报告订单处理器
|
||||
type ComponentReportOrderHandler struct {
|
||||
service *product.ComponentReportOrderService
|
||||
purchaseOrderRepo financeRepositories.PurchaseOrderRepository
|
||||
config *config.Config
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewComponentReportOrderHandler 创建组件报告订单处理器
|
||||
func NewComponentReportOrderHandler(
|
||||
service *product.ComponentReportOrderService,
|
||||
purchaseOrderRepo financeRepositories.PurchaseOrderRepository,
|
||||
config *config.Config,
|
||||
logger *zap.Logger,
|
||||
) *ComponentReportOrderHandler {
|
||||
return &ComponentReportOrderHandler{
|
||||
service: service,
|
||||
purchaseOrderRepo: purchaseOrderRepo,
|
||||
config: config,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CheckDownloadAvailability 检查下载可用性
|
||||
// GET /api/v1/products/:id/component-report/check
|
||||
func (h *ComponentReportOrderHandler) CheckDownloadAvailability(c *gin.Context) {
|
||||
h.logger.Info("开始检查下载可用性")
|
||||
|
||||
productID := c.Param("id")
|
||||
h.logger.Info("获取产品ID", zap.String("product_id", productID))
|
||||
|
||||
if productID == "" {
|
||||
h.logger.Error("产品ID不能为空")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "产品ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
h.logger.Info("获取用户ID", zap.String("user_id", userID))
|
||||
|
||||
if userID == "" {
|
||||
h.logger.Error("用户未登录")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "用户未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 调用服务获取订单信息,检查是否可以下载
|
||||
orderInfo, err := h.service.GetOrderInfo(c.Request.Context(), userID, productID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取订单信息失败", zap.Error(err), zap.String("product_id", productID), zap.String("user_id", userID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "获取订单信息失败",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("获取订单信息成功", zap.Bool("can_download", orderInfo.CanDownload), zap.Bool("is_package", orderInfo.IsPackage))
|
||||
|
||||
// 返回检查结果
|
||||
message := "需要购买"
|
||||
if orderInfo.CanDownload {
|
||||
message = "可以下载"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": gin.H{
|
||||
"can_download": orderInfo.CanDownload,
|
||||
"is_package": orderInfo.IsPackage,
|
||||
"message": message,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// GetDownloadInfo 获取下载信息和价格计算
|
||||
// GET /api/v1/products/:id/component-report/info
|
||||
func (h *ComponentReportOrderHandler) GetDownloadInfo(c *gin.Context) {
|
||||
h.logger.Info("开始获取下载信息和价格计算")
|
||||
|
||||
productID := c.Param("id")
|
||||
h.logger.Info("获取产品ID", zap.String("product_id", productID))
|
||||
|
||||
if productID == "" {
|
||||
h.logger.Error("产品ID不能为空")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "产品ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
h.logger.Info("获取用户ID", zap.String("user_id", userID))
|
||||
|
||||
if userID == "" {
|
||||
h.logger.Error("用户未登录")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "用户未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
orderInfo, err := h.service.GetOrderInfo(c.Request.Context(), userID, productID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取订单信息失败", zap.Error(err), zap.String("product_id", productID), zap.String("user_id", userID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "获取订单信息失败",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录详细的订单信息
|
||||
h.logger.Info("获取订单信息成功",
|
||||
zap.String("product_id", orderInfo.ProductID),
|
||||
zap.String("product_code", orderInfo.ProductCode),
|
||||
zap.String("product_name", orderInfo.ProductName),
|
||||
zap.Bool("is_package", orderInfo.IsPackage),
|
||||
zap.Int("sub_products_count", len(orderInfo.SubProducts)),
|
||||
zap.String("price", orderInfo.Price),
|
||||
zap.Strings("purchased_product_codes", orderInfo.PurchasedProductCodes),
|
||||
zap.Bool("can_download", orderInfo.CanDownload),
|
||||
)
|
||||
|
||||
// 记录子产品详情
|
||||
for i, subProduct := range orderInfo.SubProducts {
|
||||
h.logger.Info("子产品信息",
|
||||
zap.Int("index", i),
|
||||
zap.String("sub_product_id", subProduct.ProductID),
|
||||
zap.String("sub_product_code", subProduct.ProductCode),
|
||||
zap.String("sub_product_name", subProduct.ProductName),
|
||||
zap.String("price", subProduct.Price),
|
||||
zap.Bool("is_purchased", subProduct.IsPurchased),
|
||||
)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": orderInfo,
|
||||
})
|
||||
}
|
||||
|
||||
// CreatePaymentOrder 创建支付订单
|
||||
// POST /api/v1/products/:id/component-report/create-order
|
||||
func (h *ComponentReportOrderHandler) CreatePaymentOrder(c *gin.Context) {
|
||||
h.logger.Info("开始创建支付订单")
|
||||
|
||||
productID := c.Param("id")
|
||||
h.logger.Info("获取产品ID", zap.String("product_id", productID))
|
||||
|
||||
if productID == "" {
|
||||
h.logger.Error("产品ID不能为空")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "产品ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
h.logger.Info("获取用户ID", zap.String("user_id", userID))
|
||||
|
||||
if userID == "" {
|
||||
h.logger.Error("用户未登录")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "用户未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req product.CreatePaymentOrderRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
h.logger.Error("请求参数错误", zap.Error(err))
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "请求参数错误",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录请求参数
|
||||
h.logger.Info("支付订单请求参数",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("product_id", productID),
|
||||
zap.String("payment_type", req.PaymentType),
|
||||
zap.String("platform", req.Platform),
|
||||
zap.Strings("sub_product_codes", req.SubProductCodes),
|
||||
)
|
||||
|
||||
// 设置用户ID和产品ID
|
||||
req.UserID = userID
|
||||
req.ProductID = productID
|
||||
|
||||
// 如果未指定支付平台,根据User-Agent判断
|
||||
if req.Platform == "" {
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
req.Platform = h.detectPlatform(userAgent)
|
||||
h.logger.Info("根据User-Agent检测平台", zap.String("user_agent", userAgent), zap.String("detected_platform", req.Platform))
|
||||
}
|
||||
|
||||
response, err := h.service.CreatePaymentOrder(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
h.logger.Error("创建支付订单失败", zap.Error(err),
|
||||
zap.String("product_id", productID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("payment_type", req.PaymentType))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "创建支付订单失败",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 记录创建订单成功响应
|
||||
h.logger.Info("创建支付订单成功",
|
||||
zap.String("order_id", response.OrderID),
|
||||
zap.String("order_no", response.OrderNo),
|
||||
zap.String("payment_type", response.PaymentType),
|
||||
zap.String("amount", response.Amount),
|
||||
zap.String("code_url", response.CodeURL),
|
||||
zap.String("pay_url", response.PayURL),
|
||||
)
|
||||
|
||||
// 开发环境下,自动将订单状态设置为已支付
|
||||
if h.config != nil && h.config.App.IsDevelopment() {
|
||||
h.logger.Info("开发环境:自动设置订单为已支付状态",
|
||||
zap.String("order_id", response.OrderID),
|
||||
zap.String("order_no", response.OrderNo))
|
||||
|
||||
// 获取订单信息
|
||||
purchaseOrder, err := h.purchaseOrderRepo.GetByID(c.Request.Context(), response.OrderID)
|
||||
if err != nil {
|
||||
h.logger.Error("开发环境:获取订单信息失败", zap.Error(err), zap.String("order_id", response.OrderID))
|
||||
} else {
|
||||
// 解析金额
|
||||
amount, err := decimal.NewFromString(response.Amount)
|
||||
if err != nil {
|
||||
h.logger.Error("开发环境:解析订单金额失败", zap.Error(err), zap.String("amount", response.Amount))
|
||||
} else {
|
||||
// 标记为已支付(使用开发环境的模拟交易号)
|
||||
tradeNo := "DEV_" + response.OrderNo
|
||||
purchaseOrder.MarkPaid(tradeNo, "", "", amount, amount)
|
||||
|
||||
// 更新订单状态
|
||||
err = h.purchaseOrderRepo.Update(c.Request.Context(), purchaseOrder)
|
||||
if err != nil {
|
||||
h.logger.Error("开发环境:更新订单状态失败", zap.Error(err), zap.String("order_id", response.OrderID))
|
||||
} else {
|
||||
h.logger.Info("开发环境:订单状态已自动设置为已支付",
|
||||
zap.String("order_id", response.OrderID),
|
||||
zap.String("order_no", response.OrderNo),
|
||||
zap.String("trade_no", tradeNo))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": response,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckPaymentStatus 检查支付状态
|
||||
// GET /api/v1/component-report/check-payment/:orderId
|
||||
func (h *ComponentReportOrderHandler) CheckPaymentStatus(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "用户未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
orderID := c.Param("orderId")
|
||||
if orderID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "订单ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.service.CheckPaymentStatus(c.Request.Context(), orderID)
|
||||
if err != nil {
|
||||
h.logger.Error("检查支付状态失败", zap.Error(err), zap.String("order_id", orderID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "检查支付状态失败",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"data": response,
|
||||
"message": "查询支付状态成功",
|
||||
})
|
||||
}
|
||||
|
||||
// DownloadFile 下载文件
|
||||
// GET /api/v1/component-report/download/:orderId
|
||||
func (h *ComponentReportOrderHandler) DownloadFile(c *gin.Context) {
|
||||
h.logger.Info("开始处理文件下载请求")
|
||||
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.logger.Error("用户未登录")
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "用户未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("获取用户ID", zap.String("user_id", userID))
|
||||
|
||||
orderID := c.Param("orderId")
|
||||
if orderID == "" {
|
||||
h.logger.Error("订单ID不能为空")
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"code": 400,
|
||||
"message": "订单ID不能为空",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("获取订单ID", zap.String("order_id", orderID))
|
||||
|
||||
filePath, err := h.service.DownloadFile(c.Request.Context(), orderID)
|
||||
if err != nil {
|
||||
h.logger.Error("下载文件失败", zap.Error(err), zap.String("order_id", orderID), zap.String("user_id", userID))
|
||||
|
||||
// 根据错误类型返回不同的状态码和消息
|
||||
errorMessage := err.Error()
|
||||
statusCode := http.StatusInternalServerError
|
||||
|
||||
// 根据错误消息判断具体错误类型
|
||||
if strings.Contains(errorMessage, "购买订单不存在") {
|
||||
statusCode = http.StatusNotFound
|
||||
} else if strings.Contains(errorMessage, "订单未支付") || strings.Contains(errorMessage, "已过期") {
|
||||
statusCode = http.StatusForbidden
|
||||
} else if strings.Contains(errorMessage, "生成报告文件失败") {
|
||||
statusCode = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
c.JSON(statusCode, gin.H{
|
||||
"code": statusCode,
|
||||
"message": "下载文件失败",
|
||||
"error": errorMessage,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("成功获取文件路径",
|
||||
zap.String("order_id", orderID),
|
||||
zap.String("user_id", userID),
|
||||
zap.String("file_path", filePath))
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Type", "application/zip")
|
||||
c.Header("Content-Disposition", "attachment; filename=component_report.zip")
|
||||
|
||||
// 发送文件
|
||||
h.logger.Info("开始发送文件", zap.String("file_path", filePath))
|
||||
c.File(filePath)
|
||||
h.logger.Info("文件发送成功", zap.String("file_path", filePath))
|
||||
}
|
||||
|
||||
// GetUserOrders 获取用户订单列表
|
||||
// GET /api/v1/component-report/orders
|
||||
func (h *ComponentReportOrderHandler) GetUserOrders(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{
|
||||
"code": 401,
|
||||
"message": "用户未登录",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 解析分页参数
|
||||
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
|
||||
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
|
||||
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize < 1 || pageSize > 100 {
|
||||
pageSize = 10
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
orders, total, err := h.service.GetUserOrders(c.Request.Context(), userID, pageSize, offset)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户订单列表失败", zap.Error(err), zap.String("user_id", userID))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": 500,
|
||||
"message": "获取用户订单列表失败",
|
||||
"error": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 200,
|
||||
"data": gin.H{
|
||||
"list": orders,
|
||||
"total": total,
|
||||
"page": page,
|
||||
"page_size": pageSize,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// detectPlatform 根据 User-Agent 检测支付平台类型
|
||||
func (h *ComponentReportOrderHandler) detectPlatform(userAgent string) string {
|
||||
if userAgent == "" {
|
||||
return "h5" // 默认 H5
|
||||
}
|
||||
|
||||
ua := strings.ToLower(userAgent)
|
||||
|
||||
// 检测移动设备
|
||||
if strings.Contains(ua, "mobile") || strings.Contains(ua, "android") ||
|
||||
strings.Contains(ua, "iphone") || strings.Contains(ua, "ipad") {
|
||||
// 检测是否是支付宝或微信内置浏览器
|
||||
if strings.Contains(ua, "alipay") {
|
||||
return "app" // 支付宝 APP
|
||||
}
|
||||
if strings.Contains(ua, "micromessenger") {
|
||||
return "h5" // 微信 H5
|
||||
}
|
||||
return "h5" // 移动端默认 H5
|
||||
}
|
||||
|
||||
// PC 端
|
||||
return "pc"
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"hyapi-server/internal/application/product"
|
||||
"hyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// FileDownloadHandler 文件下载处理器
|
||||
type FileDownloadHandler struct {
|
||||
uiComponentAppService product.UIComponentApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewFileDownloadHandler 创建文件下载处理器
|
||||
func NewFileDownloadHandler(
|
||||
uiComponentAppService product.UIComponentApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
) *FileDownloadHandler {
|
||||
return &FileDownloadHandler{
|
||||
uiComponentAppService: uiComponentAppService,
|
||||
responseBuilder: responseBuilder,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// DownloadUIComponentFile 下载UI组件文件
|
||||
// @Summary 下载UI组件文件
|
||||
// @Description 下载UI组件文件
|
||||
// @Tags 文件下载
|
||||
// @Accept json
|
||||
// @Produce application/octet-stream
|
||||
// @Param id path string true "UI组件ID"
|
||||
// @Success 200 {file} file "文件内容"
|
||||
// @Failure 400 {object} interfaces.Response "请求参数错误"
|
||||
// @Failure 404 {object} interfaces.Response "UI组件不存在或文件不存在"
|
||||
// @Failure 500 {object} interfaces.Response "服务器内部错误"
|
||||
// @Router /api/v1/ui-components/{id}/download [get]
|
||||
func (h *FileDownloadHandler) DownloadUIComponentFile(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
if id == "" {
|
||||
h.responseBuilder.BadRequest(c, "UI组件ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取UI组件信息
|
||||
component, err := h.uiComponentAppService.GetUIComponentByID(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取UI组件失败", zap.Error(err), zap.String("id", id))
|
||||
h.responseBuilder.InternalError(c, "获取UI组件失败")
|
||||
return
|
||||
}
|
||||
|
||||
if component == nil {
|
||||
h.responseBuilder.NotFound(c, "UI组件不存在")
|
||||
return
|
||||
}
|
||||
|
||||
if component.FilePath == nil {
|
||||
h.responseBuilder.NotFound(c, "UI组件文件不存在")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取文件路径
|
||||
filePath, err := h.uiComponentAppService.DownloadUIComponentFile(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
h.logger.Error("获取UI组件文件路径失败", zap.Error(err), zap.String("id", id))
|
||||
h.responseBuilder.InternalError(c, "获取UI组件文件路径失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 设置下载文件名
|
||||
fileName := component.ComponentName
|
||||
if !strings.HasSuffix(strings.ToLower(fileName), ".zip") {
|
||||
fileName += ".zip"
|
||||
}
|
||||
|
||||
// 设置响应头
|
||||
c.Header("Content-Description", "File Transfer")
|
||||
c.Header("Content-Transfer-Encoding", "binary")
|
||||
c.Header("Content-Disposition", "attachment; filename="+fileName)
|
||||
c.Header("Content-Type", "application/octet-stream")
|
||||
|
||||
// 发送文件
|
||||
c.File(filePath)
|
||||
}
|
||||
1297
internal/infrastructure/http/handlers/finance_handler.go
Normal file
1297
internal/infrastructure/http/handlers/finance_handler.go
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user