This commit is contained in:
2026-04-21 22:36:48 +08:00
commit 488c695fdf
748 changed files with 266838 additions and 0 deletions

View File

@@ -0,0 +1,385 @@
package cache
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/redis/go-redis/v9"
"go.uber.org/zap"
"hyapi-server/internal/shared/interfaces"
)
// RedisCache Redis缓存实现
type RedisCache struct {
client *redis.Client
logger *zap.Logger
prefix string
// 统计信息
hits int64
misses int64
}
// NewRedisCache 创建Redis缓存实例
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
return &RedisCache{
client: client,
logger: logger,
prefix: prefix,
}
}
// Name 返回服务名称
func (r *RedisCache) Name() string {
return "redis-cache"
}
// Initialize 初始化服务
func (r *RedisCache) Initialize(ctx context.Context) error {
// 测试连接
_, err := r.client.Ping(ctx).Result()
if err != nil {
r.logger.Error("Failed to connect to Redis", zap.Error(err))
return fmt.Errorf("redis connection failed: %w", err)
}
r.logger.Info("Redis cache service initialized")
return nil
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
_, err := r.client.Ping(ctx).Result()
return err
}
// Shutdown 关闭服务
func (r *RedisCache) Shutdown(ctx context.Context) error {
return r.client.Close()
}
// Get 获取缓存值
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
fullKey := r.getFullKey(key)
val, err := r.client.Get(ctx, fullKey).Result()
if err != nil {
if err == redis.Nil {
r.misses++
return fmt.Errorf("cache miss: key %s not found", key)
}
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
return err
}
r.hits++
return json.Unmarshal([]byte(val), dest)
}
// Set 设置缓存值
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
fullKey := r.getFullKey(key)
data, err := json.Marshal(value)
if err != nil {
r.logger.Error("序列化缓存数据失败", zap.String("key", key), zap.Error(err))
return fmt.Errorf("failed to marshal value: %w", err)
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
case string:
expiration, _ = time.ParseDuration(v)
default:
expiration = 24 * time.Hour // 默认24小时
}
} else {
expiration = 24 * time.Hour // 默认24小时
}
err = r.client.Set(ctx, fullKey, data, expiration).Err()
if err != nil {
r.logger.Error("设置缓存失败", zap.String("key", key), zap.Error(err))
return err
}
r.logger.Debug("设置缓存成功", zap.String("key", key), zap.Duration("ttl", expiration))
return nil
}
// Delete 删除缓存
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
if len(keys) == 0 {
return nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
err := r.client.Del(ctx, fullKeys...).Err()
if err != nil {
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
return err
}
return nil
}
// Exists 检查键是否存在
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
fullKey := r.getFullKey(key)
count, err := r.client.Exists(ctx, fullKey).Result()
if err != nil {
return false, err
}
return count > 0, nil
}
// GetMultiple 批量获取
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
if len(keys) == 0 {
return make(map[string]interface{}), nil
}
fullKeys := make([]string, len(keys))
for i, key := range keys {
fullKeys[i] = r.getFullKey(key)
}
values, err := r.client.MGet(ctx, fullKeys...).Result()
if err != nil {
r.logger.Error("批量获取缓存失败", zap.Strings("keys", keys), zap.Error(err))
return nil, err
}
result := make(map[string]interface{})
for i, val := range values {
if val != nil {
var data interface{}
// 修复改进JSON反序列化错误处理
if err := json.Unmarshal([]byte(val.(string)), &data); err != nil {
r.logger.Warn("反序列化缓存数据失败",
zap.String("key", keys[i]),
zap.String("value", val.(string)),
zap.Error(err))
continue
}
result[keys[i]] = data
}
}
return result, nil
}
// SetMultiple 批量设置
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
if len(data) == 0 {
return nil
}
var expiration time.Duration
if len(ttl) > 0 {
switch v := ttl[0].(type) {
case time.Duration:
expiration = v
case int:
expiration = time.Duration(v) * time.Second
default:
expiration = 24 * time.Hour
}
} else {
expiration = 24 * time.Hour
}
pipe := r.client.Pipeline()
for key, value := range data {
fullKey := r.getFullKey(key)
jsonData, err := json.Marshal(value)
if err != nil {
continue
}
pipe.Set(ctx, fullKey, jsonData, expiration)
}
_, err := pipe.Exec(ctx)
return err
}
// DeletePattern 按模式删除
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
// 修复:避免重复添加前缀
var fullPattern string
if strings.HasPrefix(pattern, r.prefix+":") {
fullPattern = pattern
} else {
fullPattern = r.getFullKey(pattern)
}
// 检查上下文是否已取消
if ctx.Err() != nil {
return ctx.Err()
}
var cursor uint64
var totalDeleted int64
maxIterations := 100 // 防止无限循环
iteration := 0
for {
// 检查迭代次数限制
iteration++
if iteration > maxIterations {
r.logger.Warn("缓存删除操作达到最大迭代次数限制",
zap.String("pattern", fullPattern),
zap.Int("max_iterations", maxIterations),
zap.Int64("total_deleted", totalDeleted),
)
break
}
// 检查上下文是否已取消
if ctx.Err() != nil {
r.logger.Warn("缓存删除操作被取消",
zap.String("pattern", fullPattern),
zap.Int64("total_deleted", totalDeleted),
zap.Error(ctx.Err()),
)
return ctx.Err()
}
// 执行SCAN操作
keys, next, err := r.client.Scan(ctx, cursor, fullPattern, 1000).Result()
if err != nil {
// 如果是上下文取消错误,直接返回
if err == context.Canceled || err == context.DeadlineExceeded {
r.logger.Warn("缓存删除操作被取消",
zap.String("pattern", fullPattern),
zap.Int64("total_deleted", totalDeleted),
zap.Error(err),
)
return err
}
r.logger.Error("扫描缓存键失败",
zap.String("pattern", fullPattern),
zap.Error(err))
return err
}
// 批量删除找到的键
if len(keys) > 0 {
// 使用pipeline批量删除提高性能
pipe := r.client.Pipeline()
pipe.Del(ctx, keys...)
cmds, err := pipe.Exec(ctx)
if err != nil {
r.logger.Error("批量删除缓存键失败",
zap.Strings("keys", keys),
zap.Error(err))
return err
}
// 统计删除的键数量
for _, cmd := range cmds {
if delCmd, ok := cmd.(*redis.IntCmd); ok {
if deleted, err := delCmd.Result(); err == nil {
totalDeleted += deleted
}
}
}
r.logger.Debug("批量删除缓存键",
zap.Strings("keys", keys),
zap.Int("batch_size", len(keys)),
zap.Int64("total_deleted", totalDeleted),
)
}
cursor = next
if cursor == 0 {
break
}
}
r.logger.Debug("缓存模式删除完成",
zap.String("pattern", fullPattern),
zap.Int64("total_deleted", totalDeleted),
zap.Int("iterations", iteration),
)
return nil
}
// Keys 获取匹配的键
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
fullPattern := r.getFullKey(pattern)
keys, err := r.client.Keys(ctx, fullPattern).Result()
if err != nil {
return nil, err
}
// 移除前缀
result := make([]string, len(keys))
prefixLen := len(r.prefix) + 1 // +1 for ":"
for i, key := range keys {
if len(key) > prefixLen {
result[i] = key[prefixLen:]
} else {
result[i] = key
}
}
return result, nil
}
// Stats 获取缓存统计
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
dbSize, _ := r.client.DBSize(ctx).Result()
return interfaces.CacheStats{
Hits: r.hits,
Misses: r.misses,
Keys: dbSize,
Memory: 0, // 暂时设为0后续可解析Redis info
Connections: 0, // 暂时设为0后续可解析Redis info
}, nil
}
// getFullKey 获取完整键名
func (r *RedisCache) getFullKey(key string) string {
if r.prefix == "" {
return key
}
return fmt.Sprintf("%s:%s", r.prefix, key)
}
// Flush 清空所有缓存
func (r *RedisCache) Flush(ctx context.Context) error {
if r.prefix == "" {
return r.client.FlushDB(ctx).Err()
}
// 只删除带前缀的键
return r.DeletePattern(ctx, "*")
}
// GetClient 获取原始Redis客户端
func (r *RedisCache) GetClient() *redis.Client {
return r.client
}

View File

@@ -0,0 +1,160 @@
package database
import (
"context"
"fmt"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
)
// Config 数据库配置
type Config struct {
Host string
Port string
User string
Password string
Name string
SSLMode string
Timezone string
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
}
// DB 数据库包装器
type DB struct {
*gorm.DB
config Config
}
// NewConnection 创建新的数据库连接
func NewConnection(config Config) (*DB, error) {
// 构建DSN
dsn := buildDSN(config)
// 配置GORM
gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info),
NamingStrategy: schema.NamingStrategy{
SingularTable: true, // 使用单数表名
},
DisableForeignKeyConstraintWhenMigrating: true,
NowFunc: func() time.Time {
return time.Now().In(time.FixedZone("CST", 8*3600)) // 强制使用北京时间
},
PrepareStmt: true,
DisableAutomaticPing: false,
}
// 连接数据库
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
if err != nil {
return nil, fmt.Errorf("连接数据库失败: %w", err)
}
// 获取底层sql.DB
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
}
// 配置连接池
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
// 测试连接
if err := sqlDB.Ping(); err != nil {
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
}
return &DB{
DB: db,
config: config,
}, nil
}
// buildDSN 构建数据库连接字符串
func buildDSN(config Config) string {
return fmt.Sprintf(
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s options='-c timezone=%s'",
config.Host,
config.User,
config.Password,
config.Name,
config.Port,
config.SSLMode,
config.Timezone,
config.Timezone,
)
}
// Close 关闭数据库连接
func (db *DB) Close() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Close()
}
// Ping 检查数据库连接
func (db *DB) Ping() error {
sqlDB, err := db.DB.DB()
if err != nil {
return err
}
return sqlDB.Ping()
}
// GetStats 获取连接池统计信息
func (db *DB) GetStats() (map[string]interface{}, error) {
sqlDB, err := db.DB.DB()
if err != nil {
return nil, err
}
stats := sqlDB.Stats()
return map[string]interface{}{
"max_open_connections": stats.MaxOpenConnections,
"open_connections": stats.OpenConnections,
"in_use": stats.InUse,
"idle": stats.Idle,
"wait_count": stats.WaitCount,
"wait_duration": stats.WaitDuration,
"max_idle_closed": stats.MaxIdleClosed,
"max_idle_time_closed": stats.MaxIdleTimeClosed,
"max_lifetime_closed": stats.MaxLifetimeClosed,
}, nil
}
// BeginTx 开始事务已废弃请使用shared/database.TransactionManager
// @deprecated 请使用 shared/database.TransactionManager
func (db *DB) BeginTx() *gorm.DB {
return db.DB.Begin()
}
// Migrate 执行数据库迁移
func (db *DB) Migrate(models ...interface{}) error {
return db.DB.AutoMigrate(models...)
}
// IsHealthy 检查数据库健康状态
func (db *DB) IsHealthy() bool {
return db.Ping() == nil
}
// WithContext 返回带上下文的数据库实例
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
if c, ok := ctx.(context.Context); ok {
return db.DB.WithContext(c)
}
return db.DB
}
// 注意:事务相关功能已迁移到 shared/database.TransactionManager
// 请使用 TransactionManager 进行事务管理

View File

@@ -0,0 +1,556 @@
package api
import (
"context"
"fmt"
"strings"
"time"
"hyapi-server/internal/domains/api/entities"
"hyapi-server/internal/domains/api/repositories"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ApiCallsTable = "api_calls"
ApiCallCacheTTL = 10 * time.Minute
)
// ApiCallWithProduct 包含产品名称的API调用记录
type ApiCallWithProduct struct {
entities.ApiCall
ProductName string `json:"product_name" gorm:"column:product_name"`
}
type GormApiCallRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ApiCallRepository = (*GormApiCallRepository)(nil)
func NewGormApiCallRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiCallRepository {
return &GormApiCallRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiCallsTable),
}
}
func (r *GormApiCallRepository) Create(ctx context.Context, call *entities.ApiCall) error {
return r.CreateEntity(ctx, call)
}
func (r *GormApiCallRepository) Update(ctx context.Context, call *entities.ApiCall) error {
return r.UpdateEntity(ctx, call)
}
func (r *GormApiCallRepository) FindById(ctx context.Context, id string) (*entities.ApiCall, error) {
var call entities.ApiCall
err := r.SmartGetByID(ctx, id, &call)
if err != nil {
return nil, err
}
return &call, nil
}
func (r *GormApiCallRepository) FindByUserId(ctx context.Context, userId string, limit, offset int) ([]*entities.ApiCall, error) {
var calls []*entities.ApiCall
options := database.CacheListOptions{
Where: "user_id = ?",
Args: []interface{}{userId},
Order: "created_at DESC",
Limit: limit,
Offset: offset,
}
err := r.ListWithCache(ctx, &calls, ApiCallCacheTTL, options)
return calls, err
}
func (r *GormApiCallRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
var calls []*entities.ApiCall
var total int64
// 构建查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 获取总数
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
return calls, total, err
}
func (r *GormApiCallRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.ApiCall, int64, error) {
var calls []*entities.ApiCall
var total int64
// 构建基础查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// TransactionID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品ID筛选
if productId, ok := filters["product_id"].(string); ok && productId != "" {
whereCondition += " AND product_id = ?"
whereArgs = append(whereArgs, productId)
}
// 状态筛选
if status, ok := filters["status"].(string); ok && status != "" {
whereCondition += " AND status = ?"
whereArgs = append(whereArgs, status)
}
}
// 获取总数
count, err := r.CountWhere(ctx, &entities.ApiCall{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.ApiCall{}, &calls, options)
return calls, total, err
}
// ListByUserIdWithFiltersAndProductName 根据用户ID和筛选条件获取API调用记录包含产品名称
func (r *GormApiCallRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
var callsWithProduct []*ApiCallWithProduct
var total int64
// 构建基础查询条件
whereCondition := "ac.user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND ac.created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND ac.created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// TransactionID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND ac.transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereCondition += " AND p.name LIKE ?"
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 状态筛选
if status, ok := filters["status"].(string); ok && status != "" {
whereCondition += " AND ac.status = ?"
whereArgs = append(whereArgs, status)
}
}
// 构建JOIN查询
query := r.GetDB(ctx).Table("api_calls ac").
Select("ac.*, p.name as product_name").
Joins("LEFT JOIN product p ON ac.product_id = p.id").
Where(whereCondition, whereArgs...)
// 获取总数
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, nil, 0, err
}
total = count
// 应用排序和分页
if options.Sort != "" {
query = query.Order("ac." + options.Sort + " " + options.Order)
} else {
query = query.Order("ac.created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
// 执行查询
err = query.Find(&callsWithProduct).Error
if err != nil {
return nil, nil, 0, err
}
// 转换为entities.ApiCall并构建产品名称映射
var calls []*entities.ApiCall
productNameMap := make(map[string]string)
for _, c := range callsWithProduct {
call := c.ApiCall
calls = append(calls, &call)
// 构建产品ID到产品名称的映射
if c.ProductName != "" {
productNameMap[call.ID] = c.ProductName
}
}
return productNameMap, calls, total, nil
}
func (r *GormApiCallRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ?", userId)
}
// CountByUserIdAndProductId 按用户ID和产品ID统计API调用次数
func (r *GormApiCallRepository) CountByUserIdAndProductId(ctx context.Context, userId string, productId string) (int64, error) {
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND product_id = ?", userId, productId)
}
// CountByUserIdAndDateRange 按用户ID和日期范围统计API调用次数
func (r *GormApiCallRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
return r.CountWhere(ctx, &entities.ApiCall{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
}
// GetDailyStatsByUserId 获取用户每日API调用统计
func (r *GormApiCallRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
DATE(created_at) as date,
COUNT(*) as calls
FROM api_calls
WHERE user_id = $1
AND DATE(created_at) >= $2
AND DATE(created_at) <= $3
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetMonthlyStatsByUserId 获取用户每月API调用统计
func (r *GormApiCallRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COUNT(*) as calls
FROM api_calls
WHERE user_id = $1
AND created_at >= $2
AND created_at <= $3
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
func (r *GormApiCallRepository) FindByTransactionId(ctx context.Context, transactionId string) (*entities.ApiCall, error) {
var call entities.ApiCall
err := r.FindOne(ctx, &call, "transaction_id = ?", transactionId)
if err != nil {
return nil, err
}
return &call, nil
}
// ListWithFiltersAndProductName 管理端根据条件筛选所有API调用记录包含产品名称
func (r *GormApiCallRepository) ListWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.ApiCall, int64, error) {
var callsWithProduct []*ApiCallWithProduct
var total int64
// 构建基础查询条件
whereCondition := "1=1"
whereArgs := []interface{}{}
// 应用筛选条件
if filters != nil {
// 用户ID筛选支持单个user_id和多个user_ids
// 如果同时存在优先使用user_ids批量查询
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
// 解析逗号分隔的用户ID列表
userIdsList := strings.Split(userIds, ",")
// 去除空白字符
var cleanUserIds []string
for _, id := range userIdsList {
id = strings.TrimSpace(id)
if id != "" {
cleanUserIds = append(cleanUserIds, id)
}
}
if len(cleanUserIds) > 0 {
placeholders := strings.Repeat("?,", len(cleanUserIds))
placeholders = placeholders[:len(placeholders)-1] // 移除最后一个逗号
whereCondition += " AND ac.user_id IN (" + placeholders + ")"
for _, id := range cleanUserIds {
whereArgs = append(whereArgs, id)
}
}
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
// 单个用户ID筛选
whereCondition += " AND ac.user_id = ?"
whereArgs = append(whereArgs, userId)
}
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND ac.created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND ac.created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// TransactionID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND ac.transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereCondition += " AND p.name LIKE ?"
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 企业名称筛选
if companyName, ok := filters["company_name"].(string); ok && companyName != "" {
whereCondition += " AND ei.company_name LIKE ?"
whereArgs = append(whereArgs, "%"+companyName+"%")
}
// 状态筛选
if status, ok := filters["status"].(string); ok && status != "" {
whereCondition += " AND ac.status = ?"
whereArgs = append(whereArgs, status)
}
}
// 构建JOIN查询
// 需要JOIN product表获取产品名称JOIN users和enterprise_infos表获取企业名称
query := r.GetDB(ctx).Table("api_calls ac").
Select("ac.*, p.name as product_name").
Joins("LEFT JOIN product p ON ac.product_id = p.id").
Joins("LEFT JOIN users u ON ac.user_id = u.id").
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id").
Where(whereCondition, whereArgs...)
// 获取总数
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, nil, 0, err
}
total = count
// 应用排序和分页
if options.Sort != "" {
query = query.Order("ac." + options.Sort + " " + options.Order)
} else {
query = query.Order("ac.created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
// 执行查询
err = query.Find(&callsWithProduct).Error
if err != nil {
return nil, nil, 0, err
}
// 转换为entities.ApiCall并构建产品名称映射
var calls []*entities.ApiCall
productNameMap := make(map[string]string)
for _, c := range callsWithProduct {
call := c.ApiCall
calls = append(calls, &call)
// 构建产品ID到产品名称的映射
if c.ProductName != "" {
productNameMap[call.ID] = c.ProductName
}
}
return productNameMap, calls, total, nil
}
// GetSystemTotalCalls 获取系统总API调用次数
func (r *GormApiCallRepository) GetSystemTotalCalls(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ApiCall{}).Count(&count).Error
return count, err
}
// GetSystemCallsByDateRange 获取系统指定时间范围内的API调用次数
// endDate 应该是结束日期当天的次日00:00:00日统计或下个月1号00:00:00月统计使用 < 而不是 <=
func (r *GormApiCallRepository) GetSystemCallsByDateRange(ctx context.Context, startDate, endDate time.Time) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ApiCall{}).
Where("created_at >= ? AND created_at < ?", startDate, endDate).
Count(&count).Error
return count, err
}
// GetSystemDailyStats 获取系统每日API调用统计
func (r *GormApiCallRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COUNT(*) as calls
FROM api_calls
WHERE DATE(created_at) >= $1
AND DATE(created_at) <= $2
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyStats 获取系统每月API调用统计
func (r *GormApiCallRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COUNT(*) as calls
FROM api_calls
WHERE created_at >= $1
AND created_at < $2
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetApiPopularityRanking 获取API受欢迎程度排行榜
func (r *GormApiCallRepository) GetApiPopularityRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
p.id as product_id,
p.name as api_name,
p.description as api_description,
COUNT(ac.id) as call_count
FROM product p
LEFT JOIN api_calls ac ON p.id = ac.product_id
AND DATE(ac.created_at) = CURRENT_DATE
WHERE p.deleted_at IS NULL
GROUP BY p.id, p.name, p.description
HAVING COUNT(ac.id) > 0
ORDER BY call_count DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
p.id as product_id,
p.name as api_name,
p.description as api_description,
COUNT(ac.id) as call_count
FROM product p
LEFT JOIN api_calls ac ON p.id = ac.product_id
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE p.deleted_at IS NULL
GROUP BY p.id, p.name, p.description
HAVING COUNT(ac.id) > 0
ORDER BY call_count DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
p.id as product_id,
p.name as api_name,
p.description as api_description,
COUNT(ac.id) as call_count
FROM product p
LEFT JOIN api_calls ac ON p.id = ac.product_id
WHERE p.deleted_at IS NULL
GROUP BY p.id, p.name, p.description
HAVING COUNT(ac.id) > 0
ORDER BY call_count DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -0,0 +1,56 @@
package api
import (
"context"
"hyapi-server/internal/domains/api/entities"
"hyapi-server/internal/domains/api/repositories"
"hyapi-server/internal/shared/database"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ApiUsersTable = "api_users"
ApiUserCacheTTL = 30 * time.Minute
)
type GormApiUserRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ApiUserRepository = (*GormApiUserRepository)(nil)
func NewGormApiUserRepository(db *gorm.DB, logger *zap.Logger) repositories.ApiUserRepository {
return &GormApiUserRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ApiUsersTable),
}
}
func (r *GormApiUserRepository) Create(ctx context.Context, user *entities.ApiUser) error {
return r.CreateEntity(ctx, user)
}
func (r *GormApiUserRepository) Update(ctx context.Context, user *entities.ApiUser) error {
return r.UpdateEntity(ctx, user)
}
func (r *GormApiUserRepository) FindByAccessId(ctx context.Context, accessId string) (*entities.ApiUser, error) {
var user entities.ApiUser
err := r.SmartGetByField(ctx, &user, "access_id", accessId, ApiUserCacheTTL)
if err != nil {
return nil, err
}
return &user, nil
}
func (r *GormApiUserRepository) FindByUserId(ctx context.Context, userId string) (*entities.ApiUser, error) {
var user entities.ApiUser
err := r.SmartGetByField(ctx, &user, "user_id", userId, ApiUserCacheTTL)
if err != nil {
return nil, err
}
return &user, nil
}

View File

@@ -0,0 +1,44 @@
package api
import (
"context"
"hyapi-server/internal/domains/api/entities"
"hyapi-server/internal/domains/api/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ReportsTable = "reports"
)
// GormReportRepository 报告记录 GORM 仓储实现
type GormReportRepository struct {
*database.BaseRepositoryImpl
}
var _ repositories.ReportRepository = (*GormReportRepository)(nil)
// NewGormReportRepository 创建报告记录仓储实现
func NewGormReportRepository(db *gorm.DB, logger *zap.Logger) repositories.ReportRepository {
return &GormReportRepository{
BaseRepositoryImpl: database.NewBaseRepositoryImpl(db, logger),
}
}
// Create 创建报告记录
func (r *GormReportRepository) Create(ctx context.Context, report *entities.Report) error {
return r.CreateEntity(ctx, report)
}
// FindByReportID 根据报告编号查询记录
func (r *GormReportRepository) FindByReportID(ctx context.Context, reportID string) (*entities.Report, error) {
var report entities.Report
if err := r.FindOneByField(ctx, &report, "report_id", reportID); err != nil {
return nil, err
}
return &report, nil
}

View File

@@ -0,0 +1,328 @@
package repositories
import (
"context"
"fmt"
"strings"
"time"
"hyapi-server/internal/domains/article/entities"
"hyapi-server/internal/domains/article/repositories"
repoQueries "hyapi-server/internal/domains/article/repositories/queries"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// GormAnnouncementRepository GORM公告仓储实现
type GormAnnouncementRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.AnnouncementRepository = (*GormAnnouncementRepository)(nil)
// NewGormAnnouncementRepository 创建GORM公告仓储
func NewGormAnnouncementRepository(db *gorm.DB, logger *zap.Logger) *GormAnnouncementRepository {
return &GormAnnouncementRepository{
db: db,
logger: logger,
}
}
// Create 创建公告
func (r *GormAnnouncementRepository) Create(ctx context.Context, entity entities.Announcement) (entities.Announcement, error) {
r.logger.Info("创建公告", zap.String("id", entity.ID), zap.String("title", entity.Title))
err := r.db.WithContext(ctx).Create(&entity).Error
if err != nil {
r.logger.Error("创建公告失败", zap.Error(err))
return entity, err
}
return entity, nil
}
// GetByID 根据ID获取公告
func (r *GormAnnouncementRepository) GetByID(ctx context.Context, id string) (entities.Announcement, error) {
var entity entities.Announcement
err := r.db.WithContext(ctx).
Where("id = ?", id).
First(&entity).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return entity, fmt.Errorf("公告不存在")
}
r.logger.Error("获取公告失败", zap.String("id", id), zap.Error(err))
return entity, err
}
return entity, nil
}
// Update 更新公告
func (r *GormAnnouncementRepository) Update(ctx context.Context, entity entities.Announcement) error {
r.logger.Info("更新公告", zap.String("id", entity.ID))
err := r.db.WithContext(ctx).Save(&entity).Error
if err != nil {
r.logger.Error("更新公告失败", zap.String("id", entity.ID), zap.Error(err))
return err
}
return nil
}
// Delete 删除公告
func (r *GormAnnouncementRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除公告", zap.String("id", id))
err := r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id = ?", id).Error
if err != nil {
r.logger.Error("删除公告失败", zap.String("id", id), zap.Error(err))
return err
}
return nil
}
// FindByStatus 根据状态查找公告
func (r *GormAnnouncementRepository) FindByStatus(ctx context.Context, status entities.AnnouncementStatus) ([]*entities.Announcement, error) {
var announcements []entities.Announcement
err := r.db.WithContext(ctx).
Where("status = ?", status).
Order("created_at DESC").
Find(&announcements).Error
if err != nil {
r.logger.Error("根据状态查找公告失败", zap.String("status", string(status)), zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Announcement, len(announcements))
for i := range announcements {
result[i] = &announcements[i]
}
return result, nil
}
// FindScheduled 查找定时发布的公告
func (r *GormAnnouncementRepository) FindScheduled(ctx context.Context) ([]*entities.Announcement, error) {
var announcements []entities.Announcement
now := time.Now()
err := r.db.WithContext(ctx).
Where("status = ? AND scheduled_at IS NOT NULL AND scheduled_at <= ?", entities.AnnouncementStatusDraft, now).
Order("scheduled_at ASC").
Find(&announcements).Error
if err != nil {
r.logger.Error("查找定时发布公告失败", zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Announcement, len(announcements))
for i := range announcements {
result[i] = &announcements[i]
}
return result, nil
}
// ListAnnouncements 获取公告列表
func (r *GormAnnouncementRepository) ListAnnouncements(ctx context.Context, query *repoQueries.ListAnnouncementQuery) ([]*entities.Announcement, int64, error) {
var announcements []entities.Announcement
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
// 应用筛选条件
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.Title != "" {
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
r.logger.Error("获取公告列表总数失败", zap.Error(err))
return nil, 0, err
}
// 应用排序
if query.OrderBy != "" {
orderDir := "DESC"
if query.OrderDir != "" {
orderDir = strings.ToUpper(query.OrderDir)
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 获取数据
if err := dbQuery.Find(&announcements).Error; err != nil {
r.logger.Error("获取公告列表失败", zap.Error(err))
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Announcement, len(announcements))
for i := range announcements {
result[i] = &announcements[i]
}
return result, total, nil
}
// CountByStatus 根据状态统计公告数量
func (r *GormAnnouncementRepository) CountByStatus(ctx context.Context, status entities.AnnouncementStatus) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Announcement{}).
Where("status = ?", status).
Count(&count).Error
if err != nil {
r.logger.Error("统计公告数量失败", zap.String("status", string(status)), zap.Error(err))
return 0, err
}
return count, nil
}
// UpdateStatistics 更新统计信息
// 注意:公告实体目前没有统计字段,此方法预留扩展
func (r *GormAnnouncementRepository) UpdateStatistics(ctx context.Context, announcementID string) error {
r.logger.Info("更新公告统计信息", zap.String("announcement_id", announcementID))
// TODO: 如果将来需要统计字段(如阅读量等),可以在这里实现
return nil
}
// ================ 实现 BaseRepository 接口的其他方法 ================
// Count 统计数量
func (r *GormAnnouncementRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
}
var count int64
err := dbQuery.Count(&count).Error
return count, err
}
// Exists 检查是否存在
func (r *GormAnnouncementRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Announcement{}).
Where("id = ?", id).
Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除
func (r *GormAnnouncementRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id = ?", id).Error
}
// Restore 恢复软删除
func (r *GormAnnouncementRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.Announcement{}).
Where("id = ?", id).
Update("deleted_at", nil).Error
}
// CreateBatch 批量创建
func (r *GormAnnouncementRepository) CreateBatch(ctx context.Context, entities []entities.Announcement) error {
return r.db.WithContext(ctx).Create(&entities).Error
}
// GetByIDs 根据ID列表获取
func (r *GormAnnouncementRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Announcement, error) {
var announcements []entities.Announcement
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&announcements).Error
return announcements, err
}
// UpdateBatch 批量更新
func (r *GormAnnouncementRepository) UpdateBatch(ctx context.Context, entities []entities.Announcement) error {
return r.db.WithContext(ctx).Save(&entities).Error
}
// DeleteBatch 批量删除
func (r *GormAnnouncementRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.Announcement{}, "id IN ?", ids).Error
}
// List 列表查询
func (r *GormAnnouncementRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Announcement, error) {
var announcements []entities.Announcement
dbQuery := r.db.WithContext(ctx).Model(&entities.Announcement{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
}
// 应用排序
if options.Sort != "" {
order := "DESC"
if options.Order != "" {
order = strings.ToUpper(options.Order)
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
}
// 预加载关联数据
if len(options.Include) > 0 {
for _, include := range options.Include {
dbQuery = dbQuery.Preload(include)
}
}
err := dbQuery.Find(&announcements).Error
return announcements, err
}

View File

@@ -0,0 +1,592 @@
package repositories
import (
"context"
"fmt"
"strings"
"hyapi-server/internal/domains/article/entities"
"hyapi-server/internal/domains/article/repositories"
repoQueries "hyapi-server/internal/domains/article/repositories/queries"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// GormArticleRepository GORM文章仓储实现
type GormArticleRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.ArticleRepository = (*GormArticleRepository)(nil)
// NewGormArticleRepository 创建GORM文章仓储
func NewGormArticleRepository(db *gorm.DB, logger *zap.Logger) *GormArticleRepository {
return &GormArticleRepository{
db: db,
logger: logger,
}
}
// Create 创建文章
func (r *GormArticleRepository) Create(ctx context.Context, entity entities.Article) (entities.Article, error) {
r.logger.Info("创建文章", zap.String("id", entity.ID), zap.String("title", entity.Title))
err := r.db.WithContext(ctx).Create(&entity).Error
if err != nil {
r.logger.Error("创建文章失败", zap.Error(err))
return entity, err
}
return entity, nil
}
// GetByID 根据ID获取文章
func (r *GormArticleRepository) GetByID(ctx context.Context, id string) (entities.Article, error) {
var entity entities.Article
err := r.db.WithContext(ctx).
Preload("Category").
Preload("Tags").
Where("id = ?", id).
First(&entity).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return entity, fmt.Errorf("文章不存在")
}
r.logger.Error("获取文章失败", zap.String("id", id), zap.Error(err))
return entity, err
}
return entity, nil
}
// Update 更新文章
func (r *GormArticleRepository) Update(ctx context.Context, entity entities.Article) error {
r.logger.Info("更新文章", zap.String("id", entity.ID))
err := r.db.WithContext(ctx).Save(&entity).Error
if err != nil {
r.logger.Error("更新文章失败", zap.String("id", entity.ID), zap.Error(err))
return err
}
return nil
}
// Delete 删除文章
func (r *GormArticleRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除文章", zap.String("id", id))
err := r.db.WithContext(ctx).Delete(&entities.Article{}, "id = ?", id).Error
if err != nil {
r.logger.Error("删除文章失败", zap.String("id", id), zap.Error(err))
return err
}
return nil
}
// FindByAuthorID 根据作者ID查找文章
func (r *GormArticleRepository) FindByAuthorID(ctx context.Context, authorID string) ([]*entities.Article, error) {
var articles []entities.Article
err := r.db.WithContext(ctx).
Preload("Category").
Preload("Tags").
Where("author_id = ?", authorID).
Order("created_at DESC").
Find(&articles).Error
if err != nil {
r.logger.Error("根据作者ID查找文章失败", zap.String("author_id", authorID), zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, nil
}
// FindByCategoryID 根据分类ID查找文章
func (r *GormArticleRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Article, error) {
var articles []entities.Article
err := r.db.WithContext(ctx).
Preload("Category").
Preload("Tags").
Where("category_id = ?", categoryID).
Order("created_at DESC").
Find(&articles).Error
if err != nil {
r.logger.Error("根据分类ID查找文章失败", zap.String("category_id", categoryID), zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, nil
}
// FindByStatus 根据状态查找文章
func (r *GormArticleRepository) FindByStatus(ctx context.Context, status entities.ArticleStatus) ([]*entities.Article, error) {
var articles []entities.Article
err := r.db.WithContext(ctx).
Preload("Category").
Preload("Tags").
Where("status = ?", status).
Order("created_at DESC").
Find(&articles).Error
if err != nil {
r.logger.Error("根据状态查找文章失败", zap.String("status", string(status)), zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, nil
}
// FindFeatured 查找推荐文章
func (r *GormArticleRepository) FindFeatured(ctx context.Context) ([]*entities.Article, error) {
var articles []entities.Article
err := r.db.WithContext(ctx).
Preload("Category").
Preload("Tags").
Where("is_featured = ? AND status = ?", true, entities.ArticleStatusPublished).
Order("published_at DESC").
Find(&articles).Error
if err != nil {
r.logger.Error("查找推荐文章失败", zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, nil
}
// Search 搜索文章
func (r *GormArticleRepository) Search(ctx context.Context, query *repoQueries.SearchArticleQuery) ([]*entities.Article, int64, error) {
var articles []entities.Article
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
// 应用搜索条件
if query.Keyword != "" {
keyword := "%" + query.Keyword + "%"
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ? OR summary LIKE ?", keyword, keyword, keyword)
}
if query.CategoryID != "" {
// 如果指定了分类ID只查询该分类的文章包括没有分类的文章当CategoryID为空字符串时
if query.CategoryID == "null" || query.CategoryID == "" {
// 查询没有分类的文章
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
} else {
// 查询指定分类的文章
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
}
}
if query.AuthorID != "" {
dbQuery = dbQuery.Where("author_id = ?", query.AuthorID)
}
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
r.logger.Error("获取搜索结果总数失败", zap.Error(err))
return nil, 0, err
}
// 应用排序
if query.OrderBy != "" {
orderDir := "DESC"
if query.OrderDir != "" {
orderDir = strings.ToUpper(query.OrderDir)
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载关联数据
dbQuery = dbQuery.Preload("Category").Preload("Tags")
// 获取数据
if err := dbQuery.Find(&articles).Error; err != nil {
r.logger.Error("搜索文章失败", zap.Error(err))
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, total, nil
}
// ListArticles 获取文章列表(用户端)
func (r *GormArticleRepository) ListArticles(ctx context.Context, query *repoQueries.ListArticleQuery) ([]*entities.Article, int64, error) {
var articles []entities.Article
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
// 用户端不显示归档文章
dbQuery = dbQuery.Where("status != ?", entities.ArticleStatusArchived)
// 应用筛选条件
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.CategoryID != "" {
// 如果指定了分类ID只查询该分类的文章包括没有分类的文章当CategoryID为空字符串时
if query.CategoryID == "null" || query.CategoryID == "" {
// 查询没有分类的文章
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
} else {
// 查询指定分类的文章
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
}
}
if query.TagID != "" {
// 如果指定了标签ID只查询有关联该标签的文章
// 使用子查询而不是JOIN避免影响其他查询条件
subQuery := r.db.WithContext(ctx).Table("article_tag_relations").
Select("article_id").
Where("tag_id = ?", query.TagID)
dbQuery = dbQuery.Where("id IN (?)", subQuery)
}
if query.Title != "" {
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
}
if query.Summary != "" {
dbQuery = dbQuery.Where("summary ILIKE ?", "%"+query.Summary+"%")
}
if query.IsFeatured != nil {
dbQuery = dbQuery.Where("is_featured = ?", *query.IsFeatured)
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
r.logger.Error("获取文章列表总数失败", zap.Error(err))
return nil, 0, err
}
// 应用排序
if query.OrderBy != "" {
orderDir := "DESC"
if query.OrderDir != "" {
orderDir = strings.ToUpper(query.OrderDir)
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载关联数据
dbQuery = dbQuery.Preload("Category").Preload("Tags")
// 获取数据
if err := dbQuery.Find(&articles).Error; err != nil {
r.logger.Error("获取文章列表失败", zap.Error(err))
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, total, nil
}
// ListArticlesForAdmin 获取文章列表(管理员端)
func (r *GormArticleRepository) ListArticlesForAdmin(ctx context.Context, query *repoQueries.ListArticleQuery) ([]*entities.Article, int64, error) {
var articles []entities.Article
var total int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
// 应用筛选条件
if query.Status != "" {
dbQuery = dbQuery.Where("status = ?", query.Status)
}
if query.CategoryID != "" {
// 如果指定了分类ID只查询该分类的文章包括没有分类的文章当CategoryID为空字符串时
if query.CategoryID == "null" || query.CategoryID == "" {
// 查询没有分类的文章
dbQuery = dbQuery.Where("category_id IS NULL OR category_id = ''")
} else {
// 查询指定分类的文章
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
}
}
if query.TagID != "" {
// 如果指定了标签ID只查询有关联该标签的文章
// 使用子查询而不是JOIN避免影响其他查询条件
subQuery := r.db.WithContext(ctx).Table("article_tag_relations").
Select("article_id").
Where("tag_id = ?", query.TagID)
dbQuery = dbQuery.Where("id IN (?)", subQuery)
}
if query.Title != "" {
dbQuery = dbQuery.Where("title ILIKE ?", "%"+query.Title+"%")
}
if query.Summary != "" {
dbQuery = dbQuery.Where("summary ILIKE ?", "%"+query.Summary+"%")
}
if query.IsFeatured != nil {
dbQuery = dbQuery.Where("is_featured = ?", *query.IsFeatured)
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
r.logger.Error("获取文章列表总数失败", zap.Error(err))
return nil, 0, err
}
// 应用排序
if query.OrderBy != "" {
orderDir := "DESC"
if query.OrderDir != "" {
orderDir = strings.ToUpper(query.OrderDir)
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", query.OrderBy, orderDir))
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载关联数据
dbQuery = dbQuery.Preload("Category").Preload("Tags")
// 获取数据
if err := dbQuery.Find(&articles).Error; err != nil {
r.logger.Error("获取文章列表失败", zap.Error(err))
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Article, len(articles))
for i := range articles {
result[i] = &articles[i]
}
return result, total, nil
}
// CountByCategoryID 统计分类文章数量
func (r *GormArticleRepository) CountByCategoryID(ctx context.Context, categoryID string) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Article{}).
Where("category_id = ?", categoryID).
Count(&count).Error
if err != nil {
r.logger.Error("统计分类文章数量失败", zap.String("category_id", categoryID), zap.Error(err))
return 0, err
}
return count, nil
}
// CountByStatus 统计状态文章数量
func (r *GormArticleRepository) CountByStatus(ctx context.Context, status entities.ArticleStatus) (int64, error) {
var count int64
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
if status != "" {
dbQuery = dbQuery.Where("status = ?", status)
}
err := dbQuery.Count(&count).Error
if err != nil {
r.logger.Error("统计状态文章数量失败", zap.String("status", string(status)), zap.Error(err))
return 0, err
}
return count, nil
}
// IncrementViewCount 增加阅读量
func (r *GormArticleRepository) IncrementViewCount(ctx context.Context, articleID string) error {
err := r.db.WithContext(ctx).Model(&entities.Article{}).
Where("id = ?", articleID).
UpdateColumn("view_count", gorm.Expr("view_count + ?", 1)).Error
if err != nil {
r.logger.Error("增加阅读量失败", zap.String("article_id", articleID), zap.Error(err))
return err
}
return nil
}
// 实现 BaseRepository 接口的其他方法
func (r *GormArticleRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
}
var count int64
err := dbQuery.Count(&count).Error
return count, err
}
func (r *GormArticleRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Article{}).
Where("id = ?", id).
Count(&count).Error
return count > 0, err
}
func (r *GormArticleRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.Article{}, "id = ?", id).Error
}
func (r *GormArticleRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.Article{}).
Where("id = ?", id).
Update("deleted_at", nil).Error
}
func (r *GormArticleRepository) CreateBatch(ctx context.Context, entities []entities.Article) error {
return r.db.WithContext(ctx).Create(&entities).Error
}
func (r *GormArticleRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Article, error) {
var articles []entities.Article
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&articles).Error
return articles, err
}
func (r *GormArticleRepository) UpdateBatch(ctx context.Context, entities []entities.Article) error {
return r.db.WithContext(ctx).Save(&entities).Error
}
func (r *GormArticleRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.Article{}, "id IN ?", ids).Error
}
func (r *GormArticleRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Article, error) {
var articles []entities.Article
dbQuery := r.db.WithContext(ctx).Model(&entities.Article{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("title LIKE ? OR content LIKE ?", search, search)
}
// 应用排序
if options.Sort != "" {
order := "DESC"
if options.Order != "" {
order = strings.ToUpper(options.Order)
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
}
// 预加载关联数据
if len(options.Include) > 0 {
for _, include := range options.Include {
dbQuery = dbQuery.Preload(include)
}
}
err := dbQuery.Find(&articles).Error
return articles, err
}

View File

@@ -0,0 +1,247 @@
package repositories
import (
"context"
"fmt"
"hyapi-server/internal/domains/article/entities"
"hyapi-server/internal/domains/article/repositories"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// GormCategoryRepository GORM分类仓储实现
type GormCategoryRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.CategoryRepository = (*GormCategoryRepository)(nil)
// NewGormCategoryRepository 创建GORM分类仓储
func NewGormCategoryRepository(db *gorm.DB, logger *zap.Logger) *GormCategoryRepository {
return &GormCategoryRepository{
db: db,
logger: logger,
}
}
// Create 创建分类
func (r *GormCategoryRepository) Create(ctx context.Context, entity entities.Category) (entities.Category, error) {
r.logger.Info("创建分类", zap.String("id", entity.ID), zap.String("name", entity.Name))
err := r.db.WithContext(ctx).Create(&entity).Error
if err != nil {
r.logger.Error("创建分类失败", zap.Error(err))
return entity, err
}
return entity, nil
}
// GetByID 根据ID获取分类
func (r *GormCategoryRepository) GetByID(ctx context.Context, id string) (entities.Category, error) {
var entity entities.Category
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return entity, fmt.Errorf("分类不存在")
}
r.logger.Error("获取分类失败", zap.String("id", id), zap.Error(err))
return entity, err
}
return entity, nil
}
// Update 更新分类
func (r *GormCategoryRepository) Update(ctx context.Context, entity entities.Category) error {
r.logger.Info("更新分类", zap.String("id", entity.ID))
err := r.db.WithContext(ctx).Save(&entity).Error
if err != nil {
r.logger.Error("更新分类失败", zap.String("id", entity.ID), zap.Error(err))
return err
}
return nil
}
// Delete 删除分类
func (r *GormCategoryRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除分类", zap.String("id", id))
err := r.db.WithContext(ctx).Delete(&entities.Category{}, "id = ?", id).Error
if err != nil {
r.logger.Error("删除分类失败", zap.String("id", id), zap.Error(err))
return err
}
return nil
}
// FindActive 查找启用的分类
func (r *GormCategoryRepository) FindActive(ctx context.Context) ([]*entities.Category, error) {
var categories []entities.Category
err := r.db.WithContext(ctx).
Where("active = ?", true).
Order("sort_order ASC, created_at ASC").
Find(&categories).Error
if err != nil {
r.logger.Error("查找启用分类失败", zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Category, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// FindBySortOrder 按排序查找分类
func (r *GormCategoryRepository) FindBySortOrder(ctx context.Context) ([]*entities.Category, error) {
var categories []entities.Category
err := r.db.WithContext(ctx).
Order("sort_order ASC, created_at ASC").
Find(&categories).Error
if err != nil {
r.logger.Error("按排序查找分类失败", zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Category, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// CountActive 统计启用分类数量
func (r *GormCategoryRepository) CountActive(ctx context.Context) (int64, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Category{}).
Where("active = ?", true).
Count(&count).Error
if err != nil {
r.logger.Error("统计启用分类数量失败", zap.Error(err))
return 0, err
}
return count, nil
}
// 实现 BaseRepository 接口的其他方法
func (r *GormCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
dbQuery := r.db.WithContext(ctx).Model(&entities.Category{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ?", search, search)
}
var count int64
err := dbQuery.Count(&count).Error
return count, err
}
func (r *GormCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Category{}).
Where("id = ?", id).
Count(&count).Error
return count > 0, err
}
func (r *GormCategoryRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.Category{}, "id = ?", id).Error
}
func (r *GormCategoryRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.Category{}).
Where("id = ?", id).
Update("deleted_at", nil).Error
}
func (r *GormCategoryRepository) CreateBatch(ctx context.Context, entities []entities.Category) error {
return r.db.WithContext(ctx).Create(&entities).Error
}
func (r *GormCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Category, error) {
var categories []entities.Category
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&categories).Error
return categories, err
}
func (r *GormCategoryRepository) UpdateBatch(ctx context.Context, entities []entities.Category) error {
return r.db.WithContext(ctx).Save(&entities).Error
}
func (r *GormCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.Category{}, "id IN ?", ids).Error
}
func (r *GormCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Category, error) {
var categories []entities.Category
dbQuery := r.db.WithContext(ctx).Model(&entities.Category{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ?", search, search)
}
// 应用排序
if options.Sort != "" {
order := "DESC"
if options.Order != "" {
order = options.Order
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
} else {
dbQuery = dbQuery.Order("sort_order ASC, created_at ASC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
}
// 预加载关联数据
if len(options.Include) > 0 {
for _, include := range options.Include {
dbQuery = dbQuery.Preload(include)
}
}
err := dbQuery.Find(&categories).Error
return categories, err
}

View File

@@ -0,0 +1,168 @@
package repositories
import (
"context"
"fmt"
"time"
"hyapi-server/internal/domains/article/entities"
"hyapi-server/internal/domains/article/repositories"
"go.uber.org/zap"
"gorm.io/gorm"
)
// GormScheduledTaskRepository GORM定时任务仓储实现
type GormScheduledTaskRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.ScheduledTaskRepository = (*GormScheduledTaskRepository)(nil)
// NewGormScheduledTaskRepository 创建GORM定时任务仓储
func NewGormScheduledTaskRepository(db *gorm.DB, logger *zap.Logger) *GormScheduledTaskRepository {
return &GormScheduledTaskRepository{
db: db,
logger: logger,
}
}
// Create 创建定时任务记录
func (r *GormScheduledTaskRepository) Create(ctx context.Context, task entities.ScheduledTask) (entities.ScheduledTask, error) {
r.logger.Info("创建定时任务记录", zap.String("task_id", task.TaskID), zap.String("article_id", task.ArticleID))
err := r.db.WithContext(ctx).Create(&task).Error
if err != nil {
r.logger.Error("创建定时任务记录失败", zap.Error(err))
return task, err
}
return task, nil
}
// GetByTaskID 根据Asynq任务ID获取任务记录
func (r *GormScheduledTaskRepository) GetByTaskID(ctx context.Context, taskID string) (entities.ScheduledTask, error) {
var task entities.ScheduledTask
err := r.db.WithContext(ctx).
Preload("Article").
Where("task_id = ?", taskID).
First(&task).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return task, fmt.Errorf("定时任务不存在")
}
r.logger.Error("获取定时任务失败", zap.String("task_id", taskID), zap.Error(err))
return task, err
}
return task, nil
}
// GetByArticleID 根据文章ID获取任务记录
func (r *GormScheduledTaskRepository) GetByArticleID(ctx context.Context, articleID string) (entities.ScheduledTask, error) {
var task entities.ScheduledTask
err := r.db.WithContext(ctx).
Preload("Article").
Where("article_id = ? AND status IN (?)", articleID, []string{"pending", "running"}).
First(&task).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return task, fmt.Errorf("文章没有活动的定时任务")
}
r.logger.Error("获取文章定时任务失败", zap.String("article_id", articleID), zap.Error(err))
return task, err
}
return task, nil
}
// Update 更新任务记录
func (r *GormScheduledTaskRepository) Update(ctx context.Context, task entities.ScheduledTask) error {
r.logger.Info("更新定时任务记录", zap.String("task_id", task.TaskID), zap.String("status", string(task.Status)))
err := r.db.WithContext(ctx).Save(&task).Error
if err != nil {
r.logger.Error("更新定时任务记录失败", zap.String("task_id", task.TaskID), zap.Error(err))
return err
}
return nil
}
// Delete 删除任务记录
func (r *GormScheduledTaskRepository) Delete(ctx context.Context, taskID string) error {
r.logger.Info("删除定时任务记录", zap.String("task_id", taskID))
err := r.db.WithContext(ctx).Where("task_id = ?", taskID).Delete(&entities.ScheduledTask{}).Error
if err != nil {
r.logger.Error("删除定时任务记录失败", zap.String("task_id", taskID), zap.Error(err))
return err
}
return nil
}
// MarkAsCancelled 标记任务为已取消
func (r *GormScheduledTaskRepository) MarkAsCancelled(ctx context.Context, taskID string) error {
r.logger.Info("标记定时任务为已取消", zap.String("task_id", taskID))
result := r.db.WithContext(ctx).
Model(&entities.ScheduledTask{}).
Where("task_id = ? AND status IN (?)", taskID, []string{"pending", "running"}).
Updates(map[string]interface{}{
"status": entities.TaskStatusCancelled,
"completed_at": time.Now(),
})
if result.Error != nil {
r.logger.Error("标记定时任务为已取消失败", zap.String("task_id", taskID), zap.Error(result.Error))
return result.Error
}
if result.RowsAffected == 0 {
r.logger.Warn("没有找到需要取消的定时任务", zap.String("task_id", taskID))
}
return nil
}
// GetActiveTasks 获取活动状态的任务列表
func (r *GormScheduledTaskRepository) GetActiveTasks(ctx context.Context) ([]entities.ScheduledTask, error) {
var tasks []entities.ScheduledTask
err := r.db.WithContext(ctx).
Preload("Article").
Where("status IN (?)", []string{"pending", "running"}).
Order("scheduled_at ASC").
Find(&tasks).Error
if err != nil {
r.logger.Error("获取活动定时任务列表失败", zap.Error(err))
return nil, err
}
return tasks, nil
}
// GetExpiredTasks 获取过期的任务列表
func (r *GormScheduledTaskRepository) GetExpiredTasks(ctx context.Context) ([]entities.ScheduledTask, error) {
var tasks []entities.ScheduledTask
err := r.db.WithContext(ctx).
Preload("Article").
Where("status = ? AND scheduled_at < ?", entities.TaskStatusPending, time.Now()).
Order("scheduled_at ASC").
Find(&tasks).Error
if err != nil {
r.logger.Error("获取过期定时任务列表失败", zap.Error(err))
return nil, err
}
return tasks, nil
}

View File

@@ -0,0 +1,279 @@
package repositories
import (
"context"
"fmt"
"hyapi-server/internal/domains/article/entities"
"hyapi-server/internal/domains/article/repositories"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// GormTagRepository GORM标签仓储实现
type GormTagRepository struct {
db *gorm.DB
logger *zap.Logger
}
// 编译时检查接口实现
var _ repositories.TagRepository = (*GormTagRepository)(nil)
// NewGormTagRepository 创建GORM标签仓储
func NewGormTagRepository(db *gorm.DB, logger *zap.Logger) *GormTagRepository {
return &GormTagRepository{
db: db,
logger: logger,
}
}
// Create 创建标签
func (r *GormTagRepository) Create(ctx context.Context, entity entities.Tag) (entities.Tag, error) {
r.logger.Info("创建标签", zap.String("id", entity.ID), zap.String("name", entity.Name))
err := r.db.WithContext(ctx).Create(&entity).Error
if err != nil {
r.logger.Error("创建标签失败", zap.Error(err))
return entity, err
}
return entity, nil
}
// GetByID 根据ID获取标签
func (r *GormTagRepository) GetByID(ctx context.Context, id string) (entities.Tag, error) {
var entity entities.Tag
err := r.db.WithContext(ctx).Where("id = ?", id).First(&entity).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return entity, fmt.Errorf("标签不存在")
}
r.logger.Error("获取标签失败", zap.String("id", id), zap.Error(err))
return entity, err
}
return entity, nil
}
// Update 更新标签
func (r *GormTagRepository) Update(ctx context.Context, entity entities.Tag) error {
r.logger.Info("更新标签", zap.String("id", entity.ID))
err := r.db.WithContext(ctx).Save(&entity).Error
if err != nil {
r.logger.Error("更新标签失败", zap.String("id", entity.ID), zap.Error(err))
return err
}
return nil
}
// Delete 删除标签
func (r *GormTagRepository) Delete(ctx context.Context, id string) error {
r.logger.Info("删除标签", zap.String("id", id))
err := r.db.WithContext(ctx).Delete(&entities.Tag{}, "id = ?", id).Error
if err != nil {
r.logger.Error("删除标签失败", zap.String("id", id), zap.Error(err))
return err
}
return nil
}
// FindByArticleID 根据文章ID查找标签
func (r *GormTagRepository) FindByArticleID(ctx context.Context, articleID string) ([]*entities.Tag, error) {
var tags []entities.Tag
err := r.db.WithContext(ctx).
Joins("JOIN article_tag_relations ON article_tag_relations.tag_id = tags.id").
Where("article_tag_relations.article_id = ?", articleID).
Find(&tags).Error
if err != nil {
r.logger.Error("根据文章ID查找标签失败", zap.String("article_id", articleID), zap.Error(err))
return nil, err
}
// 转换为指针切片
result := make([]*entities.Tag, len(tags))
for i := range tags {
result[i] = &tags[i]
}
return result, nil
}
// FindByName 根据名称查找标签
func (r *GormTagRepository) FindByName(ctx context.Context, name string) (*entities.Tag, error) {
var tag entities.Tag
err := r.db.WithContext(ctx).Where("name = ?", name).First(&tag).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
r.logger.Error("根据名称查找标签失败", zap.String("name", name), zap.Error(err))
return nil, err
}
return &tag, nil
}
// AddTagToArticle 为文章添加标签
func (r *GormTagRepository) AddTagToArticle(ctx context.Context, articleID string, tagID string) error {
// 检查关联是否已存在
var count int64
err := r.db.WithContext(ctx).Table("article_tag_relations").
Where("article_id = ? AND tag_id = ?", articleID, tagID).
Count(&count).Error
if err != nil {
r.logger.Error("检查标签关联失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
return err
}
if count > 0 {
// 关联已存在,不需要重复添加
return nil
}
// 创建关联
err = r.db.WithContext(ctx).Exec(`
INSERT INTO article_tag_relations (article_id, tag_id)
VALUES (?, ?)
`, articleID, tagID).Error
if err != nil {
r.logger.Error("添加标签到文章失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
return err
}
r.logger.Info("添加标签到文章成功", zap.String("article_id", articleID), zap.String("tag_id", tagID))
return nil
}
// RemoveTagFromArticle 从文章移除标签
func (r *GormTagRepository) RemoveTagFromArticle(ctx context.Context, articleID string, tagID string) error {
err := r.db.WithContext(ctx).Exec(`
DELETE FROM article_tag_relations
WHERE article_id = ? AND tag_id = ?
`, articleID, tagID).Error
if err != nil {
r.logger.Error("从文章移除标签失败", zap.String("article_id", articleID), zap.String("tag_id", tagID), zap.Error(err))
return err
}
r.logger.Info("从文章移除标签成功", zap.String("article_id", articleID), zap.String("tag_id", tagID))
return nil
}
// GetArticleTags 获取文章的所有标签
func (r *GormTagRepository) GetArticleTags(ctx context.Context, articleID string) ([]*entities.Tag, error) {
return r.FindByArticleID(ctx, articleID)
}
// 实现 BaseRepository 接口的其他方法
func (r *GormTagRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
dbQuery := r.db.WithContext(ctx).Model(&entities.Tag{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("name LIKE ?", search)
}
var count int64
err := dbQuery.Count(&count).Error
return count, err
}
func (r *GormTagRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.Tag{}).
Where("id = ?", id).
Count(&count).Error
return count > 0, err
}
func (r *GormTagRepository) SoftDelete(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Delete(&entities.Tag{}, "id = ?", id).Error
}
func (r *GormTagRepository) Restore(ctx context.Context, id string) error {
return r.db.WithContext(ctx).Unscoped().Model(&entities.Tag{}).
Where("id = ?", id).
Update("deleted_at", nil).Error
}
func (r *GormTagRepository) CreateBatch(ctx context.Context, entities []entities.Tag) error {
return r.db.WithContext(ctx).Create(&entities).Error
}
func (r *GormTagRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Tag, error) {
var tags []entities.Tag
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&tags).Error
return tags, err
}
func (r *GormTagRepository) UpdateBatch(ctx context.Context, entities []entities.Tag) error {
return r.db.WithContext(ctx).Save(&entities).Error
}
func (r *GormTagRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.Tag{}, "id IN ?", ids).Error
}
func (r *GormTagRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Tag, error) {
var tags []entities.Tag
dbQuery := r.db.WithContext(ctx).Model(&entities.Tag{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
dbQuery = dbQuery.Where(key+" = ?", value)
}
}
if options.Search != "" {
search := "%" + options.Search + "%"
dbQuery = dbQuery.Where("name LIKE ?", search)
}
// 应用排序
if options.Sort != "" {
order := "DESC"
if options.Order != "" {
order = options.Order
}
dbQuery = dbQuery.Order(fmt.Sprintf("%s %s", options.Sort, order))
} else {
dbQuery = dbQuery.Order("created_at ASC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
dbQuery = dbQuery.Offset(offset).Limit(options.PageSize)
}
// 预加载关联数据
if len(options.Include) > 0 {
for _, include := range options.Include {
dbQuery = dbQuery.Preload(include)
}
}
err := dbQuery.Find(&tags).Error
return tags, err
}

View File

@@ -0,0 +1,370 @@
package certification
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"hyapi-server/internal/domains/certification/entities"
"hyapi-server/internal/domains/certification/enums"
"hyapi-server/internal/domains/certification/repositories"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
)
// ================ 常量定义 ================
const (
// 表名常量
CertificationsTable = "certifications"
// 缓存时间常量
CacheTTLPrimaryQuery = 30 * time.Minute // 主键查询缓存时间
CacheTTLBusinessQuery = 15 * time.Minute // 业务查询缓存时间
CacheTTLUserQuery = 10 * time.Minute // 用户相关查询缓存时间
CacheTTLWarmupLong = 30 * time.Minute // 预热长期缓存
CacheTTLWarmupMedium = 15 * time.Minute // 预热中期缓存
// 缓存键模式常量
CachePatternTable = "gorm_cache:certifications:*"
CachePatternUser = "certification:user_id:*"
)
// ================ Repository 实现 ================
// GormCertificationCommandRepository 认证命令仓储GORM实现
//
// 特性说明:
// - 基于 CachedBaseRepositoryImpl 实现自动缓存管理
// - 支持多级缓存策略主键查询30分钟业务查询15分钟
// - 自动缓存失效:写操作时自动清理相关缓存
// - 智能缓存选择:根据查询复杂度自动选择缓存策略
// - 内置监控支持:提供缓存统计和性能监控
type GormCertificationCommandRepository struct {
*database.CachedBaseRepositoryImpl
}
// 编译时检查接口实现
var _ repositories.CertificationCommandRepository = (*GormCertificationCommandRepository)(nil)
// NewGormCertificationCommandRepository 创建认证命令仓储
//
// 参数:
// - db: GORM数据库连接实例
// - logger: 日志记录器
//
// 返回:
// - repositories.CertificationCommandRepository: 仓储接口实现
func NewGormCertificationCommandRepository(db *gorm.DB, logger *zap.Logger) repositories.CertificationCommandRepository {
return &GormCertificationCommandRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, CertificationsTable),
}
}
// ================ 基础CRUD操作 ================
// Create 创建认证
//
// 业务说明:
// - 创建新的认证申请
// - 自动触发相关缓存失效
//
// 参数:
// - ctx: 上下文
// - cert: 认证实体
//
// 返回:
// - error: 创建失败时的错误信息
func (r *GormCertificationCommandRepository) Create(ctx context.Context, cert entities.Certification) error {
r.GetLogger().Info("创建认证申请",
zap.String("user_id", cert.UserID),
zap.String("status", string(cert.Status)))
return r.CreateEntity(ctx, &cert)
}
// Update 更新认证
//
// 缓存影响:
// - GORM缓存插件会自动失效相关缓存
// - 无需手动管理缓存一致性
//
// 参数:
// - ctx: 上下文
// - cert: 认证实体
//
// 返回:
// - error: 更新失败时的错误信息
func (r *GormCertificationCommandRepository) Update(ctx context.Context, cert entities.Certification) error {
r.GetLogger().Info("更新认证",
zap.String("id", cert.ID),
zap.String("status", string(cert.Status)))
return r.UpdateEntity(ctx, &cert)
}
// Delete 删除认证
//
// 参数:
// - ctx: 上下文
// - id: 认证ID
//
// 返回:
// - error: 删除失败时的错误信息
func (r *GormCertificationCommandRepository) Delete(ctx context.Context, id string) error {
r.GetLogger().Info("删除认证", zap.String("id", id))
return r.DeleteEntity(ctx, id, &entities.Certification{})
}
// ================ 业务特定的更新操作 ================
// UpdateStatus 更新认证状态
//
// 业务说明:
// - 更新认证的状态
// - 自动更新时间戳
//
// 缓存影响:
// - GORM缓存插件会自动失效表相关的缓存
// - 状态更新会影响列表查询和统计结果
//
// 参数:
// - ctx: 上下文
// - id: 认证ID
// - status: 新状态
//
// 返回:
// - error: 更新失败时的错误信息
func (r *GormCertificationCommandRepository) UpdateStatus(ctx context.Context, id string, status enums.CertificationStatus) error {
r.GetLogger().Info("更新认证状态",
zap.String("id", id),
zap.String("status", string(status)))
updates := map[string]interface{}{
"status": status,
"updated_at": time.Now(),
}
return r.GetDB(ctx).Model(&entities.Certification{}).
Where("id = ?", id).
Updates(updates).Error
}
// UpdateAuthFlowID 更新认证流程ID
//
// 业务说明:
// - 记录e签宝企业认证流程ID
// - 用于回调处理和状态跟踪
//
// 参数:
// - ctx: 上下文
// - id: 认证ID
// - authFlowID: 认证流程ID
//
// 返回:
// - error: 更新失败时的错误信息
func (r *GormCertificationCommandRepository) UpdateAuthFlowID(ctx context.Context, id string, authFlowID string) error {
r.GetLogger().Info("更新认证流程ID",
zap.String("id", id),
zap.String("auth_flow_id", authFlowID))
updates := map[string]interface{}{
"auth_flow_id": authFlowID,
"updated_at": time.Now(),
}
return r.GetDB(ctx).Model(&entities.Certification{}).
Where("id = ?", id).
Updates(updates).Error
}
// UpdateContractInfo 更新合同信息
//
// 业务说明:
// - 记录合同相关的ID和URL信息
// - 用于合同管理和用户下载
//
// 参数:
// - ctx: 上下文
// - id: 认证ID
// - contractFileID: 合同文件ID
// - esignFlowID: e签宝流程ID
// - contractURL: 合同URL
// - contractSignURL: 合同签署URL
//
// 返回:
// - error: 更新失败时的错误信息
func (r *GormCertificationCommandRepository) UpdateContractInfo(ctx context.Context, id string, contractFileID, esignFlowID, contractURL, contractSignURL string) error {
r.GetLogger().Info("更新合同信息",
zap.String("id", id),
zap.String("contract_file_id", contractFileID),
zap.String("esign_flow_id", esignFlowID))
updates := map[string]interface{}{
"contract_file_id": contractFileID,
"esign_flow_id": esignFlowID,
"contract_url": contractURL,
"contract_sign_url": contractSignURL,
"updated_at": time.Now(),
}
return r.GetDB(ctx).Model(&entities.Certification{}).
Where("id = ?", id).
Updates(updates).Error
}
// UpdateFailureInfo 更新失败信息
//
// 业务说明:
// - 记录认证失败的原因和详细信息
// - 用于错误分析和用户提示
//
// 参数:
// - ctx: 上下文
// - id: 认证ID
// - reason: 失败原因
// - message: 失败详细信息
//
// 返回:
// - error: 更新失败时的错误信息
func (r *GormCertificationCommandRepository) UpdateFailureInfo(ctx context.Context, id string, reason enums.FailureReason, message string) error {
r.GetLogger().Info("更新失败信息",
zap.String("id", id),
zap.String("reason", string(reason)),
zap.String("message", message))
updates := map[string]interface{}{
"failure_reason": reason,
"failure_message": message,
"updated_at": time.Now(),
}
return r.GetDB(ctx).Model(&entities.Certification{}).
Where("id = ?", id).
Updates(updates).Error
}
// ================ 批量操作 ================
// BatchUpdateStatus 批量更新状态
//
// 业务说明:
// - 批量更新多个认证的状态
// - 适用于管理员批量操作
//
// 参数:
// - ctx: 上下文
// - ids: 认证ID列表
// - status: 新状态
//
// 返回:
// - error: 更新失败时的错误信息
func (r *GormCertificationCommandRepository) BatchUpdateStatus(ctx context.Context, ids []string, status enums.CertificationStatus) error {
if len(ids) == 0 {
return fmt.Errorf("批量更新状态ID列表不能为空")
}
r.GetLogger().Info("批量更新认证状态",
zap.Strings("ids", ids),
zap.String("status", string(status)))
updates := map[string]interface{}{
"status": status,
"updated_at": time.Now(),
}
result := r.GetDB(ctx).Model(&entities.Certification{}).
Where("id IN ?", ids).
Updates(updates)
if result.Error != nil {
return fmt.Errorf("批量更新认证状态失败: %w", result.Error)
}
r.GetLogger().Info("批量更新完成", zap.Int64("affected_rows", result.RowsAffected))
return nil
}
// ================ 事务支持 ================
// WithTx 使用事务
//
// 业务说明:
// - 返回支持事务的仓储实例
// - 用于复杂业务操作的事务一致性保证
//
// 参数:
// - tx: 事务对象
//
// 返回:
// - repositories.CertificationCommandRepository: 支持事务的仓储实例
func (r *GormCertificationCommandRepository) WithTx(tx interfaces.Transaction) repositories.CertificationCommandRepository {
// 获取事务的底层*gorm.DB
txDB := tx.GetDB()
if gormDB, ok := txDB.(*gorm.DB); ok {
return &GormCertificationCommandRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormDB, r.GetLogger(), CertificationsTable),
}
}
r.GetLogger().Warn("不支持的事务类型,返回原始仓储")
return r
}
// ================ 缓存管理方法 ================
// WarmupCache 预热认证缓存
//
// 业务说明:
// - 系统启动时预热常用查询的缓存
// - 提升首次访问的响应速度
//
// 预热策略:
// - 活跃认证30分钟长期缓存
// - 最近创建15分钟中期缓存
func (r *GormCertificationCommandRepository) WarmupCache(ctx context.Context) error {
r.GetLogger().Info("开始预热认证缓存")
queries := []database.WarmupQuery{
{
Name: "active_certifications",
TTL: CacheTTLWarmupLong,
Dest: &[]entities.Certification{},
},
{
Name: "recent_certifications",
TTL: CacheTTLWarmupMedium,
Dest: &[]entities.Certification{},
},
}
return r.WarmupCommonQueries(ctx, queries)
}
// RefreshCache 刷新认证缓存
//
// 业务说明:
// - 手动刷新认证相关的所有缓存
// - 适用于数据迁移或批量更新后的缓存清理
func (r *GormCertificationCommandRepository) RefreshCache(ctx context.Context) error {
r.GetLogger().Info("刷新认证缓存")
return r.CachedBaseRepositoryImpl.RefreshCache(ctx, CachePatternTable)
}
// GetCacheStats 获取缓存统计信息
//
// 返回当前Repository的缓存使用统计包括
// - 基础缓存信息(命中率、键数量等)
// - 特定的缓存模式列表
// - 性能指标
func (r *GormCertificationCommandRepository) GetCacheStats() map[string]interface{} {
stats := r.GetCacheInfo()
stats["specific_patterns"] = []string{
CachePatternTable,
CachePatternUser,
}
return stats
}

View File

@@ -0,0 +1,469 @@
package certification
import (
"context"
"fmt"
"strings"
"time"
"hyapi-server/internal/domains/certification/entities"
"hyapi-server/internal/domains/certification/enums"
"hyapi-server/internal/domains/certification/repositories"
"hyapi-server/internal/domains/certification/repositories/queries"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
// ================ 常量定义 ================
const (
// 缓存时间常量
QueryCacheTTLPrimaryQuery = 30 * time.Minute // 主键查询缓存时间
QueryCacheTTLBusinessQuery = 15 * time.Minute // 业务查询缓存时间
QueryCacheTTLUserQuery = 10 * time.Minute // 用户相关查询缓存时间
QueryCacheTTLSearchQuery = 2 * time.Minute // 搜索查询缓存时间
QueryCacheTTLActiveRecords = 5 * time.Minute // 活跃记录查询缓存时间
QueryCacheTTLWarmupLong = 30 * time.Minute // 预热长期缓存
QueryCacheTTLWarmupMedium = 15 * time.Minute // 预热中期缓存
// 缓存键模式常量
QueryCachePatternTable = "gorm_cache:certifications:*"
QueryCachePatternUser = "certification:user_id:*"
)
// ================ Repository 实现 ================
// GormCertificationQueryRepository 认证查询仓储GORM实现
//
// 特性说明:
// - 基于 CachedBaseRepositoryImpl 实现自动缓存管理
// - 支持多级缓存策略主键查询30分钟业务查询15分钟搜索2分钟
// - 自动缓存失效:写操作时自动清理相关缓存
// - 智能缓存选择:根据查询复杂度自动选择缓存策略
// - 内置监控支持:提供缓存统计和性能监控
type GormCertificationQueryRepository struct {
*database.CachedBaseRepositoryImpl
}
// 编译时检查接口实现
var _ repositories.CertificationQueryRepository = (*GormCertificationQueryRepository)(nil)
// NewGormCertificationQueryRepository 创建认证查询仓储
//
// 参数:
// - db: GORM数据库连接实例
// - logger: 日志记录器
//
// 返回:
// - repositories.CertificationQueryRepository: 仓储接口实现
func NewGormCertificationQueryRepository(
db *gorm.DB,
logger *zap.Logger,
) repositories.CertificationQueryRepository {
return &GormCertificationQueryRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, CertificationsTable),
}
}
// ================ 基础查询操作 ================
// GetByID 根据ID获取认证
//
// 缓存策略:
// - 使用智能主键查询自动缓存30分钟
// - 主键查询命中率高,适合长期缓存
//
// 参数:
// - ctx: 上下文
// - id: 认证ID
//
// 返回:
// - *entities.Certification: 查询到的认证未找到时返回nil
// - error: 查询失败时的错误信息
func (r *GormCertificationQueryRepository) GetByID(ctx context.Context, id string) (*entities.Certification, error) {
var cert entities.Certification
if err := r.SmartGetByID(ctx, id, &cert); err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("认证记录不存在")
}
return nil, fmt.Errorf("查询认证记录失败: %w", err)
}
return &cert, nil
}
// GetByUserID 根据用户ID获取认证
//
// 缓存策略:
// - 业务查询缓存15分钟
// - 用户查询频率较高,适合中期缓存
//
// 参数:
// - ctx: 上下文
// - userID: 用户ID
//
// 返回:
// - *entities.Certification: 查询到的认证未找到时返回nil
// - error: 查询失败时的错误信息
func (r *GormCertificationQueryRepository) GetByUserID(ctx context.Context, userID string) (*entities.Certification, error) {
var cert entities.Certification
err := r.SmartGetByField(ctx, &cert, "user_id", userID, QueryCacheTTLUserQuery)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("用户尚未创建认证申请")
}
return nil, fmt.Errorf("查询用户认证记录失败: %w", err)
}
return &cert, nil
}
// Exists 检查认证是否存在
func (r *GormCertificationQueryRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.Certification{})
}
func (r *GormCertificationQueryRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Certification{}).Where("user_id = ?", userID).Count(&count).Error
if err != nil {
return false, fmt.Errorf("查询用户认证是否存在失败: %w", err)
}
return count > 0, nil
}
// ================ 列表查询 ================
// List 分页列表查询
//
// 缓存策略:
// - 搜索查询短期缓存2分钟避免频繁数据库查询但保证实时性
// - 常规列表:智能缓存(根据查询复杂度自动选择缓存策略)
//
// 参数:
// - ctx: 上下文
// - query: 列表查询条件
//
// 返回:
// - []*entities.Certification: 查询结果列表
// - int64: 总记录数
// - error: 查询失败时的错误信息
func (r *GormCertificationQueryRepository) List(ctx context.Context, query *queries.ListCertificationsQuery) ([]*entities.Certification, int64, error) {
db := r.GetDB(ctx).Model(&entities.Certification{})
// 应用过滤条件
if query.UserID != "" {
db = db.Where("user_id = ?", query.UserID)
}
if query.Status != "" {
db = db.Where("status = ?", query.Status)
}
if len(query.Statuses) > 0 {
db = db.Where("status IN ?", query.Statuses)
}
// 获取总数
var total int64
if err := db.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("查询认证总数失败: %w", err)
}
// 应用排序和分页
if query.SortBy != "" {
orderClause := query.SortBy
if query.SortOrder != "" {
orderClause += " " + strings.ToUpper(query.SortOrder)
}
db = db.Order(orderClause)
} else {
db = db.Order("created_at DESC")
}
offset := (query.Page - 1) * query.PageSize
db = db.Offset(offset).Limit(query.PageSize)
// 执行查询
var certifications []*entities.Certification
if err := db.Find(&certifications).Error; err != nil {
return nil, 0, fmt.Errorf("查询认证列表失败: %w", err)
}
return certifications, total, nil
}
// ListByUserIDs 根据用户ID列表查询
func (r *GormCertificationQueryRepository) ListByUserIDs(ctx context.Context, userIDs []string) ([]*entities.Certification, error) {
if len(userIDs) == 0 {
return []*entities.Certification{}, nil
}
var certifications []*entities.Certification
if err := r.GetDB(ctx).Where("user_id IN ?", userIDs).Order("created_at DESC").Find(&certifications).Error; err != nil {
return nil, fmt.Errorf("根据用户ID列表查询认证失败: %w", err)
}
return certifications, nil
}
// ListByStatus 根据状态查询
func (r *GormCertificationQueryRepository) ListByStatus(ctx context.Context, status enums.CertificationStatus, limit int) ([]*entities.Certification, error) {
db := r.GetDB(ctx).Where("status = ?", status).Order("created_at DESC")
if limit > 0 {
db = db.Limit(limit)
}
var certifications []*entities.Certification
if err := db.Find(&certifications).Error; err != nil {
return nil, fmt.Errorf("根据状态查询认证失败: %w", err)
}
return certifications, nil
}
// ================ 业务查询 ================
// FindByAuthFlowID 根据认证流程ID查询
//
// 缓存策略:
// - 业务查询缓存15分钟
// - 回调查询频率较高
//
// 参数:
// - ctx: 上下文
// - authFlowID: 认证流程ID
//
// 返回:
// - *entities.Certification: 查询到的认证未找到时返回nil
// - error: 查询失败时的错误信息
func (r *GormCertificationQueryRepository) FindByAuthFlowID(ctx context.Context, authFlowID string) (*entities.Certification, error) {
var cert entities.Certification
err := r.SmartGetByField(ctx, &cert, "auth_flow_id", authFlowID, QueryCacheTTLBusinessQuery)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("认证流程不存在")
}
return nil, fmt.Errorf("根据认证流程ID查询失败: %w", err)
}
return &cert, nil
}
// FindByEsignFlowID 根据e签宝流程ID查询
//
// 缓存策略:
// - 业务查询缓存15分钟
// - 回调查询频率较高
//
// 参数:
// - ctx: 上下文
// - esignFlowID: e签宝流程ID
//
// 返回:
// - *entities.Certification: 查询到的认证未找到时返回nil
// - error: 查询失败时的错误信息
func (r *GormCertificationQueryRepository) FindByEsignFlowID(ctx context.Context, esignFlowID string) (*entities.Certification, error) {
var cert entities.Certification
err := r.SmartGetByField(ctx, &cert, "esign_flow_id", esignFlowID, QueryCacheTTLBusinessQuery)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("e签宝流程不存在")
}
return nil, fmt.Errorf("根据e签宝流程ID查询失败: %w", err)
}
return &cert, nil
}
// ListPendingRetry 查询待重试的认证
//
// 缓存策略:
// - 管理查询,不缓存保证数据实时性
//
// 参数:
// - ctx: 上下文
// - maxRetryCount: 最大重试次数
//
// 返回:
// - []*entities.Certification: 待重试的认证列表
// - error: 查询失败时的错误信息
func (r *GormCertificationQueryRepository) ListPendingRetry(ctx context.Context, maxRetryCount int) ([]*entities.Certification, error) {
var certifications []*entities.Certification
err := r.WithoutCache().GetDB(ctx).
Where("status IN ? AND retry_count < ?",
[]enums.CertificationStatus{
enums.StatusInfoRejected,
enums.StatusContractRejected,
enums.StatusContractExpired,
},
maxRetryCount).
Order("created_at ASC").
Find(&certifications).Error
if err != nil {
return nil, fmt.Errorf("查询待重试认证失败: %w", err)
}
return certifications, nil
}
// GetPendingCertifications 获取待处理认证
func (r *GormCertificationQueryRepository) GetPendingCertifications(ctx context.Context) ([]*entities.Certification, error) {
var certifications []*entities.Certification
err := r.WithoutCache().GetDB(ctx).
Where("status IN ?", []enums.CertificationStatus{
enums.StatusPending,
enums.StatusInfoSubmitted,
}).
Order("created_at ASC").
Find(&certifications).Error
if err != nil {
return nil, fmt.Errorf("查询待处理认证失败: %w", err)
}
return certifications, nil
}
// GetExpiredContracts 获取过期合同
func (r *GormCertificationQueryRepository) GetExpiredContracts(ctx context.Context) ([]*entities.Certification, error) {
var certifications []*entities.Certification
err := r.WithoutCache().GetDB(ctx).
Where("status = ?", enums.StatusContractExpired).
Order("updated_at DESC").
Find(&certifications).Error
if err != nil {
return nil, fmt.Errorf("查询过期合同失败: %w", err)
}
return certifications, nil
}
// GetCertificationsByDateRange 根据日期范围获取认证
func (r *GormCertificationQueryRepository) GetCertificationsByDateRange(ctx context.Context, startDate, endDate time.Time) ([]*entities.Certification, error) {
var certifications []*entities.Certification
err := r.GetDB(ctx).
Where("created_at BETWEEN ? AND ?", startDate, endDate).
Order("created_at DESC").
Find(&certifications).Error
if err != nil {
return nil, fmt.Errorf("根据日期范围查询认证失败: %w", err)
}
return certifications, nil
}
// GetUserActiveCertification 获取用户当前活跃认证
func (r *GormCertificationQueryRepository) GetUserActiveCertification(ctx context.Context, userID string) (*entities.Certification, error) {
var cert entities.Certification
err := r.GetDB(ctx).
Where("user_id = ? AND status NOT IN ?", userID, []enums.CertificationStatus{
enums.StatusContractSigned,
enums.StatusInfoRejected,
enums.StatusContractRejected,
enums.StatusContractExpired,
}).
First(&cert).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("用户没有活跃的认证申请")
}
return nil, fmt.Errorf("查询用户活跃认证失败: %w", err)
}
return &cert, nil
}
// ================ 统计查询 ================
// CountByFailureReason 按失败原因统计
func (r *GormCertificationQueryRepository) CountByFailureReason(ctx context.Context, reason enums.FailureReason) (int64, error) {
var count int64
if err := r.WithShortCache().GetDB(ctx).Model(&entities.Certification{}).Where("failure_reason = ?", reason).Count(&count).Error; err != nil {
return 0, fmt.Errorf("按失败原因统计认证失败: %w", err)
}
return count, nil
}
// GetProgressStatistics 获取进度统计
func (r *GormCertificationQueryRepository) GetProgressStatistics(ctx context.Context) (*repositories.CertificationProgressStats, error) {
// 简化实现
return &repositories.CertificationProgressStats{
StatusProgress: make(map[enums.CertificationStatus]int64),
ProgressDistribution: make(map[int]int64),
StageTimeStats: make(map[string]*repositories.CertificationStageTimeInfo),
}, nil
}
// SearchByCompanyName 按公司名搜索
func (r *GormCertificationQueryRepository) SearchByCompanyName(ctx context.Context, companyName string, limit int) ([]*entities.Certification, error) {
// 简化实现,暂时返回空结果
r.GetLogger().Warn("按公司名搜索功能待实现,需要企业信息服务支持")
return []*entities.Certification{}, nil
}
// SearchByLegalPerson 按法人搜索
func (r *GormCertificationQueryRepository) SearchByLegalPerson(ctx context.Context, legalPersonName string, limit int) ([]*entities.Certification, error) {
// 简化实现,暂时返回空结果
r.GetLogger().Warn("按法人搜索功能待实现,需要企业信息服务支持")
return []*entities.Certification{}, nil
}
// InvalidateCache 清除缓存
func (r *GormCertificationQueryRepository) InvalidateCache(ctx context.Context, keys ...string) error {
// 简化实现,暂不处理缓存
return nil
}
// RefreshCache 刷新缓存
func (r *GormCertificationQueryRepository) RefreshCache(ctx context.Context, certificationID string) error {
// 简化实现,暂不处理缓存
return nil
}
// ================ 缓存管理方法 ================
// WarmupCache 预热认证查询缓存
//
// 业务说明:
// - 系统启动时预热常用查询的缓存
// - 提升首次访问的响应速度
//
// 预热策略:
// - 活跃认证30分钟长期缓存
// - 待处理认证15分钟中期缓存
func (r *GormCertificationQueryRepository) WarmupCache(ctx context.Context) error {
r.GetLogger().Info("开始预热认证查询缓存")
queries := []database.WarmupQuery{
{
Name: "active_certifications",
TTL: QueryCacheTTLWarmupLong,
Dest: &[]entities.Certification{},
},
{
Name: "pending_certifications",
TTL: QueryCacheTTLWarmupMedium,
Dest: &[]entities.Certification{},
},
}
return r.WarmupCommonQueries(ctx, queries)
}
// GetCacheStats 获取缓存统计信息
//
// 返回当前Repository的缓存使用统计包括
// - 基础缓存信息(命中率、键数量等)
// - 特定的缓存模式列表
// - 性能指标
func (r *GormCertificationQueryRepository) GetCacheStats() map[string]interface{} {
stats := r.GetCacheInfo()
stats["specific_patterns"] = []string{
QueryCachePatternTable,
QueryCachePatternUser,
}
return stats
}

View File

@@ -0,0 +1,139 @@
package certification
import (
"context"
"hyapi-server/internal/domains/certification/entities"
"hyapi-server/internal/domains/certification/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
EnterpriseInfoSubmitRecordsTable = "enterprise_info_submit_records"
)
type GormEnterpriseInfoSubmitRecordRepository struct {
*database.CachedBaseRepositoryImpl
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.EnterpriseInfoSubmitRecord{})
}
func NewGormEnterpriseInfoSubmitRecordRepository(db *gorm.DB, logger *zap.Logger) *GormEnterpriseInfoSubmitRecordRepository {
return &GormEnterpriseInfoSubmitRecordRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, EnterpriseInfoSubmitRecordsTable),
}
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Create(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
return r.CreateEntity(ctx, record)
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Update(ctx context.Context, record *entities.EnterpriseInfoSubmitRecord) error {
return r.UpdateEntity(ctx, record)
}
func (r *GormEnterpriseInfoSubmitRecordRepository) Exists(ctx context.Context, ID string) (bool, error) {
return r.ExistsEntity(ctx, ID, &entities.EnterpriseInfoSubmitRecord{})
}
func (r *GormEnterpriseInfoSubmitRecordRepository) FindByID(ctx context.Context, id string) (*entities.EnterpriseInfoSubmitRecord, error) {
var record entities.EnterpriseInfoSubmitRecord
err := r.GetDB(ctx).Where("id = ?", id).First(&record).Error
if err != nil {
return nil, err
}
return &record, nil
}
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
var record entities.EnterpriseInfoSubmitRecord
err := r.GetDB(ctx).
Where("user_id = ?", userID).
Order("submit_at DESC").
First(&record).Error
if err != nil {
return nil, err
}
return &record, nil
}
func (r *GormEnterpriseInfoSubmitRecordRepository) FindLatestVerifiedByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfoSubmitRecord, error) {
var record entities.EnterpriseInfoSubmitRecord
err := r.GetDB(ctx).
Where("user_id = ? AND status = ?", userID, "verified").
Order("verified_at DESC").
First(&record).Error
if err != nil {
return nil, err
}
return &record, nil
}
// ExistsByUnifiedSocialCodeExcludeUser 检查该统一社会信用代码是否已被其他用户占用(已提交或已通过验证的记录)
func (r *GormEnterpriseInfoSubmitRecordRepository) ExistsByUnifiedSocialCodeExcludeUser(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
if unifiedSocialCode == "" {
return false, nil
}
var count int64
query := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{}).
Where("unified_social_code = ? AND status IN (?, ?)", unifiedSocialCode, "submitted", "verified")
if excludeUserID != "" {
query = query.Where("user_id != ?", excludeUserID)
}
if err := query.Count(&count).Error; err != nil {
return false, err
}
return count > 0, nil
}
func (r *GormEnterpriseInfoSubmitRecordRepository) List(ctx context.Context, filter repositories.ListSubmitRecordsFilter) (*repositories.ListSubmitRecordsResult, error) {
base := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{})
if filter.CertificationStatus != "" {
base = base.Joins("JOIN certifications ON certifications.user_id = enterprise_info_submit_records.user_id AND certifications.deleted_at IS NULL").
Where("certifications.status = ?", filter.CertificationStatus)
}
if filter.CompanyName != "" {
base = base.Where("enterprise_info_submit_records.company_name LIKE ?", "%"+filter.CompanyName+"%")
}
if filter.LegalPersonPhone != "" {
base = base.Where("enterprise_info_submit_records.legal_person_phone = ?", filter.LegalPersonPhone)
}
if filter.LegalPersonName != "" {
base = base.Where("enterprise_info_submit_records.legal_person_name LIKE ?", "%"+filter.LegalPersonName+"%")
}
var total int64
if err := base.Count(&total).Error; err != nil {
return nil, err
}
if filter.PageSize <= 0 {
filter.PageSize = 10
}
if filter.Page <= 0 {
filter.Page = 1
}
offset := (filter.Page - 1) * filter.PageSize
var records []*entities.EnterpriseInfoSubmitRecord
q := r.GetDB(ctx).Model(&entities.EnterpriseInfoSubmitRecord{})
if filter.CertificationStatus != "" {
q = q.Joins("JOIN certifications ON certifications.user_id = enterprise_info_submit_records.user_id AND certifications.deleted_at IS NULL").
Where("certifications.status = ?", filter.CertificationStatus)
}
if filter.CompanyName != "" {
q = q.Where("enterprise_info_submit_records.company_name LIKE ?", "%"+filter.CompanyName+"%")
}
if filter.LegalPersonPhone != "" {
q = q.Where("enterprise_info_submit_records.legal_person_phone = ?", filter.LegalPersonPhone)
}
if filter.LegalPersonName != "" {
q = q.Where("enterprise_info_submit_records.legal_person_name LIKE ?", "%"+filter.LegalPersonName+"%")
}
err := q.Order("enterprise_info_submit_records.submit_at DESC").Offset(offset).Limit(filter.PageSize).Find(&records).Error
if err != nil {
return nil, err
}
return &repositories.ListSubmitRecordsResult{Records: records, Total: total}, nil
}

View File

@@ -0,0 +1,98 @@
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/finance/entities"
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
AlipayOrdersTable = "typay_orders"
)
type GormAlipayOrderRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.AlipayOrderRepository = (*GormAlipayOrderRepository)(nil)
func NewGormAlipayOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.AlipayOrderRepository {
return &GormAlipayOrderRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, AlipayOrdersTable),
}
}
func (r *GormAlipayOrderRepository) Create(ctx context.Context, order entities.AlipayOrder) (entities.AlipayOrder, error) {
err := r.CreateEntity(ctx, &order)
return order, err
}
func (r *GormAlipayOrderRepository) GetByID(ctx context.Context, id string) (entities.AlipayOrder, error) {
var order entities.AlipayOrder
err := r.SmartGetByID(ctx, id, &order)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.AlipayOrder{}, gorm.ErrRecordNotFound
}
return entities.AlipayOrder{}, err
}
return order, nil
}
func (r *GormAlipayOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.AlipayOrder, error) {
var order entities.AlipayOrder
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &order, nil
}
func (r *GormAlipayOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.AlipayOrder, error) {
var order entities.AlipayOrder
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &order, nil
}
func (r *GormAlipayOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.AlipayOrder, error) {
var orders []entities.AlipayOrder
err := r.GetDB(ctx).
Joins("JOIN recharge_records ON typay_orders.recharge_id = recharge_records.id").
Where("recharge_records.user_id = ?", userID).
Order("typay_orders.created_at DESC").
Find(&orders).Error
return orders, err
}
func (r *GormAlipayOrderRepository) Update(ctx context.Context, order entities.AlipayOrder) error {
return r.UpdateEntity(ctx, &order)
}
func (r *GormAlipayOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.AlipayOrderStatus) error {
return r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Update("status", status).Error
}
func (r *GormAlipayOrderRepository) Delete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.AlipayOrder{}, "id = ?", id).Error
}
func (r *GormAlipayOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.AlipayOrder{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}

View File

@@ -0,0 +1,352 @@
package repositories
import (
"context"
"errors"
"time"
"hyapi-server/internal/domains/finance/entities"
"hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
PurchaseOrdersTable = "ty_purchase_orders"
)
type GormPurchaseOrderRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.PurchaseOrderRepository = (*GormPurchaseOrderRepository)(nil)
func NewGormPurchaseOrderRepository(db *gorm.DB, logger *zap.Logger) repositories.PurchaseOrderRepository {
return &GormPurchaseOrderRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, PurchaseOrdersTable),
}
}
func (r *GormPurchaseOrderRepository) Create(ctx context.Context, order *entities.PurchaseOrder) (*entities.PurchaseOrder, error) {
err := r.CreateEntity(ctx, order)
if err != nil {
return nil, err
}
return order, nil
}
func (r *GormPurchaseOrderRepository) Update(ctx context.Context, order *entities.PurchaseOrder) error {
return r.UpdateEntity(ctx, order)
}
func (r *GormPurchaseOrderRepository) GetByID(ctx context.Context, id string) (*entities.PurchaseOrder, error) {
var order entities.PurchaseOrder
err := r.SmartGetByID(ctx, id, &order)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &order, nil
}
func (r *GormPurchaseOrderRepository) GetByOrderNo(ctx context.Context, orderNo string) (*entities.PurchaseOrder, error) {
var order entities.PurchaseOrder
err := r.GetDB(ctx).Where("order_no = ?", orderNo).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &order, nil
}
func (r *GormPurchaseOrderRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.PurchaseOrder, int64, error) {
var orders []entities.PurchaseOrder
var count int64
db := r.GetDB(ctx).Where("user_id = ?", userID)
// 获取总数
err := db.Model(&entities.PurchaseOrder{}).Count(&count).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
err = db.Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&orders).Error
if err != nil {
return nil, 0, err
}
result := make([]*entities.PurchaseOrder, len(orders))
for i := range orders {
result[i] = &orders[i]
}
return result, count, nil
}
func (r *GormPurchaseOrderRepository) GetByUserIDAndProductID(ctx context.Context, userID, productID string) (*entities.PurchaseOrder, error) {
var order entities.PurchaseOrder
err := r.GetDB(ctx).
Where("user_id = ? AND product_id = ? AND status = ?", userID, productID, entities.PurchaseOrderStatusPaid).
First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &order, nil
}
func (r *GormPurchaseOrderRepository) GetByPaymentTypeAndTransactionID(ctx context.Context, paymentType, transactionID string) (*entities.PurchaseOrder, error) {
var order entities.PurchaseOrder
err := r.GetDB(ctx).
Where("payment_type = ? AND trade_no = ?", paymentType, transactionID).
First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &order, nil
}
func (r *GormPurchaseOrderRepository) GetByTradeNo(ctx context.Context, tradeNo string) (*entities.PurchaseOrder, error) {
var order entities.PurchaseOrder
err := r.GetDB(ctx).Where("trade_no = ?", tradeNo).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &order, nil
}
func (r *GormPurchaseOrderRepository) UpdatePaymentStatus(ctx context.Context, orderID string, status entities.PurchaseOrderStatus, tradeNo *string, payAmount, receiptAmount *string, paymentTime *time.Time) error {
updates := map[string]interface{}{
"status": status,
}
if tradeNo != nil {
updates["trade_no"] = *tradeNo
}
if payAmount != nil {
updates["pay_amount"] = *payAmount
}
if receiptAmount != nil {
updates["receipt_amount"] = *receiptAmount
}
if paymentTime != nil {
updates["pay_time"] = *paymentTime
updates["notify_time"] = *paymentTime
}
err := r.GetDB(ctx).
Model(&entities.PurchaseOrder{}).
Where("id = ?", orderID).
Updates(updates).Error
return err
}
func (r *GormPurchaseOrderRepository) GetUserPurchasedProductCodes(ctx context.Context, userID string) ([]string, error) {
var orders []entities.PurchaseOrder
err := r.GetDB(ctx).
Select("product_code").
Where("user_id = ? AND status = ?", userID, entities.PurchaseOrderStatusPaid).
Find(&orders).Error
if err != nil {
return nil, err
}
codesMap := make(map[string]bool)
for _, order := range orders {
// 添加主产品编号
if order.ProductCode != "" {
codesMap[order.ProductCode] = true
}
}
codes := make([]string, 0, len(codesMap))
for code := range codesMap {
codes = append(codes, code)
}
return codes, nil
}
func (r *GormPurchaseOrderRepository) GetUserPaidProductIDs(ctx context.Context, userID string) ([]string, error) {
var orders []entities.PurchaseOrder
err := r.GetDB(ctx).
Select("product_id").
Where("user_id = ? AND status = ?", userID, entities.PurchaseOrderStatusPaid).
Find(&orders).Error
if err != nil {
return nil, err
}
idsMap := make(map[string]bool)
for _, order := range orders {
// 添加主产品ID
if order.ProductID != "" {
idsMap[order.ProductID] = true
}
}
ids := make([]string, 0, len(idsMap))
for id := range idsMap {
ids = append(ids, id)
}
return ids, nil
}
func (r *GormPurchaseOrderRepository) HasUserPurchased(ctx context.Context, userID string, productCode string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.PurchaseOrder{}).
Where("user_id = ? AND product_code = ? AND status = ?", userID, productCode, entities.PurchaseOrderStatusPaid).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *GormPurchaseOrderRepository) GetExpiringOrders(ctx context.Context, before time.Time, limit int) ([]*entities.PurchaseOrder, error) {
// 购买订单实体没有过期时间字段,此方法返回空结果
return []*entities.PurchaseOrder{}, nil
}
func (r *GormPurchaseOrderRepository) GetExpiredOrders(ctx context.Context, limit int) ([]*entities.PurchaseOrder, error) {
// 购买订单实体没有过期时间字段,此方法返回空结果
return []*entities.PurchaseOrder{}, nil
}
func (r *GormPurchaseOrderRepository) GetByStatus(ctx context.Context, status entities.PurchaseOrderStatus, limit, offset int) ([]*entities.PurchaseOrder, int64, error) {
var orders []entities.PurchaseOrder
var count int64
db := r.GetDB(ctx).Where("status = ?", status)
// 获取总数
err := db.Model(&entities.PurchaseOrder{}).Count(&count).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
err = db.Order("created_at DESC").
Limit(limit).
Offset(offset).
Find(&orders).Error
if err != nil {
return nil, 0, err
}
result := make([]*entities.PurchaseOrder, len(orders))
for i := range orders {
result[i] = &orders[i]
}
return result, count, nil
}
func (r *GormPurchaseOrderRepository) GetByFilters(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.PurchaseOrder, error) {
var orders []entities.PurchaseOrder
db := r.GetDB(ctx)
// 应用筛选条件
if filters != nil {
if userID, ok := filters["user_id"]; ok {
db = db.Where("user_id = ?", userID)
}
if status, ok := filters["status"]; ok && status != "" {
db = db.Where("status = ?", status)
}
if paymentType, ok := filters["payment_type"]; ok && paymentType != "" {
db = db.Where("payment_type = ?", paymentType)
}
if payChannel, ok := filters["pay_channel"]; ok && payChannel != "" {
db = db.Where("pay_channel = ?", payChannel)
}
if startTime, ok := filters["start_time"]; ok && startTime != "" {
db = db.Where("created_at >= ?", startTime)
}
if endTime, ok := filters["end_time"]; ok && endTime != "" {
db = db.Where("created_at <= ?", endTime)
}
}
// 应用排序和分页
// 默认按创建时间倒序
db = db.Order("created_at DESC")
// 应用分页
if options.PageSize > 0 {
db = db.Limit(options.PageSize)
}
if options.Page > 0 {
db = db.Offset((options.Page - 1) * options.PageSize)
}
// 执行查询
err := db.Find(&orders).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.PurchaseOrder, len(orders))
for i := range orders {
result[i] = &orders[i]
}
return result, nil
}
func (r *GormPurchaseOrderRepository) CountByFilters(ctx context.Context, filters map[string]interface{}) (int64, error) {
var count int64
db := r.GetDB(ctx).Model(&entities.PurchaseOrder{})
// 应用筛选条件
if filters != nil {
if userID, ok := filters["user_id"]; ok {
db = db.Where("user_id = ?", userID)
}
if status, ok := filters["status"]; ok && status != "" {
db = db.Where("status = ?", status)
}
if paymentType, ok := filters["payment_type"]; ok && paymentType != "" {
db = db.Where("payment_type = ?", paymentType)
}
if payChannel, ok := filters["pay_channel"]; ok && payChannel != "" {
db = db.Where("pay_channel = ?", payChannel)
}
if startTime, ok := filters["start_time"]; ok && startTime != "" {
db = db.Where("created_at >= ?", startTime)
}
if endTime, ok := filters["end_time"]; ok && endTime != "" {
db = db.Where("created_at <= ?", endTime)
}
}
// 执行计数
err := db.Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,509 @@
package repositories
import (
"context"
"errors"
"fmt"
"strings"
"time"
"hyapi-server/internal/domains/finance/entities"
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
RechargeRecordsTable = "recharge_records"
)
type GormRechargeRecordRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.RechargeRecordRepository = (*GormRechargeRecordRepository)(nil)
func NewGormRechargeRecordRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.RechargeRecordRepository {
return &GormRechargeRecordRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, RechargeRecordsTable),
}
}
func (r *GormRechargeRecordRepository) Create(ctx context.Context, record entities.RechargeRecord) (entities.RechargeRecord, error) {
err := r.CreateEntity(ctx, &record)
return record, err
}
func (r *GormRechargeRecordRepository) GetByID(ctx context.Context, id string) (entities.RechargeRecord, error) {
var record entities.RechargeRecord
err := r.SmartGetByID(ctx, id, &record)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.RechargeRecord{}, gorm.ErrRecordNotFound
}
return entities.RechargeRecord{}, err
}
return record, nil
}
func (r *GormRechargeRecordRepository) GetByUserID(ctx context.Context, userID string) ([]entities.RechargeRecord, error) {
var records []entities.RechargeRecord
err := r.GetDB(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&records).Error
return records, err
}
func (r *GormRechargeRecordRepository) GetByAlipayOrderID(ctx context.Context, alipayOrderID string) (*entities.RechargeRecord, error) {
var record entities.RechargeRecord
err := r.GetDB(ctx).Where("alipay_order_id = ?", alipayOrderID).First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &record, nil
}
func (r *GormRechargeRecordRepository) GetByTransferOrderID(ctx context.Context, transferOrderID string) (*entities.RechargeRecord, error) {
var record entities.RechargeRecord
err := r.GetDB(ctx).Where("transfer_order_id = ?", transferOrderID).First(&record).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &record, nil
}
func (r *GormRechargeRecordRepository) Update(ctx context.Context, record entities.RechargeRecord) error {
return r.UpdateEntity(ctx, &record)
}
func (r *GormRechargeRecordRepository) UpdateStatus(ctx context.Context, id string, status entities.RechargeStatus) error {
return r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Update("status", status).Error
}
func (r *GormRechargeRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
// 检查是否有 company_name 筛选,如果有则需要 JOIN 表
hasCompanyNameFilter := false
if options.Filters != nil {
if companyName, ok := options.Filters["company_name"].(string); ok && companyName != "" {
hasCompanyNameFilter = true
}
}
var query *gorm.DB
if hasCompanyNameFilter {
// 使用 JOIN 查询以支持企业名称筛选
query = r.GetDB(ctx).Table("recharge_records rr").
Joins("LEFT JOIN users u ON rr.user_id = u.id").
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id")
} else {
// 普通查询
query = r.GetDB(ctx).Model(&entities.RechargeRecord{})
}
if options.Filters != nil {
for key, value := range options.Filters {
// 特殊处理时间范围过滤器
if key == "start_time" {
if startTime, ok := value.(time.Time); ok {
if hasCompanyNameFilter {
query = query.Where("rr.created_at >= ?", startTime)
} else {
query = query.Where("created_at >= ?", startTime)
}
}
} else if key == "end_time" {
if endTime, ok := value.(time.Time); ok {
if hasCompanyNameFilter {
query = query.Where("rr.created_at <= ?", endTime)
} else {
query = query.Where("created_at <= ?", endTime)
}
}
} else if key == "company_name" {
// 处理企业名称筛选
if companyName, ok := value.(string); ok && companyName != "" {
query = query.Where("ei.company_name LIKE ?", "%"+companyName+"%")
}
} else if key == "min_amount" {
// 处理最小金额支持string、int、int64类型
if amount, err := r.parseAmount(value); err == nil {
if hasCompanyNameFilter {
query = query.Where("rr.amount >= ?", amount)
} else {
query = query.Where("amount >= ?", amount)
}
}
} else if key == "max_amount" {
// 处理最大金额支持string、int、int64类型
if amount, err := r.parseAmount(value); err == nil {
if hasCompanyNameFilter {
query = query.Where("rr.amount <= ?", amount)
} else {
query = query.Where("amount <= ?", amount)
}
}
} else {
// 其他过滤器使用等值查询
if hasCompanyNameFilter {
query = query.Where("rr."+key+" = ?", value)
} else {
query = query.Where(key+" = ?", value)
}
}
}
}
if options.Search != "" {
if hasCompanyNameFilter {
query = query.Where("rr.user_id LIKE ? OR rr.transfer_order_id LIKE ? OR rr.alipay_order_id LIKE ? OR rr.wechat_order_id LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
} else {
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ? OR wechat_order_id LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
}
return count, query.Count(&count).Error
}
func (r *GormRechargeRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
func (r *GormRechargeRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.RechargeRecord, error) {
var records []entities.RechargeRecord
// 检查是否有 company_name 筛选,如果有则需要 JOIN 表
hasCompanyNameFilter := false
if options.Filters != nil {
if companyName, ok := options.Filters["company_name"].(string); ok && companyName != "" {
hasCompanyNameFilter = true
}
}
var query *gorm.DB
if hasCompanyNameFilter {
// 使用 JOIN 查询以支持企业名称筛选
query = r.GetDB(ctx).Table("recharge_records rr").
Select("rr.*").
Joins("LEFT JOIN users u ON rr.user_id = u.id").
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id")
} else {
// 普通查询
query = r.GetDB(ctx).Model(&entities.RechargeRecord{})
}
if options.Filters != nil {
for key, value := range options.Filters {
// 特殊处理 user_ids 过滤器
if key == "user_ids" {
if userIds, ok := value.(string); ok && userIds != "" {
if hasCompanyNameFilter {
query = query.Where("rr.user_id IN ?", strings.Split(userIds, ","))
} else {
query = query.Where("user_id IN ?", strings.Split(userIds, ","))
}
}
} else if key == "company_name" {
// 处理企业名称筛选
if companyName, ok := value.(string); ok && companyName != "" {
query = query.Where("ei.company_name LIKE ?", "%"+companyName+"%")
}
} else if key == "start_time" {
// 处理开始时间范围
if startTime, ok := value.(time.Time); ok {
if hasCompanyNameFilter {
query = query.Where("rr.created_at >= ?", startTime)
} else {
query = query.Where("created_at >= ?", startTime)
}
}
} else if key == "end_time" {
// 处理结束时间范围
if endTime, ok := value.(time.Time); ok {
if hasCompanyNameFilter {
query = query.Where("rr.created_at <= ?", endTime)
} else {
query = query.Where("created_at <= ?", endTime)
}
}
} else if key == "min_amount" {
// 处理最小金额支持string、int、int64类型
if amount, err := r.parseAmount(value); err == nil {
if hasCompanyNameFilter {
query = query.Where("rr.amount >= ?", amount)
} else {
query = query.Where("amount >= ?", amount)
}
}
} else if key == "max_amount" {
// 处理最大金额支持string、int、int64类型
if amount, err := r.parseAmount(value); err == nil {
if hasCompanyNameFilter {
query = query.Where("rr.amount <= ?", amount)
} else {
query = query.Where("amount <= ?", amount)
}
}
} else {
// 其他过滤器使用等值查询
if hasCompanyNameFilter {
query = query.Where("rr."+key+" = ?", value)
} else {
query = query.Where(key+" = ?", value)
}
}
}
}
if options.Search != "" {
if hasCompanyNameFilter {
query = query.Where("rr.user_id LIKE ? OR rr.transfer_order_id LIKE ? OR rr.alipay_order_id LIKE ? OR rr.wechat_order_id LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
} else {
query = query.Where("user_id LIKE ? OR transfer_order_id LIKE ? OR alipay_order_id LIKE ? OR wechat_order_id LIKE ?",
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
}
}
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
if hasCompanyNameFilter {
query = query.Order("rr." + options.Sort + " " + order)
} else {
query = query.Order(options.Sort + " " + order)
}
} else {
if hasCompanyNameFilter {
query = query.Order("rr.created_at DESC")
} else {
query = query.Order("created_at DESC")
}
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&records).Error
return records, err
}
func (r *GormRechargeRecordRepository) CreateBatch(ctx context.Context, records []entities.RechargeRecord) error {
return r.GetDB(ctx).Create(&records).Error
}
func (r *GormRechargeRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.RechargeRecord, error) {
var records []entities.RechargeRecord
err := r.GetDB(ctx).Where("id IN ?", ids).Find(&records).Error
return records, err
}
func (r *GormRechargeRecordRepository) UpdateBatch(ctx context.Context, records []entities.RechargeRecord) error {
return r.GetDB(ctx).Save(&records).Error
}
func (r *GormRechargeRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.RechargeRecord{}, "id IN ?", ids).Error
}
func (r *GormRechargeRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.RechargeRecord] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormRechargeRecordRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), RechargeRecordsTable),
}
}
return r
}
func (r *GormRechargeRecordRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.RechargeRecord{})
}
func (r *GormRechargeRecordRepository) SoftDelete(ctx context.Context, id string) error {
return r.SoftDeleteEntity(ctx, id, &entities.RechargeRecord{})
}
func (r *GormRechargeRecordRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.RechargeRecord{})
}
// GetTotalAmountByUserId 获取用户总充值金额(排除赠送)
func (r *GormRechargeRecordRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ? AND status = ? AND recharge_type != ?", userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift).
Scan(&total).Error
return total, err
}
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总充值金额排除赠送
func (r *GormRechargeRecordRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ? AND status = ? AND recharge_type != ? AND created_at >= ? AND created_at < ?", userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).
Scan(&total).Error
return total, err
}
// GetDailyStatsByUserId 获取用户每日充值统计(排除赠送)
func (r *GormRechargeRecordRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE user_id = $1
AND status = $2
AND recharge_type != $3
AND DATE(created_at) >= $4
AND DATE(created_at) <= $5
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetMonthlyStatsByUserId 获取用户每月充值统计(排除赠送)
func (r *GormRechargeRecordRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE user_id = $1
AND status = $2
AND recharge_type != $3
AND created_at >= $4
AND created_at <= $5
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, userId, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemTotalAmount 获取系统总充值金额(排除赠送)
func (r *GormRechargeRecordRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Where("status = ? AND recharge_type != ?", entities.RechargeStatusSuccess, entities.RechargeTypeGift).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemAmountByDateRange 获取系统指定时间范围内的充值金额(排除赠送)
// endDate 应该是结束日期当天的次日00:00:00日统计或下个月1号00:00:00月统计使用 < 而不是 <=
func (r *GormRechargeRecordRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.RechargeRecord{}).
Where("status = ? AND recharge_type != ? AND created_at >= ? AND created_at < ?", entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemDailyStats 获取系统每日充值统计(排除赠送)
// startDate 和 endDate 应该是时间对象endDate 应该是结束日期当天的次日00:00:00使用 < 而不是 <=
func (r *GormRechargeRecordRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE status = ?
AND recharge_type != ?
AND created_at >= ?
AND created_at < ?
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyStats 获取系统每月充值统计(排除赠送)
func (r *GormRechargeRecordRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM recharge_records
WHERE status = ?
AND recharge_type != ?
AND created_at >= ?
AND created_at < ?
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, entities.RechargeStatusSuccess, entities.RechargeTypeGift, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// parseAmount 解析金额值支持string、int、int64类型转换为decimal.Decimal
func (r *GormRechargeRecordRepository) parseAmount(value interface{}) (decimal.Decimal, error) {
switch v := value.(type) {
case string:
if v == "" {
return decimal.Zero, fmt.Errorf("empty string")
}
return decimal.NewFromString(v)
case int:
return decimal.NewFromInt(int64(v)), nil
case int64:
return decimal.NewFromInt(v), nil
case float64:
return decimal.NewFromFloat(v), nil
case decimal.Decimal:
return v, nil
default:
return decimal.Zero, fmt.Errorf("unsupported type: %T", value)
}
}

View File

@@ -0,0 +1,348 @@
package repositories
import (
"context"
"errors"
"fmt"
"time"
"hyapi-server/internal/domains/finance/entities"
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
WalletsTable = "wallets"
)
type GormWalletRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
return &GormWalletRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletsTable),
}
}
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
err := r.CreateEntity(ctx, &wallet)
return wallet, err
}
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
var wallet entities.Wallet
err := r.SmartGetByID(ctx, id, &wallet)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Wallet{}, gorm.ErrRecordNotFound
}
return entities.Wallet{}, err
}
return wallet, nil
}
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
return r.UpdateEntity(ctx, &wallet)
}
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.Wallet{})
}
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
return r.SoftDeleteEntity(ctx, id, &entities.Wallet{})
}
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.Wallet{})
}
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
return count, query.Count(&count).Error
}
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
return r.GetDB(ctx).Create(&wallets).Error
}
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
var wallets []entities.Wallet
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&wallets).Error
return wallets, err
}
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
return r.GetDB(ctx).Save(&wallets).Error
}
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
}
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
var wallets []entities.Wallet
query := r.GetDB(ctx).Model(&entities.Wallet{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
}
if options.Sort != "" {
order := options.Sort
if options.Order == "desc" {
order += " DESC"
} else {
order += " ASC"
}
query = query.Order(order)
} else {
// 默认按创建时间倒序
query = query.Order("created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return wallets, query.Find(&wallets).Error
}
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormWalletRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletsTable),
}
}
return r
}
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
var total decimal.Decimal
err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
return total, err
}
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
return count, err
}
// ================ 接口要求的方法 ================
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
var wallet entities.Wallet
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &wallet, nil
}
// UpdateBalanceWithVersion 乐观锁自动重试最大重试maxRetry次
func (r *GormWalletRepository) UpdateBalanceWithVersion(ctx context.Context, walletID string, amount decimal.Decimal, operation string) (bool, error) {
maxRetry := 10
for i := 0; i < maxRetry; i++ {
// 每次重试都重新获取最新的钱包信息
var wallet entities.Wallet
err := r.GetDB(ctx).Where("id = ?", walletID).First(&wallet).Error
if err != nil {
return false, fmt.Errorf("获取钱包信息失败: %w", err)
}
// 重新计算新余额
var newBalance decimal.Decimal
switch operation {
case "add":
newBalance = wallet.Balance.Add(amount)
case "subtract":
newBalance = wallet.Balance.Sub(amount)
default:
return false, fmt.Errorf("不支持的操作类型: %s", operation)
}
// 乐观锁更新
result := r.GetDB(ctx).Model(&entities.Wallet{}).
Where("id = ? AND version = ?", walletID, wallet.Version).
Updates(map[string]interface{}{
"balance": newBalance.String(),
"version": wallet.Version + 1,
})
if result.Error != nil {
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
}
if result.RowsAffected == 1 {
return true, nil
}
// 乐观锁冲突,继续重试
// 注意这里可以添加日志记录但需要确保logger可用
}
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
}
// UpdateBalanceByUserID 乐观锁更新通过用户ID直接更新使用原生SQL
func (r *GormWalletRepository) UpdateBalanceByUserID(ctx context.Context, userID string, amount decimal.Decimal, operation string) (bool, error) {
maxRetry := 20 // 增加重试次数
baseDelay := 1 // 基础延迟毫秒
for i := 0; i < maxRetry; i++ {
// 每次重试都重新获取最新的钱包信息
var wallet entities.Wallet
err := r.GetDB(ctx).Where("user_id = ?", userID).First(&wallet).Error
if err != nil {
return false, fmt.Errorf("获取钱包信息失败: %w", err)
}
// 重新计算新余额
var newBalance decimal.Decimal
switch operation {
case "add":
newBalance = wallet.Balance.Add(amount)
case "subtract":
newBalance = wallet.Balance.Sub(amount)
default:
return false, fmt.Errorf("不支持的操作类型: %s", operation)
}
// 使用原生SQL进行乐观锁更新
newVersion := wallet.Version + 1
result := r.GetDB(ctx).Exec(`
UPDATE wallets
SET balance = ?, version = ?, updated_at = NOW()
WHERE user_id = ? AND version = ?
`, newBalance.String(), newVersion, userID, wallet.Version)
if result.Error != nil {
return false, fmt.Errorf("更新钱包余额失败: %w", result.Error)
}
if result.RowsAffected == 1 {
return true, nil
}
// 乐观锁冲突,添加指数退避延迟
if i < maxRetry-1 {
delay := baseDelay * (1 << i) // 指数退避: 1ms, 2ms, 4ms, 8ms...
if delay > 50 {
delay = 50 // 最大延迟50ms
}
time.Sleep(time.Duration(delay) * time.Millisecond)
}
}
return false, fmt.Errorf("高并发下余额变动失败,已达到最大重试次数 %d", maxRetry)
}
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
}
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
}
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
return r.GetDB(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
}
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 总钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 激活钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 总余额
var totalBalance decimal.Decimal
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
var stats domain_finance_repo.FinanceStats
// 用户钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
return nil, err
}
// 用户激活钱包数
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
return nil, err
}
// 用户总余额
var totalBalance decimal.Decimal
if err := r.GetDB(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
return nil, err
}
stats.TotalBalance = totalBalance.String()
// 用户今日交易数(这里需要根据实际业务逻辑实现)
stats.TodayTransactions = 0
return &stats, nil
}

View File

@@ -0,0 +1,643 @@
package repositories
import (
"context"
"strings"
"time"
"hyapi-server/internal/domains/finance/entities"
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
// WalletTransactionWithProduct 包含产品名称的钱包交易记录
type WalletTransactionWithProduct struct {
entities.WalletTransaction
ProductName string `json:"product_name" gorm:"column:product_name"`
}
const (
WalletTransactionsTable = "wallet_transactions"
)
type GormWalletTransactionRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.WalletTransactionRepository = (*GormWalletTransactionRepository)(nil)
func NewGormWalletTransactionRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletTransactionRepository {
return &GormWalletTransactionRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WalletTransactionsTable),
}
}
func (r *GormWalletTransactionRepository) Create(ctx context.Context, transaction entities.WalletTransaction) (entities.WalletTransaction, error) {
err := r.CreateEntity(ctx, &transaction)
return transaction, err
}
func (r *GormWalletTransactionRepository) Update(ctx context.Context, transaction entities.WalletTransaction) error {
return r.UpdateEntity(ctx, &transaction)
}
func (r *GormWalletTransactionRepository) GetByID(ctx context.Context, id string) (entities.WalletTransaction, error) {
var transaction entities.WalletTransaction
err := r.SmartGetByID(ctx, id, &transaction)
return transaction, err
}
func (r *GormWalletTransactionRepository) GetByUserID(ctx context.Context, userID string, limit, offset int) ([]*entities.WalletTransaction, error) {
var transactions []*entities.WalletTransaction
options := database.CacheListOptions{
Where: "user_id = ?",
Args: []interface{}{userID},
Order: "created_at DESC",
Limit: limit,
Offset: offset,
}
err := r.ListWithCache(ctx, &transactions, 10*time.Minute, options)
return transactions, err
}
func (r *GormWalletTransactionRepository) GetByApiCallID(ctx context.Context, apiCallID string) (*entities.WalletTransaction, error) {
var transaction entities.WalletTransaction
err := r.FindOne(ctx, &transaction, "api_call_id = ?", apiCallID)
if err != nil {
return nil, err
}
return &transaction, nil
}
func (r *GormWalletTransactionRepository) ListByUserId(ctx context.Context, userId string, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
var transactions []*entities.WalletTransaction
var total int64
// 构建查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 获取总数
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
return transactions, total, err
}
func (r *GormWalletTransactionRepository) ListByUserIdWithFilters(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) ([]*entities.WalletTransaction, int64, error) {
var transactions []*entities.WalletTransaction
var total int64
// 构建基础查询条件
whereCondition := "user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// 关键词筛选支持transaction_id和product_name
if keyword, ok := filters["keyword"].(string); ok && keyword != "" {
whereCondition += " AND (transaction_id LIKE ? OR product_id IN (SELECT id FROM product WHERE name LIKE ?))"
whereArgs = append(whereArgs, "%"+keyword+"%", "%"+keyword+"%")
}
// API调用ID筛选
if apiCallId, ok := filters["api_call_id"].(string); ok && apiCallId != "" {
whereCondition += " AND api_call_id LIKE ?"
whereArgs = append(whereArgs, "%"+apiCallId+"%")
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereCondition += " AND amount >= ?"
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereCondition += " AND amount <= ?"
whereArgs = append(whereArgs, maxAmount)
}
}
// 获取总数
count, err := r.CountWhere(ctx, &entities.WalletTransaction{}, whereCondition, whereArgs...)
if err != nil {
return nil, 0, err
}
total = count
// 使用基础仓储的分页查询方法
err = r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
return transactions, total, err
}
func (r *GormWalletTransactionRepository) CountByUserId(ctx context.Context, userId string) (int64, error) {
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ?", userId)
}
// CountByUserIdAndDateRange 按用户ID和日期范围统计钱包交易次数
func (r *GormWalletTransactionRepository) CountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (int64, error) {
return r.CountWhere(ctx, &entities.WalletTransaction{}, "user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate)
}
// GetTotalAmountByUserId 获取用户总消费金额
func (r *GormWalletTransactionRepository) GetTotalAmountByUserId(ctx context.Context, userId string) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ?", userId).
Scan(&total).Error
return total, err
}
// GetTotalAmountByUserIdAndDateRange 按用户ID和日期范围获取总消费金额
func (r *GormWalletTransactionRepository) GetTotalAmountByUserIdAndDateRange(ctx context.Context, userId string, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Select("COALESCE(SUM(amount), 0)").
Where("user_id = ? AND created_at >= ? AND created_at < ?", userId, startDate, endDate).
Scan(&total).Error
return total, err
}
// GetDailyStatsByUserId 获取用户每日消费统计
func (r *GormWalletTransactionRepository) GetDailyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE user_id = $1
AND DATE(created_at) >= $2
AND DATE(created_at) <= $3
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetMonthlyStatsByUserId 获取用户每月消费统计
func (r *GormWalletTransactionRepository) GetMonthlyStatsByUserId(ctx context.Context, userId string, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
// 构建SQL查询 - 使用PostgreSQL语法使用具体的日期范围
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE user_id = $1
AND created_at >= $2
AND created_at <= $3
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, userId, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// 实现interfaces.Repository接口的其他方法
func (r *GormWalletTransactionRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.WalletTransaction, error) {
var transactions []entities.WalletTransaction
err := r.ListWithOptions(ctx, &entities.WalletTransaction{}, &transactions, options)
return transactions, err
}
func (r *GormWalletTransactionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
return r.CountWithOptions(ctx, &entities.WalletTransaction{}, options)
}
func (r *GormWalletTransactionRepository) CreateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
return r.CreateBatchEntity(ctx, &transactions)
}
func (r *GormWalletTransactionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.WalletTransaction, error) {
var transactions []entities.WalletTransaction
err := r.GetEntitiesByIDs(ctx, ids, &transactions)
return transactions, err
}
func (r *GormWalletTransactionRepository) UpdateBatch(ctx context.Context, transactions []entities.WalletTransaction) error {
return r.UpdateBatchEntity(ctx, &transactions)
}
func (r *GormWalletTransactionRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.DeleteBatchEntity(ctx, ids, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) WithTx(tx interface{}) interfaces.Repository[entities.WalletTransaction] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormWalletTransactionRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), WalletTransactionsTable),
}
}
return r
}
func (r *GormWalletTransactionRepository) SoftDelete(ctx context.Context, id string) error {
return r.SoftDeleteEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) Restore(ctx context.Context, id string) error {
return r.RestoreEntity(ctx, id, &entities.WalletTransaction{})
}
func (r *GormWalletTransactionRepository) ListByUserIdWithFiltersAndProductName(ctx context.Context, userId string, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
var transactionsWithProduct []*WalletTransactionWithProduct
var total int64
// 构建基础查询条件
whereCondition := "wt.user_id = ?"
whereArgs := []interface{}{userId}
// 应用筛选条件
if filters != nil {
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND wt.created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND wt.created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// 交易ID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND wt.transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereCondition += " AND p.name LIKE ?"
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereCondition += " AND wt.amount >= ?"
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereCondition += " AND wt.amount <= ?"
whereArgs = append(whereArgs, maxAmount)
}
}
// 构建JOIN查询
query := r.GetDB(ctx).Table("wallet_transactions wt").
Select("wt.*, p.name as product_name").
Joins("LEFT JOIN product p ON wt.product_id = p.id").
Where(whereCondition, whereArgs...)
// 获取总数
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, nil, 0, err
}
total = count
// 应用排序和分页
if options.Sort != "" {
query = query.Order("wt." + options.Sort + " " + options.Order)
} else {
query = query.Order("wt.created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
// 执行查询
err = query.Find(&transactionsWithProduct).Error
if err != nil {
return nil, nil, 0, err
}
// 转换为entities.WalletTransaction并构建产品名称映射
var transactions []*entities.WalletTransaction
productNameMap := make(map[string]string)
for _, t := range transactionsWithProduct {
transaction := t.WalletTransaction
transactions = append(transactions, &transaction)
// 构建产品ID到产品名称的映射
if t.ProductName != "" {
productNameMap[transaction.ProductID] = t.ProductName
}
}
return productNameMap, transactions, total, nil
}
// ListWithFiltersAndProductName 管理端:根据条件筛选所有钱包交易记录(包含产品名称)
func (r *GormWalletTransactionRepository) ListWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}, options interfaces.ListOptions) (map[string]string, []*entities.WalletTransaction, int64, error) {
var transactionsWithProduct []*WalletTransactionWithProduct
var total int64
// 构建基础查询条件
whereCondition := "1=1"
whereArgs := []interface{}{}
// 应用筛选条件
if filters != nil {
// 用户ID筛选支持单个和多个
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
// 多个用户ID逗号分隔
userIdsList := strings.Split(userIds, ",")
whereCondition += " AND wt.user_id IN ?"
whereArgs = append(whereArgs, userIdsList)
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
// 单个用户ID
whereCondition += " AND wt.user_id = ?"
whereArgs = append(whereArgs, userId)
}
// 产品ID筛选支持多个
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
// 多个产品ID逗号分隔
productIdsList := strings.Split(productIds, ",")
whereCondition += " AND wt.product_id IN ?"
whereArgs = append(whereArgs, productIdsList)
}
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereCondition += " AND wt.created_at >= ?"
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereCondition += " AND wt.created_at <= ?"
whereArgs = append(whereArgs, endTime)
}
// 交易ID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereCondition += " AND wt.transaction_id LIKE ?"
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereCondition += " AND p.name LIKE ?"
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 企业名称筛选
if companyName, ok := filters["company_name"].(string); ok && companyName != "" {
whereCondition += " AND ei.company_name LIKE ?"
whereArgs = append(whereArgs, "%"+companyName+"%")
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereCondition += " AND wt.amount >= ?"
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereCondition += " AND wt.amount <= ?"
whereArgs = append(whereArgs, maxAmount)
}
}
// 构建JOIN查询
// 需要JOIN product表获取产品名称JOIN users和enterprise_infos表获取企业名称
query := r.GetDB(ctx).Table("wallet_transactions wt").
Select("wt.*, p.name as product_name").
Joins("LEFT JOIN product p ON wt.product_id = p.id").
Joins("LEFT JOIN users u ON wt.user_id = u.id").
Joins("LEFT JOIN enterprise_infos ei ON u.id = ei.user_id").
Where(whereCondition, whereArgs...)
// 获取总数
var count int64
err := query.Count(&count).Error
if err != nil {
return nil, nil, 0, err
}
total = count
// 应用排序和分页
if options.Sort != "" {
query = query.Order("wt." + options.Sort + " " + options.Order)
} else {
query = query.Order("wt.created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
// 执行查询
err = query.Find(&transactionsWithProduct).Error
if err != nil {
return nil, nil, 0, err
}
// 转换为entities.WalletTransaction并构建产品名称映射
var transactions []*entities.WalletTransaction
productNameMap := make(map[string]string)
for _, t := range transactionsWithProduct {
transaction := t.WalletTransaction
transactions = append(transactions, &transaction)
// 构建产品ID到产品名称的映射
if t.ProductName != "" {
productNameMap[transaction.ProductID] = t.ProductName
}
}
return productNameMap, transactions, total, nil
}
// ExportWithFiltersAndProductName 导出钱包交易记录(包含产品名称和企业信息)
func (r *GormWalletTransactionRepository) ExportWithFiltersAndProductName(ctx context.Context, filters map[string]interface{}) ([]*entities.WalletTransaction, error) {
var transactionsWithProduct []WalletTransactionWithProduct
// 构建查询
query := r.GetDB(ctx).Table("wallet_transactions wt").
Select("wt.*, p.name as product_name").
Joins("LEFT JOIN product p ON wt.product_id = p.id")
// 构建WHERE条件
var whereConditions []string
var whereArgs []interface{}
// 用户ID筛选
if userIds, ok := filters["user_ids"].(string); ok && userIds != "" {
whereConditions = append(whereConditions, "wt.user_id IN (?)")
whereArgs = append(whereArgs, strings.Split(userIds, ","))
} else if userId, ok := filters["user_id"].(string); ok && userId != "" {
whereConditions = append(whereConditions, "wt.user_id = ?")
whereArgs = append(whereArgs, userId)
}
// 时间范围筛选
if startTime, ok := filters["start_time"].(time.Time); ok {
whereConditions = append(whereConditions, "wt.created_at >= ?")
whereArgs = append(whereArgs, startTime)
}
if endTime, ok := filters["end_time"].(time.Time); ok {
whereConditions = append(whereConditions, "wt.created_at <= ?")
whereArgs = append(whereArgs, endTime)
}
// 交易ID筛选
if transactionId, ok := filters["transaction_id"].(string); ok && transactionId != "" {
whereConditions = append(whereConditions, "wt.transaction_id LIKE ?")
whereArgs = append(whereArgs, "%"+transactionId+"%")
}
// 产品名称筛选
if productName, ok := filters["product_name"].(string); ok && productName != "" {
whereConditions = append(whereConditions, "p.name LIKE ?")
whereArgs = append(whereArgs, "%"+productName+"%")
}
// 产品ID列表筛选
if productIds, ok := filters["product_ids"].(string); ok && productIds != "" {
whereConditions = append(whereConditions, "wt.product_id IN (?)")
whereArgs = append(whereArgs, strings.Split(productIds, ","))
}
// 金额范围筛选
if minAmount, ok := filters["min_amount"].(string); ok && minAmount != "" {
whereConditions = append(whereConditions, "wt.amount >= ?")
whereArgs = append(whereArgs, minAmount)
}
if maxAmount, ok := filters["max_amount"].(string); ok && maxAmount != "" {
whereConditions = append(whereConditions, "wt.amount <= ?")
whereArgs = append(whereArgs, maxAmount)
}
// 应用WHERE条件
if len(whereConditions) > 0 {
query = query.Where(strings.Join(whereConditions, " AND "), whereArgs...)
}
// 排序
query = query.Order("wt.created_at DESC")
// 执行查询
err := query.Find(&transactionsWithProduct).Error
if err != nil {
return nil, err
}
// 转换为entities.WalletTransaction
var transactions []*entities.WalletTransaction
for _, t := range transactionsWithProduct {
transaction := t.WalletTransaction
transactions = append(transactions, &transaction)
}
return transactions, nil
}
// GetSystemTotalAmount 获取系统总消费金额
func (r *GormWalletTransactionRepository) GetSystemTotalAmount(ctx context.Context) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemAmountByDateRange 获取系统指定时间范围内的消费金额
// endDate 应该是结束日期当天的次日00:00:00日统计或下个月1号00:00:00月统计使用 < 而不是 <=
func (r *GormWalletTransactionRepository) GetSystemAmountByDateRange(ctx context.Context, startDate, endDate time.Time) (float64, error) {
var total float64
err := r.GetDB(ctx).Model(&entities.WalletTransaction{}).
Where("created_at >= ? AND created_at < ?", startDate, endDate).
Select("COALESCE(SUM(amount), 0)").
Scan(&total).Error
return total, err
}
// GetSystemDailyStats 获取系统每日消费统计
func (r *GormWalletTransactionRepository) GetSystemDailyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE DATE(created_at) >= ?
AND DATE(created_at) <= ?
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyStats 获取系统每月消费统计
func (r *GormWalletTransactionRepository) GetSystemMonthlyStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COALESCE(SUM(amount), 0) as amount
FROM wallet_transactions
WHERE created_at >= ?
AND created_at < ?
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -0,0 +1,93 @@
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/finance/entities"
domain_finance_repo "hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
WechatOrdersTable = "typay_orders"
)
type GormWechatOrderRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ domain_finance_repo.WechatOrderRepository = (*GormWechatOrderRepository)(nil)
func NewGormWechatOrderRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WechatOrderRepository {
return &GormWechatOrderRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, WechatOrdersTable),
}
}
func (r *GormWechatOrderRepository) Create(ctx context.Context, order entities.WechatOrder) (entities.WechatOrder, error) {
err := r.CreateEntity(ctx, &order)
return order, err
}
func (r *GormWechatOrderRepository) GetByID(ctx context.Context, id string) (entities.WechatOrder, error) {
var order entities.WechatOrder
err := r.SmartGetByID(ctx, id, &order)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.WechatOrder{}, gorm.ErrRecordNotFound
}
return entities.WechatOrder{}, err
}
return order, nil
}
func (r *GormWechatOrderRepository) GetByOutTradeNo(ctx context.Context, outTradeNo string) (*entities.WechatOrder, error) {
var order entities.WechatOrder
err := r.GetDB(ctx).Where("out_trade_no = ?", outTradeNo).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &order, nil
}
func (r *GormWechatOrderRepository) GetByRechargeID(ctx context.Context, rechargeID string) (*entities.WechatOrder, error) {
var order entities.WechatOrder
err := r.GetDB(ctx).Where("recharge_id = ?", rechargeID).First(&order).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &order, nil
}
func (r *GormWechatOrderRepository) GetByUserID(ctx context.Context, userID string) ([]entities.WechatOrder, error) {
var orders []entities.WechatOrder
// 需要通过充值记录关联查询,这里简化处理
err := r.GetDB(ctx).Find(&orders).Error
return orders, err
}
func (r *GormWechatOrderRepository) Update(ctx context.Context, order entities.WechatOrder) error {
return r.UpdateEntity(ctx, &order)
}
func (r *GormWechatOrderRepository) UpdateStatus(ctx context.Context, id string, status entities.WechatOrderStatus) error {
return r.GetDB(ctx).Model(&entities.WechatOrder{}).Where("id = ?", id).Update("status", status).Error
}
func (r *GormWechatOrderRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.WechatOrder{})
}
func (r *GormWechatOrderRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.WechatOrder{})
}

View File

@@ -0,0 +1,342 @@
package repositories
import (
"context"
"fmt"
"time"
"hyapi-server/internal/domains/finance/entities"
"hyapi-server/internal/domains/finance/repositories"
"hyapi-server/internal/domains/finance/value_objects"
"gorm.io/gorm"
)
// GormInvoiceApplicationRepository 发票申请仓储的GORM实现
type GormInvoiceApplicationRepository struct {
db *gorm.DB
}
// NewGormInvoiceApplicationRepository 创建发票申请仓储
func NewGormInvoiceApplicationRepository(db *gorm.DB) repositories.InvoiceApplicationRepository {
return &GormInvoiceApplicationRepository{
db: db,
}
}
// Create 创建发票申请
func (r *GormInvoiceApplicationRepository) Create(ctx context.Context, application *entities.InvoiceApplication) error {
return r.db.WithContext(ctx).Create(application).Error
}
// Update 更新发票申请
func (r *GormInvoiceApplicationRepository) Update(ctx context.Context, application *entities.InvoiceApplication) error {
return r.db.WithContext(ctx).Save(application).Error
}
// Save 保存发票申请
func (r *GormInvoiceApplicationRepository) Save(ctx context.Context, application *entities.InvoiceApplication) error {
return r.db.WithContext(ctx).Save(application).Error
}
// FindByID 根据ID查找发票申请
func (r *GormInvoiceApplicationRepository) FindByID(ctx context.Context, id string) (*entities.InvoiceApplication, error) {
var application entities.InvoiceApplication
err := r.db.WithContext(ctx).Where("id = ?", id).First(&application).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &application, nil
}
// FindByUserID 根据用户ID查找发票申请列表
func (r *GormInvoiceApplicationRepository) FindByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
// 获取总数
err := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID).Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = r.db.WithContext(ctx).Where("user_id = ?", userID).
Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindPendingApplications 查找待处理的发票申请
func (r *GormInvoiceApplicationRepository) FindPendingApplications(ctx context.Context, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
// 获取总数
err := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).
Where("status = ?", entities.ApplicationStatusPending).
Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = r.db.WithContext(ctx).
Where("status = ?", entities.ApplicationStatusPending).
Order("created_at ASC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindByUserIDAndStatus 根据用户ID和状态查找发票申请
func (r *GormInvoiceApplicationRepository) FindByUserIDAndStatus(ctx context.Context, userID string, status entities.ApplicationStatus, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID)
if status != "" {
query = query.Where("status = ?", status)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindByUserIDAndStatusWithTimeRange 根据用户ID、状态和时间范围查找发票申请列表
func (r *GormInvoiceApplicationRepository) FindByUserIDAndStatusWithTimeRange(ctx context.Context, userID string, status entities.ApplicationStatus, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ?", userID)
// 添加状态筛选
if status != "" {
query = query.Where("status = ?", status)
}
// 添加时间范围筛选
if startTime != nil {
query = query.Where("created_at >= ?", startTime)
}
if endTime != nil {
query = query.Where("created_at <= ?", endTime)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindByStatus 根据状态查找发票申请
func (r *GormInvoiceApplicationRepository) FindByStatus(ctx context.Context, status entities.ApplicationStatus) ([]*entities.InvoiceApplication, error) {
var applications []*entities.InvoiceApplication
err := r.db.WithContext(ctx).
Where("status = ?", status).
Order("created_at DESC").
Find(&applications).Error
return applications, err
}
// GetUserInvoiceInfo 获取用户发票信息
// GetUserTotalInvoicedAmount 获取用户已开票总金额
func (r *GormInvoiceApplicationRepository) GetUserTotalInvoicedAmount(ctx context.Context, userID string) (string, error) {
var total string
err := r.db.WithContext(ctx).
Model(&entities.InvoiceApplication{}).
Select("COALESCE(SUM(CAST(amount AS DECIMAL(10,2))), '0')").
Where("user_id = ? AND status = ?", userID, entities.ApplicationStatusCompleted).
Scan(&total).Error
return total, err
}
// GetUserTotalAppliedAmount 获取用户申请开票总金额
func (r *GormInvoiceApplicationRepository) GetUserTotalAppliedAmount(ctx context.Context, userID string) (string, error) {
var total string
err := r.db.WithContext(ctx).
Model(&entities.InvoiceApplication{}).
Select("COALESCE(SUM(CAST(amount AS DECIMAL(10,2))), '0')").
Where("user_id = ?", userID).
Scan(&total).Error
return total, err
}
// FindByUserIDAndInvoiceType 根据用户ID和发票类型查找申请
func (r *GormInvoiceApplicationRepository) FindByUserIDAndInvoiceType(ctx context.Context, userID string, invoiceType value_objects.InvoiceType, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("user_id = ? AND invoice_type = ?", userID, invoiceType)
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindByDateRange 根据日期范围查找申请
func (r *GormInvoiceApplicationRepository) FindByDateRange(ctx context.Context, startDate, endDate string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{})
if startDate != "" {
query = query.Where("DATE(created_at) >= ?", startDate)
}
if endDate != "" {
query = query.Where("DATE(created_at) <= ?", endDate)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// SearchApplications 搜索发票申请
func (r *GormInvoiceApplicationRepository) SearchApplications(ctx context.Context, keyword string, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).
Where("company_name LIKE ? OR email LIKE ? OR tax_number LIKE ?",
fmt.Sprintf("%%%s%%", keyword),
fmt.Sprintf("%%%s%%", keyword),
fmt.Sprintf("%%%s%%", keyword))
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindByStatusWithTimeRange 根据状态和时间范围查找发票申请
func (r *GormInvoiceApplicationRepository) FindByStatusWithTimeRange(ctx context.Context, status entities.ApplicationStatus, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{}).Where("status = ?", status)
// 添加时间范围筛选
if startTime != nil {
query = query.Where("created_at >= ?", startTime)
}
if endTime != nil {
query = query.Where("created_at <= ?", endTime)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}
// FindAllWithTimeRange 根据时间范围查找所有发票申请
func (r *GormInvoiceApplicationRepository) FindAllWithTimeRange(ctx context.Context, startTime, endTime *time.Time, page, pageSize int) ([]*entities.InvoiceApplication, int64, error) {
var applications []*entities.InvoiceApplication
var total int64
query := r.db.WithContext(ctx).Model(&entities.InvoiceApplication{})
// 添加时间范围筛选
if startTime != nil {
query = query.Where("created_at >= ?", startTime)
}
if endTime != nil {
query = query.Where("created_at <= ?", endTime)
}
// 获取总数
err := query.Count(&total).Error
if err != nil {
return nil, 0, err
}
// 获取分页数据
offset := (page - 1) * pageSize
err = query.Order("created_at DESC").
Offset(offset).
Limit(pageSize).
Find(&applications).Error
return applications, total, err
}

View File

@@ -0,0 +1,74 @@
package repositories
import (
"context"
"hyapi-server/internal/domains/finance/entities"
"hyapi-server/internal/domains/finance/repositories"
"gorm.io/gorm"
)
// GormUserInvoiceInfoRepository 用户开票信息仓储的GORM实现
type GormUserInvoiceInfoRepository struct {
db *gorm.DB
}
// NewGormUserInvoiceInfoRepository 创建用户开票信息仓储
func NewGormUserInvoiceInfoRepository(db *gorm.DB) repositories.UserInvoiceInfoRepository {
return &GormUserInvoiceInfoRepository{
db: db,
}
}
// Create 创建用户开票信息
func (r *GormUserInvoiceInfoRepository) Create(ctx context.Context, info *entities.UserInvoiceInfo) error {
return r.db.WithContext(ctx).Create(info).Error
}
// Update 更新用户开票信息
func (r *GormUserInvoiceInfoRepository) Update(ctx context.Context, info *entities.UserInvoiceInfo) error {
return r.db.WithContext(ctx).Save(info).Error
}
// Save 保存用户开票信息(创建或更新)
func (r *GormUserInvoiceInfoRepository) Save(ctx context.Context, info *entities.UserInvoiceInfo) error {
return r.db.WithContext(ctx).Save(info).Error
}
// FindByUserID 根据用户ID查找开票信息
func (r *GormUserInvoiceInfoRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserInvoiceInfo, error) {
var info entities.UserInvoiceInfo
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&info).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &info, nil
}
// FindByID 根据ID查找开票信息
func (r *GormUserInvoiceInfoRepository) FindByID(ctx context.Context, id string) (*entities.UserInvoiceInfo, error) {
var info entities.UserInvoiceInfo
err := r.db.WithContext(ctx).Where("id = ?", id).First(&info).Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, err
}
return &info, nil
}
// Delete 删除用户开票信息
func (r *GormUserInvoiceInfoRepository) Delete(ctx context.Context, userID string) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&entities.UserInvoiceInfo{}).Error
}
// Exists 检查用户开票信息是否存在
func (r *GormUserInvoiceInfoRepository) Exists(ctx context.Context, userID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.UserInvoiceInfo{}).Where("user_id = ?", userID).Count(&count).Error
return count > 0, err
}

View File

@@ -0,0 +1,189 @@
package repositories
import (
"context"
"encoding/json"
"errors"
"time"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ComponentReportDownloadsTable = "component_report_downloads"
)
type GormComponentReportRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ComponentReportRepository = (*GormComponentReportRepository)(nil)
func NewGormComponentReportRepository(db *gorm.DB, logger *zap.Logger) repositories.ComponentReportRepository {
return &GormComponentReportRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ComponentReportDownloadsTable),
}
}
func (r *GormComponentReportRepository) Create(ctx context.Context, download *entities.ComponentReportDownload) error {
return r.CreateEntity(ctx, download)
}
func (r *GormComponentReportRepository) UpdateDownload(ctx context.Context, download *entities.ComponentReportDownload) error {
return r.UpdateEntity(ctx, download)
}
func (r *GormComponentReportRepository) GetDownloadByID(ctx context.Context, id string) (*entities.ComponentReportDownload, error) {
var download entities.ComponentReportDownload
err := r.SmartGetByID(ctx, id, &download)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &download, nil
}
func (r *GormComponentReportRepository) GetUserDownloads(ctx context.Context, userID string, productID *string) ([]*entities.ComponentReportDownload, error) {
var downloads []entities.ComponentReportDownload
query := r.GetDB(ctx).Where("user_id = ?", userID)
if productID != nil && *productID != "" {
query = query.Where("product_id = ?", *productID)
}
err := query.Order("created_at DESC").Find(&downloads).Error
if err != nil {
return nil, err
}
result := make([]*entities.ComponentReportDownload, len(downloads))
for i := range downloads {
result[i] = &downloads[i]
}
return result, nil
}
func (r *GormComponentReportRepository) HasUserDownloaded(ctx context.Context, userID string, productCode string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).
Where("user_id = ? AND product_code = ?", userID, productCode).
Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *GormComponentReportRepository) GetUserDownloadedProductCodes(ctx context.Context, userID string) ([]string, error) {
var downloads []entities.ComponentReportDownload
err := r.GetDB(ctx).
Select("DISTINCT sub_product_codes").
Where("user_id = ?", userID).
Find(&downloads).Error
if err != nil {
return nil, err
}
codesMap := make(map[string]bool)
for _, download := range downloads {
if download.SubProductCodes != "" {
var codes []string
if err := json.Unmarshal([]byte(download.SubProductCodes), &codes); err == nil {
for _, code := range codes {
codesMap[code] = true
}
}
}
// 也添加主产品编号
if download.ProductCode != "" {
codesMap[download.ProductCode] = true
}
}
codes := make([]string, 0, len(codesMap))
for code := range codesMap {
codes = append(codes, code)
}
return codes, nil
}
func (r *GormComponentReportRepository) GetDownloadByPaymentOrderID(ctx context.Context, orderID string) (*entities.ComponentReportDownload, error) {
var download entities.ComponentReportDownload
err := r.GetDB(ctx).Where("order_id = ?", orderID).First(&download).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &download, nil
}
// GetActiveDownload 获取用户有效的下载记录
func (r *GormComponentReportRepository) GetActiveDownload(ctx context.Context, userID, productID string) (*entities.ComponentReportDownload, error) {
var download entities.ComponentReportDownload
// 先尝试查找有支付订单号的下载记录(已支付)
err := r.GetDB(ctx).
Where("user_id = ? AND product_id = ? AND order_number IS NOT NULL AND deleted_at IS NULL", userID, productID).
Order("created_at DESC").
First(&download).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 如果没有找到有支付订单号的记录,尝试查找任何有效的下载记录
err = r.GetDB(ctx).
Where("user_id = ? AND product_id = ? AND deleted_at IS NULL", userID, productID).
Order("created_at DESC").
First(&download).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
} else {
return nil, err
}
}
// 如果找到了下载记录,检查关联的购买订单状态
if download.OrderID != nil {
// 这里需要查询购买订单状态,但当前仓库没有依赖购买订单仓库
// 所以只检查是否有过期时间设置,如果有则认为已支付
if download.ExpiresAt == nil {
return nil, nil // 没有过期时间,表示未支付
}
}
// 检查是否已过期
if download.IsExpired() {
return nil, nil
}
return &download, nil
}
// UpdateFilePath 更新下载记录文件路径
func (r *GormComponentReportRepository) UpdateFilePath(ctx context.Context, downloadID, filePath string) error {
return r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).Where("id = ?", downloadID).Update("file_path", filePath).Error
}
// IncrementDownloadCount 增加下载次数
func (r *GormComponentReportRepository) IncrementDownloadCount(ctx context.Context, downloadID string) error {
now := time.Now()
return r.GetDB(ctx).Model(&entities.ComponentReportDownload{}).
Where("id = ?", downloadID).
Updates(map[string]interface{}{
"download_count": gorm.Expr("download_count + 1"),
"last_download_at": &now,
}).Error
}

View File

@@ -0,0 +1,92 @@
package repositories
import (
"context"
"errors"
"time"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductApiConfigsTable = "product_api_configs"
ProductApiConfigCacheTTL = 30 * time.Minute
)
type GormProductApiConfigRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ProductApiConfigRepository = (*GormProductApiConfigRepository)(nil)
func NewGormProductApiConfigRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductApiConfigRepository {
return &GormProductApiConfigRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductApiConfigsTable),
}
}
func (r *GormProductApiConfigRepository) Create(ctx context.Context, config entities.ProductApiConfig) error {
return r.CreateEntity(ctx, &config)
}
func (r *GormProductApiConfigRepository) Update(ctx context.Context, config entities.ProductApiConfig) error {
return r.UpdateEntity(ctx, &config)
}
func (r *GormProductApiConfigRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.ProductApiConfig{})
}
func (r *GormProductApiConfigRepository) GetByID(ctx context.Context, id string) (*entities.ProductApiConfig, error) {
var config entities.ProductApiConfig
err := r.SmartGetByID(ctx, id, &config)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &config, nil
}
func (r *GormProductApiConfigRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductApiConfig, error) {
var config entities.ProductApiConfig
err := r.SmartGetByField(ctx, &config, "product_id", productID, ProductApiConfigCacheTTL)
if err != nil {
return nil, err
}
return &config, nil
}
func (r *GormProductApiConfigRepository) FindByProductCode(ctx context.Context, productCode string) (*entities.ProductApiConfig, error) {
var config entities.ProductApiConfig
err := r.GetDB(ctx).Joins("JOIN products ON products.id = product_api_configs.product_id").
Where("products.code = ?", productCode).
First(&config).Error
if err != nil {
return nil, err
}
return &config, nil
}
func (r *GormProductApiConfigRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductApiConfig, error) {
var configs []*entities.ProductApiConfig
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&configs).Error
if err != nil {
return nil, err
}
return configs, nil
}
func (r *GormProductApiConfigRepository) ExistsByProductID(ctx context.Context, productID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ProductApiConfig{}).Where("product_id = ?", productID).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}

View File

@@ -0,0 +1,281 @@
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/domains/product/repositories/queries"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductCategoriesTable = "product_categories"
)
type GormProductCategoryRepository struct {
*database.CachedBaseRepositoryImpl
}
func (r *GormProductCategoryRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.ProductCategory{})
}
var _ repositories.ProductCategoryRepository = (*GormProductCategoryRepository)(nil)
func NewGormProductCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductCategoryRepository {
return &GormProductCategoryRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductCategoriesTable),
}
}
func (r *GormProductCategoryRepository) Create(ctx context.Context, entity entities.ProductCategory) (entities.ProductCategory, error) {
err := r.CreateEntity(ctx, &entity)
return entity, err
}
func (r *GormProductCategoryRepository) GetByID(ctx context.Context, id string) (entities.ProductCategory, error) {
var entity entities.ProductCategory
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.ProductCategory{}, gorm.ErrRecordNotFound
}
return entities.ProductCategory{}, err
}
return entity, nil
}
func (r *GormProductCategoryRepository) Update(ctx context.Context, entity entities.ProductCategory) error {
return r.UpdateEntity(ctx, &entity)
}
// FindByCode 根据编号查找产品分类
func (r *GormProductCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductCategory, error) {
var entity entities.ProductCategory
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// FindVisible 查找可见分类
func (r *GormProductCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductCategory, error) {
var categories []entities.ProductCategory
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Order("sort ASC, created_at DESC").Find(&categories).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// FindEnabled 查找启用分类
func (r *GormProductCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductCategory, error) {
var categories []entities.ProductCategory
err := r.GetDB(ctx).Where("is_enabled = ?", true).Order("sort ASC, created_at DESC").Find(&categories).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// ListCategories 获取分类列表
func (r *GormProductCategoryRepository) ListCategories(ctx context.Context, query *queries.ListCategoriesQuery) ([]*entities.ProductCategory, int64, error) {
var categories []entities.ProductCategory
var total int64
dbQuery := r.GetDB(ctx).Model(&entities.ProductCategory{})
// 应用筛选条件
if query.IsEnabled != nil {
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
}
if query.IsVisible != nil {
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
if query.SortBy != "" {
order := query.SortBy
if query.SortOrder == "desc" {
order += " DESC"
} else {
order += " ASC"
}
dbQuery = dbQuery.Order(order)
} else {
// 默认按排序字段和创建时间排序
dbQuery = dbQuery.Order("sort ASC, created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 获取数据
if err := dbQuery.Find(&categories).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.ProductCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, total, nil
}
// CountEnabled 统计启用分类数量
func (r *GormProductCategoryRepository) CountEnabled(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_enabled = ?", true).Count(&count).Error
return count, err
}
// CountVisible 统计可见分类数量
func (r *GormProductCategoryRepository) CountVisible(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
return count, err
}
// 基础Repository接口方法
// Count 返回分类总数
func (r *GormProductCategoryRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// GetByIDs 根据ID列表获取分类
func (r *GormProductCategoryRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ProductCategory, error) {
var categories []entities.ProductCategory
err := r.GetDB(ctx).Where("id IN ?", ids).Order("sort ASC, created_at DESC").Find(&categories).Error
return categories, err
}
// CreateBatch 批量创建分类
func (r *GormProductCategoryRepository) CreateBatch(ctx context.Context, categories []entities.ProductCategory) error {
return r.GetDB(ctx).Create(&categories).Error
}
// UpdateBatch 批量更新分类
func (r *GormProductCategoryRepository) UpdateBatch(ctx context.Context, categories []entities.ProductCategory) error {
return r.GetDB(ctx).Save(&categories).Error
}
// DeleteBatch 批量删除分类
func (r *GormProductCategoryRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id IN ?", ids).Error
}
// List 获取分类列表(基础方法)
func (r *GormProductCategoryRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ProductCategory, error) {
var categories []entities.ProductCategory
query := r.GetDB(ctx).Model(&entities.ProductCategory{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := options.Sort
if options.Order == "desc" {
order += " DESC"
} else {
order += " ASC"
}
query = query.Order(order)
} else {
// 默认按排序字段和创建时间倒序
query = query.Order("sort ASC, created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&categories).Error
return categories, err
}
// Exists 检查分类是否存在
func (r *GormProductCategoryRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ProductCategory{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除分类
func (r *GormProductCategoryRepository) SoftDelete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.ProductCategory{}, "id = ?", id).Error
}
// Restore 恢复软删除的分类
func (r *GormProductCategoryRepository) Restore(ctx context.Context, id string) error {
return r.GetDB(ctx).Unscoped().Model(&entities.ProductCategory{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// WithTx 使用事务
func (r *GormProductCategoryRepository) WithTx(tx interface{}) interfaces.Repository[entities.ProductCategory] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormProductCategoryRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductCategoriesTable),
}
}
return r
}

View File

@@ -0,0 +1,108 @@
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductDocumentationsTable = "product_documentations"
)
type GormProductDocumentationRepository struct {
*database.CachedBaseRepositoryImpl
}
func (r *GormProductDocumentationRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.ProductDocumentation{})
}
var _ repositories.ProductDocumentationRepository = (*GormProductDocumentationRepository)(nil)
func NewGormProductDocumentationRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductDocumentationRepository {
return &GormProductDocumentationRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductDocumentationsTable),
}
}
// Create 创建文档
func (r *GormProductDocumentationRepository) Create(ctx context.Context, documentation *entities.ProductDocumentation) error {
return r.CreateEntity(ctx, documentation)
}
// Update 更新文档
func (r *GormProductDocumentationRepository) Update(ctx context.Context, documentation *entities.ProductDocumentation) error {
return r.UpdateEntity(ctx, documentation)
}
// FindByID 根据ID查找文档
func (r *GormProductDocumentationRepository) FindByID(ctx context.Context, id string) (*entities.ProductDocumentation, error) {
var entity entities.ProductDocumentation
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// FindByProductID 根据产品ID查找文档
func (r *GormProductDocumentationRepository) FindByProductID(ctx context.Context, productID string) (*entities.ProductDocumentation, error) {
var entity entities.ProductDocumentation
err := r.GetDB(ctx).Where("product_id = ?", productID).First(&entity).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// FindByProductIDs 根据产品ID列表批量查找文档
func (r *GormProductDocumentationRepository) FindByProductIDs(ctx context.Context, productIDs []string) ([]*entities.ProductDocumentation, error) {
var documentations []entities.ProductDocumentation
err := r.GetDB(ctx).Where("product_id IN ?", productIDs).Find(&documentations).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductDocumentation, len(documentations))
for i := range documentations {
result[i] = &documentations[i]
}
return result, nil
}
// UpdateBatch 批量更新文档
func (r *GormProductDocumentationRepository) UpdateBatch(ctx context.Context, documentations []*entities.ProductDocumentation) error {
if len(documentations) == 0 {
return nil
}
// 使用事务进行批量更新
return r.GetDB(ctx).Transaction(func(tx *gorm.DB) error {
for _, doc := range documentations {
if err := tx.Save(doc).Error; err != nil {
return err
}
}
return nil
})
}
// CountByProductID 统计指定产品的文档数量
func (r *GormProductDocumentationRepository) CountByProductID(ctx context.Context, productID string) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ProductDocumentation{}).Where("product_id = ?", productID).Count(&count).Error
return count, err
}

View File

@@ -0,0 +1,521 @@
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/domains/product/repositories/queries"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductsTable = "products"
)
type GormProductRepository struct {
*database.CachedBaseRepositoryImpl
}
func (r *GormProductRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.Product{})
}
var _ repositories.ProductRepository = (*GormProductRepository)(nil)
func NewGormProductRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductRepository {
return &GormProductRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductsTable),
}
}
func (r *GormProductRepository) Create(ctx context.Context, entity entities.Product) (entities.Product, error) {
err := r.CreateEntity(ctx, &entity)
return entity, err
}
func (r *GormProductRepository) GetByID(ctx context.Context, id string) (entities.Product, error) {
var entity entities.Product
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Product{}, gorm.ErrRecordNotFound
}
return entities.Product{}, err
}
return entity, nil
}
func (r *GormProductRepository) Update(ctx context.Context, entity entities.Product) error {
return r.UpdateEntity(ctx, &entity)
}
// 其它方法同理迁移全部用r.GetDB(ctx)
// FindByCode 根据编号查找产品
func (r *GormProductRepository) FindByCode(ctx context.Context, code string) (*entities.Product, error) {
var entity entities.Product
err := r.SmartGetByField(ctx, &entity, "code", code) // 自动缓存
if err != nil {
return nil, err
}
return &entity, nil
}
// FindByOldID 根据旧ID查找产品
func (r *GormProductRepository) FindByOldID(ctx context.Context, oldID string) (*entities.Product, error) {
var entity entities.Product
err := r.GetDB(ctx).Where("old_id = ?", oldID).First(&entity).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// FindByCategoryID 根据分类ID查找产品
func (r *GormProductRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.GetDB(ctx).Preload("Category").Where("category_id = ?", categoryID).Order("created_at DESC").Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, nil
}
// FindVisible 查找可见产品
func (r *GormProductRepository) FindVisible(ctx context.Context) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.GetDB(ctx).Preload("Category").Where("is_visible = ? AND is_enabled = ?", true, true).Order("created_at DESC").Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, nil
}
// FindEnabled 查找启用产品
func (r *GormProductRepository) FindEnabled(ctx context.Context) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.GetDB(ctx).Preload("Category").Where("is_enabled = ?", true).Order("created_at DESC").Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, nil
}
// ListProducts 获取产品列表
func (r *GormProductRepository) ListProducts(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, int64, error) {
var productEntities []entities.Product
var total int64
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if query.Keyword != "" {
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
if query.CategoryID != "" {
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
}
if query.MinPrice != nil {
dbQuery = dbQuery.Where("price >= ?", *query.MinPrice)
}
if query.MaxPrice != nil {
dbQuery = dbQuery.Where("price <= ?", *query.MaxPrice)
}
if query.IsEnabled != nil {
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
}
if query.IsVisible != nil {
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
}
if query.IsPackage != nil {
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
if query.SortBy != "" {
order := query.SortBy
if query.SortOrder == "desc" {
order += " DESC"
} else {
order += " ASC"
}
dbQuery = dbQuery.Order(order)
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载分类信息并获取数据
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, total, nil
}
// ListProductsWithSubscriptionStatus 获取产品列表(包含订阅状态)
func (r *GormProductRepository) ListProductsWithSubscriptionStatus(ctx context.Context, query *queries.ListProductsQuery) ([]*entities.Product, map[string]bool, int64, error) {
var productEntities []entities.Product
var total int64
dbQuery := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if query.Keyword != "" {
dbQuery = dbQuery.Where("name LIKE ? OR description LIKE ? OR code LIKE ?",
"%"+query.Keyword+"%", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
if query.CategoryID != "" {
dbQuery = dbQuery.Where("category_id = ?", query.CategoryID)
}
if query.MinPrice != nil {
dbQuery = dbQuery.Where("price >= ?", *query.MinPrice)
}
if query.MaxPrice != nil {
dbQuery = dbQuery.Where("price <= ?", *query.MaxPrice)
}
if query.IsEnabled != nil {
dbQuery = dbQuery.Where("is_enabled = ?", *query.IsEnabled)
}
if query.IsVisible != nil {
dbQuery = dbQuery.Where("is_visible = ?", *query.IsVisible)
}
if query.IsPackage != nil {
dbQuery = dbQuery.Where("is_package = ?", *query.IsPackage)
}
// 如果指定了用户ID添加订阅状态筛选
if query.UserID != "" && query.IsSubscribed != nil {
if *query.IsSubscribed {
// 筛选已订阅的产品
dbQuery = dbQuery.Where("EXISTS (SELECT 1 FROM subscription WHERE subscription.product_id = product.id AND subscription.user_id = ?)", query.UserID)
} else {
// 筛选未订阅的产品
dbQuery = dbQuery.Where("NOT EXISTS (SELECT 1 FROM subscription WHERE subscription.product_id = product.id AND subscription.user_id = ?)", query.UserID)
}
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, nil, 0, err
}
// 应用排序
if query.SortBy != "" {
order := query.SortBy
if query.SortOrder == "desc" {
order += " DESC"
} else {
order += " ASC"
}
dbQuery = dbQuery.Order(order)
} else {
dbQuery = dbQuery.Order("created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载分类信息并获取数据
if err := dbQuery.Preload("Category").Find(&productEntities).Error; err != nil {
return nil, nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
// 获取订阅状态映射
subscriptionStatusMap := make(map[string]bool)
if query.UserID != "" && len(result) > 0 {
productIDs := make([]string, len(result))
for i, product := range result {
productIDs[i] = product.ID
}
// 查询用户的订阅状态
var subscriptions []struct {
ProductID string `gorm:"column:product_id"`
}
err := r.GetDB(ctx).Table("subscription").
Select("product_id").
Where("user_id = ? AND product_id IN ?", query.UserID, productIDs).
Find(&subscriptions).Error
if err == nil {
for _, sub := range subscriptions {
subscriptionStatusMap[sub.ProductID] = true
}
}
}
return result, subscriptionStatusMap, total, nil
}
// FindSubscribableProducts 查找可订阅产品
func (r *GormProductRepository) FindSubscribableProducts(ctx context.Context, userID string) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.GetDB(ctx).Where("is_enabled = ? AND is_visible = ?", true, true).Order("created_at DESC").Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, nil
}
// FindProductsByIDs 根据ID列表查找产品
func (r *GormProductRepository) FindProductsByIDs(ctx context.Context, ids []string) ([]*entities.Product, error) {
var productEntities []entities.Product
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&productEntities).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Product, len(productEntities))
for i := range productEntities {
result[i] = &productEntities[i]
}
return result, nil
}
// CountByCategory 统计分类下的产品数量
func (r *GormProductRepository) CountByCategory(ctx context.Context, categoryID string) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.Product{})
if categoryID != "" {
query = query.Where("category_id = ?", categoryID)
}
err := query.Count(&count).Error
return count, err
}
// CountEnabled 统计启用产品数量
func (r *GormProductRepository) CountEnabled(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_enabled = ?", true).Count(&count).Error
return count, err
}
// CountVisible 统计可见产品数量
func (r *GormProductRepository) CountVisible(ctx context.Context) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Product{}).Where("is_visible = ? AND is_enabled = ?", true, true).Count(&count).Error
return count, err
}
// Count 返回产品总数
func (r *GormProductRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// GetByIDs 根据ID列表获取产品
func (r *GormProductRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Product, error) {
var products []entities.Product
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&products).Error
return products, err
}
// CreateBatch 批量创建产品
func (r *GormProductRepository) CreateBatch(ctx context.Context, products []entities.Product) error {
return r.GetDB(ctx).Create(&products).Error
}
// UpdateBatch 批量更新产品
func (r *GormProductRepository) UpdateBatch(ctx context.Context, products []entities.Product) error {
return r.GetDB(ctx).Save(&products).Error
}
// DeleteBatch 批量删除产品
func (r *GormProductRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.Product{}, "id IN ?", ids).Error
}
// List 获取产品列表(基础方法)
func (r *GormProductRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Product, error) {
var products []entities.Product
query := r.GetDB(ctx).Model(&entities.Product{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("name LIKE ? OR description LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := options.Sort
if options.Order == "desc" {
order += " DESC"
} else {
order += " ASC"
}
query = query.Order(order)
} else {
// 默认按创建时间倒序
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&products).Error
return products, err
}
// Exists 检查产品是否存在
func (r *GormProductRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.Product{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除产品
func (r *GormProductRepository) SoftDelete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.Product{}, "id = ?", id).Error
}
// Restore 恢复软删除的产品
func (r *GormProductRepository) Restore(ctx context.Context, id string) error {
return r.GetDB(ctx).Unscoped().Model(&entities.Product{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// GetPackageItems 获取组合包项目
func (r *GormProductRepository) GetPackageItems(ctx context.Context, packageID string) ([]*entities.ProductPackageItem, error) {
var packageItems []entities.ProductPackageItem
err := r.GetDB(ctx).
Preload("Product").
Where("package_id = ?", packageID).
Order("sort_order ASC").
Find(&packageItems).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductPackageItem, len(packageItems))
for i := range packageItems {
result[i] = &packageItems[i]
}
return result, nil
}
// CreatePackageItem 创建组合包项目
func (r *GormProductRepository) CreatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
return r.GetDB(ctx).Create(packageItem).Error
}
// GetPackageItemByID 根据ID获取组合包项目
func (r *GormProductRepository) GetPackageItemByID(ctx context.Context, itemID string) (*entities.ProductPackageItem, error) {
var packageItem entities.ProductPackageItem
err := r.GetDB(ctx).
Preload("Product").
Preload("Package").
Where("id = ?", itemID).
First(&packageItem).Error
if err != nil {
return nil, err
}
return &packageItem, nil
}
// UpdatePackageItem 更新组合包项目
func (r *GormProductRepository) UpdatePackageItem(ctx context.Context, packageItem *entities.ProductPackageItem) error {
return r.GetDB(ctx).Save(packageItem).Error
}
// DeletePackageItem 删除组合包项目(硬删除)
func (r *GormProductRepository) DeletePackageItem(ctx context.Context, itemID string) error {
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "id = ?", itemID).Error
}
// DeletePackageItemsByPackageID 根据组合包ID删除所有子产品硬删除
func (r *GormProductRepository) DeletePackageItemsByPackageID(ctx context.Context, packageID string) error {
return r.GetDB(ctx).Unscoped().Delete(&entities.ProductPackageItem{}, "package_id = ?", packageID).Error
}
// WithTx 使用事务
func (r *GormProductRepository) WithTx(tx interface{}) interfaces.Repository[entities.Product] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormProductRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), ProductsTable),
}
}
return r
}

View File

@@ -0,0 +1,137 @@
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ProductSubCategoriesTable = "product_sub_categories"
)
type GormProductSubCategoryRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.ProductSubCategoryRepository = (*GormProductSubCategoryRepository)(nil)
func NewGormProductSubCategoryRepository(db *gorm.DB, logger *zap.Logger) repositories.ProductSubCategoryRepository {
return &GormProductSubCategoryRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ProductSubCategoriesTable),
}
}
// Create 创建二级分类
func (r *GormProductSubCategoryRepository) Create(ctx context.Context, category entities.ProductSubCategory) (*entities.ProductSubCategory, error) {
err := r.CreateEntity(ctx, &category)
if err != nil {
return nil, err
}
return &category, nil
}
// GetByID 根据ID获取二级分类
func (r *GormProductSubCategoryRepository) GetByID(ctx context.Context, id string) (*entities.ProductSubCategory, error) {
var entity entities.ProductSubCategory
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// Update 更新二级分类
func (r *GormProductSubCategoryRepository) Update(ctx context.Context, category entities.ProductSubCategory) error {
return r.UpdateEntity(ctx, &category)
}
// Delete 删除二级分类
func (r *GormProductSubCategoryRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.ProductSubCategory{})
}
// List 获取所有二级分类
func (r *GormProductSubCategoryRepository) List(ctx context.Context) ([]*entities.ProductSubCategory, error) {
var categories []entities.ProductSubCategory
err := r.GetDB(ctx).Order("sort ASC, created_at DESC").Find(&categories).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductSubCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// FindByCode 根据编号查找二级分类
func (r *GormProductSubCategoryRepository) FindByCode(ctx context.Context, code string) (*entities.ProductSubCategory, error) {
var entity entities.ProductSubCategory
err := r.GetDB(ctx).Where("code = ?", code).First(&entity).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// FindByCategoryID 根据一级分类ID查找二级分类
func (r *GormProductSubCategoryRepository) FindByCategoryID(ctx context.Context, categoryID string) ([]*entities.ProductSubCategory, error) {
var categories []entities.ProductSubCategory
err := r.GetDB(ctx).Where("category_id = ?", categoryID).Order("sort ASC, created_at DESC").Find(&categories).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductSubCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// FindVisible 查找可见的二级分类
func (r *GormProductSubCategoryRepository) FindVisible(ctx context.Context) ([]*entities.ProductSubCategory, error) {
var categories []entities.ProductSubCategory
err := r.GetDB(ctx).Where("is_visible = ? AND is_enabled = ?", true, true).Order("sort ASC, created_at DESC").Find(&categories).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductSubCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}
// FindEnabled 查找启用的二级分类
func (r *GormProductSubCategoryRepository) FindEnabled(ctx context.Context) ([]*entities.ProductSubCategory, error) {
var categories []entities.ProductSubCategory
err := r.GetDB(ctx).Where("is_enabled = ?", true).Order("sort ASC, created_at DESC").Find(&categories).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.ProductSubCategory, len(categories))
for i := range categories {
result[i] = &categories[i]
}
return result, nil
}

View File

@@ -0,0 +1,80 @@
package repositories
import (
"context"
"fmt"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"gorm.io/gorm"
)
// GormProductUIComponentRepository 产品UI组件关联仓储实现
type GormProductUIComponentRepository struct {
db *gorm.DB
}
// NewGormProductUIComponentRepository 创建产品UI组件关联仓储实例
func NewGormProductUIComponentRepository(db *gorm.DB) repositories.ProductUIComponentRepository {
return &GormProductUIComponentRepository{db: db}
}
// Create 创建产品UI组件关联
func (r *GormProductUIComponentRepository) Create(ctx context.Context, relation entities.ProductUIComponent) (entities.ProductUIComponent, error) {
if err := r.db.WithContext(ctx).Create(&relation).Error; err != nil {
return entities.ProductUIComponent{}, fmt.Errorf("创建产品UI组件关联失败: %w", err)
}
return relation, nil
}
// GetByProductID 根据产品ID获取UI组件关联列表
func (r *GormProductUIComponentRepository) GetByProductID(ctx context.Context, productID string) ([]entities.ProductUIComponent, error) {
var relations []entities.ProductUIComponent
if err := r.db.WithContext(ctx).
Preload("UIComponent").
Where("product_id = ?", productID).
Find(&relations).Error; err != nil {
return nil, fmt.Errorf("获取产品UI组件关联列表失败: %w", err)
}
return relations, nil
}
// GetByUIComponentID 根据UI组件ID获取产品关联列表
func (r *GormProductUIComponentRepository) GetByUIComponentID(ctx context.Context, componentID string) ([]entities.ProductUIComponent, error) {
var relations []entities.ProductUIComponent
if err := r.db.WithContext(ctx).
Preload("Product").
Where("ui_component_id = ?", componentID).
Find(&relations).Error; err != nil {
return nil, fmt.Errorf("获取UI组件产品关联列表失败: %w", err)
}
return relations, nil
}
// Delete 删除产品UI组件关联
func (r *GormProductUIComponentRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.ProductUIComponent{}, id).Error; err != nil {
return fmt.Errorf("删除产品UI组件关联失败: %w", err)
}
return nil
}
// DeleteByProductID 根据产品ID删除所有关联
func (r *GormProductUIComponentRepository) DeleteByProductID(ctx context.Context, productID string) error {
if err := r.db.WithContext(ctx).Where("product_id = ?", productID).Delete(&entities.ProductUIComponent{}).Error; err != nil {
return fmt.Errorf("根据产品ID删除UI组件关联失败: %w", err)
}
return nil
}
// BatchCreate 批量创建产品UI组件关联
func (r *GormProductUIComponentRepository) BatchCreate(ctx context.Context, relations []entities.ProductUIComponent) error {
if len(relations) == 0 {
return nil
}
if err := r.db.WithContext(ctx).CreateInBatches(relations, 100).Error; err != nil {
return fmt.Errorf("批量创建产品UI组件关联失败: %w", err)
}
return nil
}

View File

@@ -0,0 +1,354 @@
package repositories
import (
"context"
"errors"
"time"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"hyapi-server/internal/domains/product/repositories/queries"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
SubscriptionsTable = "subscription"
SubscriptionCacheTTL = 60 * time.Minute
)
type GormSubscriptionRepository struct {
*database.CachedBaseRepositoryImpl
}
func (r *GormSubscriptionRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.Subscription{})
}
var _ repositories.SubscriptionRepository = (*GormSubscriptionRepository)(nil)
func NewGormSubscriptionRepository(db *gorm.DB, logger *zap.Logger) repositories.SubscriptionRepository {
return &GormSubscriptionRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SubscriptionsTable),
}
}
func (r *GormSubscriptionRepository) Create(ctx context.Context, entity entities.Subscription) (entities.Subscription, error) {
err := r.CreateEntity(ctx, &entity)
return entity, err
}
func (r *GormSubscriptionRepository) GetByID(ctx context.Context, id string) (entities.Subscription, error) {
var entity entities.Subscription
err := r.SmartGetByID(ctx, id, &entity)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.Subscription{}, gorm.ErrRecordNotFound
}
return entities.Subscription{}, err
}
return entity, nil
}
func (r *GormSubscriptionRepository) Update(ctx context.Context, entity entities.Subscription) error {
return r.UpdateEntity(ctx, &entity)
}
// FindByUserID 根据用户ID查找订阅
func (r *GormSubscriptionRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.Subscription, error) {
var subscriptions []entities.Subscription
err := r.GetDB(ctx).WithContext(ctx).Where("user_id = ?", userID).Order("created_at DESC").Find(&subscriptions).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Subscription, len(subscriptions))
for i := range subscriptions {
result[i] = &subscriptions[i]
}
return result, nil
}
// FindByProductID 根据产品ID查找订阅
func (r *GormSubscriptionRepository) FindByProductID(ctx context.Context, productID string) ([]*entities.Subscription, error) {
var subscriptions []entities.Subscription
err := r.GetDB(ctx).WithContext(ctx).Where("product_id = ?", productID).Order("created_at DESC").Find(&subscriptions).Error
if err != nil {
return nil, err
}
// 转换为指针切片
result := make([]*entities.Subscription, len(subscriptions))
for i := range subscriptions {
result[i] = &subscriptions[i]
}
return result, nil
}
// FindByUserAndProduct 根据用户和产品查找订阅
func (r *GormSubscriptionRepository) FindByUserAndProduct(ctx context.Context, userID, productID string) (*entities.Subscription, error) {
var entity entities.Subscription
// 组合缓存key的条件
where := "user_id = ? AND product_id = ?"
ttl := SubscriptionCacheTTL // 缓存10分钟可根据业务调整
err := r.GetWithCache(ctx, &entity, ttl, where, userID, productID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, gorm.ErrRecordNotFound
}
return nil, err
}
return &entity, nil
}
// ListSubscriptions 获取订阅列表
func (r *GormSubscriptionRepository) ListSubscriptions(ctx context.Context, query *queries.ListSubscriptionsQuery) ([]*entities.Subscription, int64, error) {
var subscriptions []entities.Subscription
var total int64
dbQuery := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
// 应用筛选条件
if query.UserID != "" {
dbQuery = dbQuery.Where("subscription.user_id = ?", query.UserID)
}
// 关键词搜索(产品名称或编码)
if query.Keyword != "" {
dbQuery = dbQuery.Joins("LEFT JOIN product ON product.id = subscription.product_id").
Where("product.name LIKE ? OR product.code LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
// 产品名称筛选
if query.ProductName != "" {
dbQuery = dbQuery.Joins("LEFT JOIN product ON product.id = subscription.product_id").
Where("product.name LIKE ?", "%"+query.ProductName+"%")
}
// 企业名称筛选(需要关联用户和企业信息)
if query.CompanyName != "" {
dbQuery = dbQuery.Joins("LEFT JOIN users ON users.id = subscription.user_id").
Joins("LEFT JOIN enterprise_infos ON enterprise_infos.user_id = users.id").
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
}
// 时间范围筛选
if query.StartTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", query.StartTime); err == nil {
dbQuery = dbQuery.Where("subscription.created_at >= ?", t)
}
}
if query.EndTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", query.EndTime); err == nil {
dbQuery = dbQuery.Where("subscription.created_at <= ?", t)
}
}
// 获取总数
if err := dbQuery.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序
if query.SortBy != "" {
order := query.SortBy
if query.SortOrder == "desc" {
order += " DESC"
} else {
order += " ASC"
}
dbQuery = dbQuery.Order(order)
} else {
dbQuery = dbQuery.Order("subscription.created_at DESC")
}
// 应用分页
if query.Page > 0 && query.PageSize > 0 {
offset := (query.Page - 1) * query.PageSize
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
}
// 预加载Product的id、name、code、price、cost_price、is_package字段并同时预加载ProductCategory的id、name、code字段
if err := dbQuery.
Preload("Product", func(db *gorm.DB) *gorm.DB {
return db.Select("id", "name", "code", "price", "cost_price", "is_package", "category_id").
Preload("Category", func(db2 *gorm.DB) *gorm.DB {
return db2.Select("id", "name", "code")
})
}).
Find(&subscriptions).Error; err != nil {
return nil, 0, err
}
// 转换为指针切片
result := make([]*entities.Subscription, len(subscriptions))
for i := range subscriptions {
result[i] = &subscriptions[i]
}
return result, total, nil
}
// CountByUser 统计用户订阅数量
func (r *GormSubscriptionRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
var count int64
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("user_id = ?", userID).Count(&count).Error
return count, err
}
// CountByProduct 统计产品的订阅数量
func (r *GormSubscriptionRepository) CountByProduct(ctx context.Context, productID string) (int64, error) {
var count int64
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("product_id = ?", productID).Count(&count).Error
return count, err
}
// GetTotalRevenue 获取总收入
func (r *GormSubscriptionRepository) GetTotalRevenue(ctx context.Context) (float64, error) {
var total decimal.Decimal
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Select("COALESCE(SUM(price), 0)").Scan(&total).Error
if err != nil {
return 0, err
}
return total.InexactFloat64(), nil
}
// 基础Repository接口方法
// Count 返回订阅总数
func (r *GormSubscriptionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("user_id LIKE ? OR product_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// GetByIDs 根据ID列表获取订阅
func (r *GormSubscriptionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Subscription, error) {
var subscriptions []entities.Subscription
err := r.GetDB(ctx).WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&subscriptions).Error
return subscriptions, err
}
// CreateBatch 批量创建订阅
func (r *GormSubscriptionRepository) CreateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
return r.GetDB(ctx).WithContext(ctx).Create(&subscriptions).Error
}
// UpdateBatch 批量更新订阅
func (r *GormSubscriptionRepository) UpdateBatch(ctx context.Context, subscriptions []entities.Subscription) error {
return r.GetDB(ctx).WithContext(ctx).Save(&subscriptions).Error
}
// DeleteBatch 批量删除订阅
func (r *GormSubscriptionRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id IN ?", ids).Error
}
// List 获取订阅列表(基础方法)
func (r *GormSubscriptionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Subscription, error) {
var subscriptions []entities.Subscription
query := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("user_id LIKE ? OR product_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
// 应用排序
if options.Sort != "" {
order := options.Sort
if options.Order == "desc" {
order += " DESC"
} else {
order += " ASC"
}
query = query.Order(order)
} else {
// 默认按创建时间倒序
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&subscriptions).Error
return subscriptions, err
}
// Exists 检查订阅是否存在
func (r *GormSubscriptionRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).WithContext(ctx).Model(&entities.Subscription{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// SoftDelete 软删除订阅
func (r *GormSubscriptionRepository) SoftDelete(ctx context.Context, id string) error {
return r.GetDB(ctx).WithContext(ctx).Delete(&entities.Subscription{}, "id = ?", id).Error
}
// Restore 恢复软删除的订阅
func (r *GormSubscriptionRepository) Restore(ctx context.Context, id string) error {
return r.GetDB(ctx).WithContext(ctx).Unscoped().Model(&entities.Subscription{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// WithTx 使用事务
func (r *GormSubscriptionRepository) WithTx(tx interface{}) interfaces.Repository[entities.Subscription] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormSubscriptionRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(gormTx, r.GetLogger(), SubscriptionsTable),
}
}
return r
}
// IncrementAPIUsageWithOptimisticLock 使用乐观锁增加API使用次数
func (r *GormSubscriptionRepository) IncrementAPIUsageWithOptimisticLock(ctx context.Context, subscriptionID string, increment int64) error {
// 使用原生SQL进行乐观锁更新
result := r.GetDB(ctx).WithContext(ctx).Exec(`
UPDATE subscription
SET api_used = api_used + ?, version = version + 1, updated_at = NOW()
WHERE id = ? AND version = (
SELECT version FROM subscription WHERE id = ?
)
`, increment, subscriptionID, subscriptionID)
if result.Error != nil {
return result.Error
}
// 检查是否有行被更新
if result.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}

View File

@@ -0,0 +1,130 @@
package repositories
import (
"context"
"fmt"
"hyapi-server/internal/domains/product/entities"
"hyapi-server/internal/domains/product/repositories"
"gorm.io/gorm"
)
// GormUIComponentRepository UI组件仓储实现
type GormUIComponentRepository struct {
db *gorm.DB
}
// NewGormUIComponentRepository 创建UI组件仓储实例
func NewGormUIComponentRepository(db *gorm.DB) repositories.UIComponentRepository {
return &GormUIComponentRepository{db: db}
}
// Create 创建UI组件
func (r *GormUIComponentRepository) Create(ctx context.Context, component entities.UIComponent) (entities.UIComponent, error) {
if err := r.db.WithContext(ctx).Create(&component).Error; err != nil {
return entities.UIComponent{}, fmt.Errorf("创建UI组件失败: %w", err)
}
return component, nil
}
// GetByID 根据ID获取UI组件
func (r *GormUIComponentRepository) GetByID(ctx context.Context, id string) (*entities.UIComponent, error) {
var component entities.UIComponent
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&component).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("获取UI组件失败: %w", err)
}
return &component, nil
}
// GetByCode 根据编码获取UI组件
func (r *GormUIComponentRepository) GetByCode(ctx context.Context, code string) (*entities.UIComponent, error) {
var component entities.UIComponent
if err := r.db.WithContext(ctx).Where("component_code = ?", code).First(&component).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("获取UI组件失败: %w", err)
}
return &component, nil
}
// List 获取UI组件列表
func (r *GormUIComponentRepository) List(ctx context.Context, filters map[string]interface{}) ([]entities.UIComponent, int64, error) {
var components []entities.UIComponent
var total int64
query := r.db.WithContext(ctx).Model(&entities.UIComponent{})
// 应用过滤条件
if isActive, ok := filters["is_active"]; ok {
query = query.Where("is_active = ?", isActive)
}
if keyword, ok := filters["keyword"]; ok && keyword != "" {
query = query.Where("component_name LIKE ? OR component_code LIKE ? OR description LIKE ?",
"%"+keyword.(string)+"%", "%"+keyword.(string)+"%", "%"+keyword.(string)+"%")
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("获取UI组件总数失败: %w", err)
}
// 分页
if page, ok := filters["page"]; ok {
if pageSize, ok := filters["page_size"]; ok {
offset := (page.(int) - 1) * pageSize.(int)
query = query.Offset(offset).Limit(pageSize.(int))
}
}
// 排序
if sortBy, ok := filters["sort_by"]; ok {
if sortOrder, ok := filters["sort_order"]; ok {
query = query.Order(fmt.Sprintf("%s %s", sortBy, sortOrder))
}
} else {
query = query.Order("sort_order ASC, created_at DESC")
}
// 获取数据
if err := query.Find(&components).Error; err != nil {
return nil, 0, fmt.Errorf("获取UI组件列表失败: %w", err)
}
return components, total, nil
}
// Update 更新UI组件
func (r *GormUIComponentRepository) Update(ctx context.Context, component entities.UIComponent) error {
if err := r.db.WithContext(ctx).Save(&component).Error; err != nil {
return fmt.Errorf("更新UI组件失败: %w", err)
}
return nil
}
// Delete 删除UI组件
func (r *GormUIComponentRepository) Delete(ctx context.Context, id string) error {
// 记录删除操作的详细信息
if err := r.db.WithContext(ctx).Where("id = ?", id).Delete(&entities.UIComponent{}).Error; err != nil {
return fmt.Errorf("删除UI组件失败: %w", err)
}
return nil
}
// GetByCodes 根据编码列表获取UI组件
func (r *GormUIComponentRepository) GetByCodes(ctx context.Context, codes []string) ([]entities.UIComponent, error) {
var components []entities.UIComponent
if len(codes) == 0 {
return components, nil
}
if err := r.db.WithContext(ctx).Where("component_code IN ?", codes).Find(&components).Error; err != nil {
return nil, fmt.Errorf("根据编码列表获取UI组件失败: %w", err)
}
return components, nil
}

View File

@@ -0,0 +1,461 @@
package statistics
import (
"context"
"fmt"
"gorm.io/gorm"
"hyapi-server/internal/domains/statistics/entities"
"hyapi-server/internal/domains/statistics/repositories"
)
// GormStatisticsDashboardRepository GORM统计仪表板仓储实现
type GormStatisticsDashboardRepository struct {
db *gorm.DB
}
// NewGormStatisticsDashboardRepository 创建GORM统计仪表板仓储
func NewGormStatisticsDashboardRepository(db *gorm.DB) repositories.StatisticsDashboardRepository {
return &GormStatisticsDashboardRepository{
db: db,
}
}
// Save 保存统计仪表板
func (r *GormStatisticsDashboardRepository) Save(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
if dashboard == nil {
return fmt.Errorf("统计仪表板不能为空")
}
// 验证仪表板
if err := dashboard.Validate(); err != nil {
return fmt.Errorf("统计仪表板验证失败: %w", err)
}
// 保存到数据库
result := r.db.WithContext(ctx).Save(dashboard)
if result.Error != nil {
return fmt.Errorf("保存统计仪表板失败: %w", result.Error)
}
return nil
}
// FindByID 根据ID查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsDashboard, error) {
if id == "" {
return nil, fmt.Errorf("仪表板ID不能为空")
}
var dashboard entities.StatisticsDashboard
result := r.db.WithContext(ctx).Where("id = ?", id).First(&dashboard)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("统计仪表板不存在")
}
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return &dashboard, nil
}
// FindByUser 根据用户查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userID == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("created_by = ?", userID)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindByUserRole 根据用户角色查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByUserRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// Update 更新统计仪表板
func (r *GormStatisticsDashboardRepository) Update(ctx context.Context, dashboard *entities.StatisticsDashboard) error {
if dashboard == nil {
return fmt.Errorf("统计仪表板不能为空")
}
if dashboard.ID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
// 验证仪表板
if err := dashboard.Validate(); err != nil {
return fmt.Errorf("统计仪表板验证失败: %w", err)
}
// 更新数据库
result := r.db.WithContext(ctx).Save(dashboard)
if result.Error != nil {
return fmt.Errorf("更新统计仪表板失败: %w", result.Error)
}
return nil
}
// Delete 删除统计仪表板
func (r *GormStatisticsDashboardRepository) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("仪表板ID不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("删除统计仪表板失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("统计仪表板不存在")
}
return nil
}
// FindByRole 根据角色查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("user_role = ?", userRole)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindDefaultByRole 根据角色查找默认统计仪表板
func (r *GormStatisticsDashboardRepository) FindDefaultByRole(ctx context.Context, userRole string) (*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboard entities.StatisticsDashboard
result := r.db.WithContext(ctx).
Where("user_role = ? AND is_default = ? AND is_active = ?", userRole, true, true).
First(&dashboard)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("默认统计仪表板不存在")
}
return nil, fmt.Errorf("查询默认统计仪表板失败: %w", result.Error)
}
return &dashboard, nil
}
// FindActiveByRole 根据角色查找激活的统计仪表板
func (r *GormStatisticsDashboardRepository) FindActiveByRole(ctx context.Context, userRole string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if userRole == "" {
return nil, fmt.Errorf("用户角色不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).
Where("user_role = ? AND is_active = ?", userRole, true)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询激活统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindByStatus 根据状态查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByStatus(ctx context.Context, isActive bool, limit, offset int) ([]*entities.StatisticsDashboard, error) {
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("is_active = ?", isActive)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// FindByAccessLevel 根据访问级别查找统计仪表板
func (r *GormStatisticsDashboardRepository) FindByAccessLevel(ctx context.Context, accessLevel string, limit, offset int) ([]*entities.StatisticsDashboard, error) {
if accessLevel == "" {
return nil, fmt.Errorf("访问级别不能为空")
}
var dashboards []*entities.StatisticsDashboard
query := r.db.WithContext(ctx).Where("access_level = ?", accessLevel)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&dashboards)
if result.Error != nil {
return nil, fmt.Errorf("查询统计仪表板失败: %w", result.Error)
}
return dashboards, nil
}
// CountByUser 根据用户统计数量
func (r *GormStatisticsDashboardRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
if userID == "" {
return 0, fmt.Errorf("用户ID不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("created_by = ?", userID).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
}
return count, nil
}
// CountByRole 根据角色统计数量
func (r *GormStatisticsDashboardRepository) CountByRole(ctx context.Context, userRole string) (int64, error) {
if userRole == "" {
return 0, fmt.Errorf("用户角色不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("user_role = ?", userRole).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
}
return count, nil
}
// CountByStatus 根据状态统计数量
func (r *GormStatisticsDashboardRepository) CountByStatus(ctx context.Context, isActive bool) (int64, error) {
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("is_active = ?", isActive).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计仪表板数量失败: %w", result.Error)
}
return count, nil
}
// BatchSave 批量保存统计仪表板
func (r *GormStatisticsDashboardRepository) BatchSave(ctx context.Context, dashboards []*entities.StatisticsDashboard) error {
if len(dashboards) == 0 {
return fmt.Errorf("统计仪表板列表不能为空")
}
// 验证所有仪表板
for _, dashboard := range dashboards {
if err := dashboard.Validate(); err != nil {
return fmt.Errorf("统计仪表板验证失败: %w", err)
}
}
// 批量保存
result := r.db.WithContext(ctx).CreateInBatches(dashboards, 100)
if result.Error != nil {
return fmt.Errorf("批量保存统计仪表板失败: %w", result.Error)
}
return nil
}
// BatchDelete 批量删除统计仪表板
func (r *GormStatisticsDashboardRepository) BatchDelete(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return fmt.Errorf("仪表板ID列表不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsDashboard{}, "id IN ?", ids)
if result.Error != nil {
return fmt.Errorf("批量删除统计仪表板失败: %w", result.Error)
}
return nil
}
// SetDefaultDashboard 设置默认仪表板
func (r *GormStatisticsDashboardRepository) SetDefaultDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
// 开始事务
tx := r.db.WithContext(ctx).Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 先取消同角色的所有默认状态
var dashboard entities.StatisticsDashboard
if err := tx.Where("id = ?", dashboardID).First(&dashboard).Error; err != nil {
tx.Rollback()
return fmt.Errorf("查询仪表板失败: %w", err)
}
// 取消同角色的所有默认状态
if err := tx.Model(&entities.StatisticsDashboard{}).
Where("user_role = ? AND is_default = ?", dashboard.UserRole, true).
Update("is_default", false).Error; err != nil {
tx.Rollback()
return fmt.Errorf("取消默认状态失败: %w", err)
}
// 设置新的默认状态
if err := tx.Model(&entities.StatisticsDashboard{}).
Where("id = ?", dashboardID).
Update("is_default", true).Error; err != nil {
tx.Rollback()
return fmt.Errorf("设置默认状态失败: %w", err)
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return fmt.Errorf("提交事务失败: %w", err)
}
return nil
}
// RemoveDefaultDashboard 移除默认仪表板
func (r *GormStatisticsDashboardRepository) RemoveDefaultDashboard(ctx context.Context, userRole string) error {
if userRole == "" {
return fmt.Errorf("用户角色不能为空")
}
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("user_role = ? AND is_default = ?", userRole, true).
Update("is_default", false)
if result.Error != nil {
return fmt.Errorf("移除默认仪表板失败: %w", result.Error)
}
return nil
}
// ActivateDashboard 激活仪表板
func (r *GormStatisticsDashboardRepository) ActivateDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("id = ?", dashboardID).
Update("is_active", true)
if result.Error != nil {
return fmt.Errorf("激活仪表板失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("仪表板不存在")
}
return nil
}
// DeactivateDashboard 停用仪表板
func (r *GormStatisticsDashboardRepository) DeactivateDashboard(ctx context.Context, dashboardID string) error {
if dashboardID == "" {
return fmt.Errorf("仪表板ID不能为空")
}
result := r.db.WithContext(ctx).
Model(&entities.StatisticsDashboard{}).
Where("id = ?", dashboardID).
Update("is_active", false)
if result.Error != nil {
return fmt.Errorf("停用仪表板失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("仪表板不存在")
}
return nil
}

View File

@@ -0,0 +1,377 @@
package statistics
import (
"context"
"fmt"
"time"
"gorm.io/gorm"
"hyapi-server/internal/domains/statistics/entities"
"hyapi-server/internal/domains/statistics/repositories"
)
// GormStatisticsReportRepository GORM统计报告仓储实现
type GormStatisticsReportRepository struct {
db *gorm.DB
}
// NewGormStatisticsReportRepository 创建GORM统计报告仓储
func NewGormStatisticsReportRepository(db *gorm.DB) repositories.StatisticsReportRepository {
return &GormStatisticsReportRepository{
db: db,
}
}
// Save 保存统计报告
func (r *GormStatisticsReportRepository) Save(ctx context.Context, report *entities.StatisticsReport) error {
if report == nil {
return fmt.Errorf("统计报告不能为空")
}
// 验证报告
if err := report.Validate(); err != nil {
return fmt.Errorf("统计报告验证失败: %w", err)
}
// 保存到数据库
result := r.db.WithContext(ctx).Save(report)
if result.Error != nil {
return fmt.Errorf("保存统计报告失败: %w", result.Error)
}
return nil
}
// FindByID 根据ID查找统计报告
func (r *GormStatisticsReportRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsReport, error) {
if id == "" {
return nil, fmt.Errorf("报告ID不能为空")
}
var report entities.StatisticsReport
result := r.db.WithContext(ctx).Where("id = ?", id).First(&report)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("统计报告不存在")
}
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return &report, nil
}
// FindByUser 根据用户查找统计报告
func (r *GormStatisticsReportRepository) FindByUser(ctx context.Context, userID string, limit, offset int) ([]*entities.StatisticsReport, error) {
if userID == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).Where("generated_by = ?", userID)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByStatus 根据状态查找统计报告
func (r *GormStatisticsReportRepository) FindByStatus(ctx context.Context, status string) ([]*entities.StatisticsReport, error) {
if status == "" {
return nil, fmt.Errorf("报告状态不能为空")
}
var reports []*entities.StatisticsReport
result := r.db.WithContext(ctx).
Where("status = ?", status).
Order("created_at DESC").
Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// Update 更新统计报告
func (r *GormStatisticsReportRepository) Update(ctx context.Context, report *entities.StatisticsReport) error {
if report == nil {
return fmt.Errorf("统计报告不能为空")
}
if report.ID == "" {
return fmt.Errorf("报告ID不能为空")
}
// 验证报告
if err := report.Validate(); err != nil {
return fmt.Errorf("统计报告验证失败: %w", err)
}
// 更新数据库
result := r.db.WithContext(ctx).Save(report)
if result.Error != nil {
return fmt.Errorf("更新统计报告失败: %w", result.Error)
}
return nil
}
// Delete 删除统计报告
func (r *GormStatisticsReportRepository) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("报告ID不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("删除统计报告失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("统计报告不存在")
}
return nil
}
// FindByType 根据类型查找统计报告
func (r *GormStatisticsReportRepository) FindByType(ctx context.Context, reportType string, limit, offset int) ([]*entities.StatisticsReport, error) {
if reportType == "" {
return nil, fmt.Errorf("报告类型不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).Where("report_type = ?", reportType)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByTypeAndPeriod 根据类型和周期查找统计报告
func (r *GormStatisticsReportRepository) FindByTypeAndPeriod(ctx context.Context, reportType, period string, limit, offset int) ([]*entities.StatisticsReport, error) {
if reportType == "" {
return nil, fmt.Errorf("报告类型不能为空")
}
if period == "" {
return nil, fmt.Errorf("统计周期不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).
Where("report_type = ? AND period = ?", reportType, period)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByDateRange 根据日期范围查找统计报告
func (r *GormStatisticsReportRepository) FindByDateRange(ctx context.Context, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).
Where("created_at >= ? AND created_at < ?", startDate, endDate)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// FindByUserAndDateRange 根据用户和日期范围查找统计报告
func (r *GormStatisticsReportRepository) FindByUserAndDateRange(ctx context.Context, userID string, startDate, endDate time.Time, limit, offset int) ([]*entities.StatisticsReport, error) {
if userID == "" {
return nil, fmt.Errorf("用户ID不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var reports []*entities.StatisticsReport
query := r.db.WithContext(ctx).
Where("generated_by = ? AND created_at >= ? AND created_at < ?", userID, startDate, endDate)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&reports)
if result.Error != nil {
return nil, fmt.Errorf("查询统计报告失败: %w", result.Error)
}
return reports, nil
}
// CountByUser 根据用户统计数量
func (r *GormStatisticsReportRepository) CountByUser(ctx context.Context, userID string) (int64, error) {
if userID == "" {
return 0, fmt.Errorf("用户ID不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsReport{}).
Where("generated_by = ?", userID).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
}
return count, nil
}
// CountByType 根据类型统计数量
func (r *GormStatisticsReportRepository) CountByType(ctx context.Context, reportType string) (int64, error) {
if reportType == "" {
return 0, fmt.Errorf("报告类型不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsReport{}).
Where("report_type = ?", reportType).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
}
return count, nil
}
// CountByStatus 根据状态统计数量
func (r *GormStatisticsReportRepository) CountByStatus(ctx context.Context, status string) (int64, error) {
if status == "" {
return 0, fmt.Errorf("报告状态不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsReport{}).
Where("status = ?", status).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计报告数量失败: %w", result.Error)
}
return count, nil
}
// BatchSave 批量保存统计报告
func (r *GormStatisticsReportRepository) BatchSave(ctx context.Context, reports []*entities.StatisticsReport) error {
if len(reports) == 0 {
return fmt.Errorf("统计报告列表不能为空")
}
// 验证所有报告
for _, report := range reports {
if err := report.Validate(); err != nil {
return fmt.Errorf("统计报告验证失败: %w", err)
}
}
// 批量保存
result := r.db.WithContext(ctx).CreateInBatches(reports, 100)
if result.Error != nil {
return fmt.Errorf("批量保存统计报告失败: %w", result.Error)
}
return nil
}
// BatchDelete 批量删除统计报告
func (r *GormStatisticsReportRepository) BatchDelete(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return fmt.Errorf("报告ID列表不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsReport{}, "id IN ?", ids)
if result.Error != nil {
return fmt.Errorf("批量删除统计报告失败: %w", result.Error)
}
return nil
}
// DeleteExpiredReports 删除过期报告
func (r *GormStatisticsReportRepository) DeleteExpiredReports(ctx context.Context, expiredBefore time.Time) error {
if expiredBefore.IsZero() {
return fmt.Errorf("过期时间不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsReport{}, "expires_at IS NOT NULL AND expires_at < ?", expiredBefore)
if result.Error != nil {
return fmt.Errorf("删除过期报告失败: %w", result.Error)
}
return nil
}
// DeleteByStatus 根据状态删除统计报告
func (r *GormStatisticsReportRepository) DeleteByStatus(ctx context.Context, status string) error {
if status == "" {
return fmt.Errorf("报告状态不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsReport{}, "status = ?", status)
if result.Error != nil {
return fmt.Errorf("根据状态删除统计报告失败: %w", result.Error)
}
return nil
}

View File

@@ -0,0 +1,381 @@
package statistics
import (
"context"
"fmt"
"time"
"gorm.io/gorm"
"hyapi-server/internal/domains/statistics/entities"
"hyapi-server/internal/domains/statistics/repositories"
)
// GormStatisticsRepository GORM统计指标仓储实现
type GormStatisticsRepository struct {
db *gorm.DB
}
// NewGormStatisticsRepository 创建GORM统计指标仓储
func NewGormStatisticsRepository(db *gorm.DB) repositories.StatisticsRepository {
return &GormStatisticsRepository{
db: db,
}
}
// Save 保存统计指标
func (r *GormStatisticsRepository) Save(ctx context.Context, metric *entities.StatisticsMetric) error {
if metric == nil {
return fmt.Errorf("统计指标不能为空")
}
// 验证指标
if err := metric.Validate(); err != nil {
return fmt.Errorf("统计指标验证失败: %w", err)
}
// 保存到数据库
result := r.db.WithContext(ctx).Create(metric)
if result.Error != nil {
return fmt.Errorf("保存统计指标失败: %w", result.Error)
}
return nil
}
// FindByID 根据ID查找统计指标
func (r *GormStatisticsRepository) FindByID(ctx context.Context, id string) (*entities.StatisticsMetric, error) {
if id == "" {
return nil, fmt.Errorf("指标ID不能为空")
}
var metric entities.StatisticsMetric
result := r.db.WithContext(ctx).Where("id = ?", id).First(&metric)
if result.Error != nil {
if result.Error == gorm.ErrRecordNotFound {
return nil, fmt.Errorf("统计指标不存在")
}
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return &metric, nil
}
// FindByType 根据类型查找统计指标
func (r *GormStatisticsRepository) FindByType(ctx context.Context, metricType string, limit, offset int) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
var metrics []*entities.StatisticsMetric
query := r.db.WithContext(ctx).Where("metric_type = ?", metricType)
if limit > 0 {
query = query.Limit(limit)
}
if offset > 0 {
query = query.Offset(offset)
}
result := query.Order("created_at DESC").Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// Update 更新统计指标
func (r *GormStatisticsRepository) Update(ctx context.Context, metric *entities.StatisticsMetric) error {
if metric == nil {
return fmt.Errorf("统计指标不能为空")
}
if metric.ID == "" {
return fmt.Errorf("指标ID不能为空")
}
// 验证指标
if err := metric.Validate(); err != nil {
return fmt.Errorf("统计指标验证失败: %w", err)
}
// 更新数据库
result := r.db.WithContext(ctx).Save(metric)
if result.Error != nil {
return fmt.Errorf("更新统计指标失败: %w", result.Error)
}
return nil
}
// Delete 删除统计指标
func (r *GormStatisticsRepository) Delete(ctx context.Context, id string) error {
if id == "" {
return fmt.Errorf("指标ID不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id = ?", id)
if result.Error != nil {
return fmt.Errorf("删除统计指标失败: %w", result.Error)
}
if result.RowsAffected == 0 {
return fmt.Errorf("统计指标不存在")
}
return nil
}
// FindByTypeAndDateRange 根据类型和日期范围查找统计指标
func (r *GormStatisticsRepository) FindByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
result := r.db.WithContext(ctx).
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
Order("date ASC").
Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// FindByTypeDimensionAndDateRange 根据类型、维度和日期范围查找统计指标
func (r *GormStatisticsRepository) FindByTypeDimensionAndDateRange(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
query := r.db.WithContext(ctx).
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate)
if dimension != "" {
query = query.Where("dimension = ?", dimension)
}
result := query.Order("date ASC").Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// FindByTypeNameAndDateRange 根据类型、名称和日期范围查找统计指标
func (r *GormStatisticsRepository) FindByTypeNameAndDateRange(ctx context.Context, metricType, metricName string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if metricName == "" {
return nil, fmt.Errorf("指标名称不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
result := r.db.WithContext(ctx).
Where("metric_type = ? AND metric_name = ? AND date >= ? AND date < ?",
metricType, metricName, startDate, endDate).
Order("date ASC").
Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// GetAggregatedMetrics 获取聚合指标
func (r *GormStatisticsRepository) GetAggregatedMetrics(ctx context.Context, metricType, dimension string, startDate, endDate time.Time) (map[string]float64, error) {
if metricType == "" {
return nil, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
type AggregatedResult struct {
MetricName string `json:"metric_name"`
TotalValue float64 `json:"total_value"`
}
var results []AggregatedResult
query := r.db.WithContext(ctx).
Model(&entities.StatisticsMetric{}).
Select("metric_name, SUM(value) as total_value").
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
Group("metric_name")
if dimension != "" {
query = query.Where("dimension = ?", dimension)
}
result := query.Find(&results)
if result.Error != nil {
return nil, fmt.Errorf("查询聚合指标失败: %w", result.Error)
}
// 转换为map
aggregated := make(map[string]float64)
for _, res := range results {
aggregated[res.MetricName] = res.TotalValue
}
return aggregated, nil
}
// GetMetricsByDimension 根据维度获取指标
func (r *GormStatisticsRepository) GetMetricsByDimension(ctx context.Context, dimension string, startDate, endDate time.Time) ([]*entities.StatisticsMetric, error) {
if dimension == "" {
return nil, fmt.Errorf("统计维度不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return nil, fmt.Errorf("开始日期和结束日期不能为空")
}
var metrics []*entities.StatisticsMetric
result := r.db.WithContext(ctx).
Where("dimension = ? AND date >= ? AND date < ?", dimension, startDate, endDate).
Order("date ASC").
Find(&metrics)
if result.Error != nil {
return nil, fmt.Errorf("查询统计指标失败: %w", result.Error)
}
return metrics, nil
}
// CountByType 根据类型统计数量
func (r *GormStatisticsRepository) CountByType(ctx context.Context, metricType string) (int64, error) {
if metricType == "" {
return 0, fmt.Errorf("指标类型不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsMetric{}).
Where("metric_type = ?", metricType).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
}
return count, nil
}
// CountByTypeAndDateRange 根据类型和日期范围统计数量
func (r *GormStatisticsRepository) CountByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) (int64, error) {
if metricType == "" {
return 0, fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return 0, fmt.Errorf("开始日期和结束日期不能为空")
}
var count int64
result := r.db.WithContext(ctx).
Model(&entities.StatisticsMetric{}).
Where("metric_type = ? AND date >= ? AND date < ?", metricType, startDate, endDate).
Count(&count)
if result.Error != nil {
return 0, fmt.Errorf("统计指标数量失败: %w", result.Error)
}
return count, nil
}
// BatchSave 批量保存统计指标
func (r *GormStatisticsRepository) BatchSave(ctx context.Context, metrics []*entities.StatisticsMetric) error {
if len(metrics) == 0 {
return fmt.Errorf("统计指标列表不能为空")
}
// 验证所有指标
for _, metric := range metrics {
if err := metric.Validate(); err != nil {
return fmt.Errorf("统计指标验证失败: %w", err)
}
}
// 批量保存
result := r.db.WithContext(ctx).CreateInBatches(metrics, 100)
if result.Error != nil {
return fmt.Errorf("批量保存统计指标失败: %w", result.Error)
}
return nil
}
// BatchDelete 批量删除统计指标
func (r *GormStatisticsRepository) BatchDelete(ctx context.Context, ids []string) error {
if len(ids) == 0 {
return fmt.Errorf("指标ID列表不能为空")
}
result := r.db.WithContext(ctx).Delete(&entities.StatisticsMetric{}, "id IN ?", ids)
if result.Error != nil {
return fmt.Errorf("批量删除统计指标失败: %w", result.Error)
}
return nil
}
// DeleteByDateRange 根据日期范围删除统计指标
func (r *GormStatisticsRepository) DeleteByDateRange(ctx context.Context, startDate, endDate time.Time) error {
if startDate.IsZero() || endDate.IsZero() {
return fmt.Errorf("开始日期和结束日期不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsMetric{}, "date >= ? AND date < ?", startDate, endDate)
if result.Error != nil {
return fmt.Errorf("根据日期范围删除统计指标失败: %w", result.Error)
}
return nil
}
// DeleteByTypeAndDateRange 根据类型和日期范围删除统计指标
func (r *GormStatisticsRepository) DeleteByTypeAndDateRange(ctx context.Context, metricType string, startDate, endDate time.Time) error {
if metricType == "" {
return fmt.Errorf("指标类型不能为空")
}
if startDate.IsZero() || endDate.IsZero() {
return fmt.Errorf("开始日期和结束日期不能为空")
}
result := r.db.WithContext(ctx).
Delete(&entities.StatisticsMetric{}, "metric_type = ? AND date >= ? AND date < ?",
metricType, startDate, endDate)
if result.Error != nil {
return fmt.Errorf("根据类型和日期范围删除统计指标失败: %w", result.Error)
}
return nil
}

View File

@@ -0,0 +1,101 @@
// internal/infrastructure/database/repositories/user/gorm_contract_info_repository.go
package repositories
import (
"context"
"errors"
"hyapi-server/internal/domains/user/entities"
"hyapi-server/internal/domains/user/repositories"
"hyapi-server/internal/shared/database"
"go.uber.org/zap"
"gorm.io/gorm"
)
const (
ContractInfosTable = "contract_infos"
)
type GormContractInfoRepository struct {
*database.CachedBaseRepositoryImpl
}
func NewGormContractInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractInfoRepository {
return &GormContractInfoRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, ContractInfosTable),
}
}
func (r *GormContractInfoRepository) Save(ctx context.Context, contract *entities.ContractInfo) error {
return r.CreateEntity(ctx, contract)
}
func (r *GormContractInfoRepository) FindByID(ctx context.Context, contractID string) (*entities.ContractInfo, error) {
var contract entities.ContractInfo
err := r.SmartGetByID(ctx, contractID, &contract)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return &contract, nil
}
func (r *GormContractInfoRepository) Delete(ctx context.Context, contractID string) error {
return r.DeleteEntity(ctx, contractID, &entities.ContractInfo{})
}
func (r *GormContractInfoRepository) FindByEnterpriseInfoID(ctx context.Context, enterpriseInfoID string) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("enterprise_info_id = ?", enterpriseInfoID).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) FindByUserID(ctx context.Context, userID string) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("user_id = ?", userID).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) FindByContractType(ctx context.Context, enterpriseInfoID string, contractType entities.ContractType) ([]*entities.ContractInfo, error) {
var contracts []entities.ContractInfo
err := r.GetDB(ctx).Where("enterprise_info_id = ? AND contract_type = ?", enterpriseInfoID, contractType).Find(&contracts).Error
if err != nil {
return nil, err
}
result := make([]*entities.ContractInfo, len(contracts))
for i := range contracts {
result[i] = &contracts[i]
}
return result, nil
}
func (r *GormContractInfoRepository) ExistsByContractFileID(ctx context.Context, contractFileID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ?", contractFileID).Count(&count).Error
return count > 0, err
}
func (r *GormContractInfoRepository) ExistsByContractFileIDExcludeID(ctx context.Context, contractFileID, excludeID string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.ContractInfo{}).Where("contract_file_id = ? AND id != ?", contractFileID, excludeID).Count(&count).Error
return count > 0, err
}

View File

@@ -0,0 +1,272 @@
package repositories
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"hyapi-server/internal/domains/user/entities"
"hyapi-server/internal/domains/user/repositories"
"hyapi-server/internal/shared/interfaces"
)
// GormEnterpriseInfoRepository 企业信息GORM仓储实现
type GormEnterpriseInfoRepository struct {
db *gorm.DB
logger *zap.Logger
}
// NewGormEnterpriseInfoRepository 创建企业信息GORM仓储
func NewGormEnterpriseInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.EnterpriseInfoRepository {
return &GormEnterpriseInfoRepository{
db: db,
logger: logger,
}
}
// Create 创建企业信息
func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) (entities.EnterpriseInfo, error) {
if err := r.db.WithContext(ctx).Create(&enterpriseInfo).Error; err != nil {
r.logger.Error("创建企业信息失败", zap.Error(err))
return entities.EnterpriseInfo{}, fmt.Errorf("创建企业信息失败: %w", err)
}
return enterpriseInfo, nil
}
// GetByID 根据ID获取企业信息
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
return entities.EnterpriseInfo{}, fmt.Errorf("获取企业信息失败: %w", err)
}
return enterpriseInfo, nil
}
// Update 更新企业信息
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
r.logger.Error("更新企业信息失败", zap.Error(err))
return fmt.Errorf("更新企业信息失败: %w", err)
}
return nil
}
// Delete 删除企业信息
func (r *GormEnterpriseInfoRepository) Delete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
r.logger.Error("删除企业信息失败", zap.Error(err))
return fmt.Errorf("删除企业信息失败: %w", err)
}
return nil
}
// SoftDelete 软删除企业信息
func (r *GormEnterpriseInfoRepository) SoftDelete(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
r.logger.Error("软删除企业信息失败", zap.Error(err))
return fmt.Errorf("软删除企业信息失败: %w", err)
}
return nil
}
// Restore 恢复软删除的企业信息
func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) error {
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
r.logger.Error("恢复企业信息失败", zap.Error(err))
return fmt.Errorf("恢复企业信息失败: %w", err)
}
return nil
}
// GetByUserID 根据用户ID获取企业信息
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
return nil, fmt.Errorf("获取企业信息失败: %w", err)
}
return &enterpriseInfo, nil
}
// GetByUnifiedSocialCode 根据统一社会信用代码获取企业信息
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
var enterpriseInfo entities.EnterpriseInfo
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("企业信息不存在")
}
r.logger.Error("获取企业信息失败", zap.Error(err))
return nil, fmt.Errorf("获取企业信息失败: %w", err)
}
return &enterpriseInfo, nil
}
// CheckUnifiedSocialCodeExists 检查统一社会信用代码是否已存在
func (r *GormEnterpriseInfoRepository) CheckUnifiedSocialCodeExists(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("unified_social_code = ?", unifiedSocialCode)
if excludeUserID != "" {
query = query.Where("user_id != ?", excludeUserID)
}
if err := query.Count(&count).Error; err != nil {
r.logger.Error("检查统一社会信用代码失败", zap.Error(err))
return false, fmt.Errorf("检查统一社会信用代码失败: %w", err)
}
return count > 0, nil
}
// UpdateVerificationStatus 更新验证状态
func (r *GormEnterpriseInfoRepository) UpdateVerificationStatus(ctx context.Context, userID string, isOCRVerified, isFaceVerified, isCertified bool) error {
updates := map[string]interface{}{
"is_ocr_verified": isOCRVerified,
"is_face_verified": isFaceVerified,
"is_certified": isCertified,
}
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
r.logger.Error("更新验证状态失败", zap.Error(err))
return fmt.Errorf("更新验证状态失败: %w", err)
}
return nil
}
// UpdateOCRData 更新OCR数据
func (r *GormEnterpriseInfoRepository) UpdateOCRData(ctx context.Context, userID string, rawData string, confidence float64) error {
updates := map[string]interface{}{
"ocr_raw_data": rawData,
"ocr_confidence": confidence,
"is_ocr_verified": true,
}
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
r.logger.Error("更新OCR数据失败", zap.Error(err))
return fmt.Errorf("更新OCR数据失败: %w", err)
}
return nil
}
// CompleteCertification 完成认证
func (r *GormEnterpriseInfoRepository) CompleteCertification(ctx context.Context, userID string) error {
now := time.Now()
updates := map[string]interface{}{
"is_certified": true,
"certified_at": &now,
}
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
r.logger.Error("完成认证失败", zap.Error(err))
return fmt.Errorf("完成认证失败: %w", err)
}
return nil
}
// Count 统计企业信息数量
func (r *GormEnterpriseInfoRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// Exists 检查企业信息是否存在
func (r *GormEnterpriseInfoRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// CreateBatch 批量创建企业信息
func (r *GormEnterpriseInfoRepository) CreateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
return r.db.WithContext(ctx).Create(&enterpriseInfos).Error
}
// GetByIDs 根据ID列表获取企业信息
func (r *GormEnterpriseInfoRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.EnterpriseInfo, error) {
var enterpriseInfos []entities.EnterpriseInfo
err := r.db.WithContext(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&enterpriseInfos).Error
return enterpriseInfos, err
}
// UpdateBatch 批量更新企业信息
func (r *GormEnterpriseInfoRepository) UpdateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
return r.db.WithContext(ctx).Save(&enterpriseInfos).Error
}
// DeleteBatch 批量删除企业信息
func (r *GormEnterpriseInfoRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id IN ?", ids).Error
}
// List 获取企业信息列表
func (r *GormEnterpriseInfoRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.EnterpriseInfo, error) {
var enterpriseInfos []entities.EnterpriseInfo
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
if options.Search != "" {
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
}
if options.Sort != "" {
order := "ASC"
if options.Order != "" {
order = options.Order
}
query = query.Order(options.Sort + " " + order)
} else {
// 默认按创建时间倒序
query = query.Order("created_at DESC")
}
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
err := query.Find(&enterpriseInfos).Error
return enterpriseInfos, err
}
// WithTx 使用事务
func (r *GormEnterpriseInfoRepository) WithTx(tx interface{}) interfaces.Repository[entities.EnterpriseInfo] {
if gormTx, ok := tx.(*gorm.DB); ok {
return &GormEnterpriseInfoRepository{
db: gormTx,
logger: r.logger,
}
}
return r
}

View File

@@ -0,0 +1,374 @@
//go:build !test
// +build !test
package repositories
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"hyapi-server/internal/domains/user/entities"
"hyapi-server/internal/domains/user/repositories"
"hyapi-server/internal/domains/user/repositories/queries"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
)
const (
SMSCodesTable = "sms_codes"
)
// GormSMSCodeRepository 短信验证码GORM仓储实现无缓存确保安全性
type GormSMSCodeRepository struct {
*database.CachedBaseRepositoryImpl
}
// NewGormSMSCodeRepository 创建短信验证码仓储
func NewGormSMSCodeRepository(db *gorm.DB, logger *zap.Logger) repositories.SMSCodeRepository {
return &GormSMSCodeRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, SMSCodesTable),
}
}
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
// ================ Repository[T] 接口实现 ================
// Create 创建短信验证码记录(不缓存,确保安全性)
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
err := r.GetDB(ctx).Create(&smsCode).Error
return smsCode, err
}
// GetByID 根据ID获取短信验证码
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
var smsCode entities.SMSCode
err := r.GetDB(ctx).Where("id = ?", id).First(&smsCode).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
}
return entities.SMSCode{}, err
}
return smsCode, nil
}
// Update 更新验证码记录
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
return r.GetDB(ctx).Save(&smsCode).Error
}
// CreateBatch 批量创建短信验证码
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
return r.GetDB(ctx).Create(&smsCodes).Error
}
// GetByIDs 根据ID列表获取短信验证码
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&smsCodes).Error
return smsCodes, err
}
// UpdateBatch 批量更新短信验证码
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
return r.GetDB(ctx).Save(&smsCodes).Error
}
// DeleteBatch 批量删除短信验证码
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
}
// List 获取短信验证码列表
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
var smsCodes []entities.SMSCode
query := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
}
// 应用预加载
for _, include := range options.Include {
query = query.Preload(include)
}
// 应用排序
if options.Sort != "" {
order := "ASC"
if options.Order == "desc" || options.Order == "DESC" {
order = "DESC"
}
query = query.Order(options.Sort + " " + order)
} else {
query = query.Order("created_at DESC")
}
// 应用分页
if options.Page > 0 && options.PageSize > 0 {
offset := (options.Page - 1) * options.PageSize
query = query.Offset(offset).Limit(options.PageSize)
}
return smsCodes, query.Find(&smsCodes).Error
}
// ================ BaseRepository 接口实现 ================
// Delete 删除短信验证码
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Exists 检查短信验证码是否存在
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
return count > 0, err
}
// Count 统计短信验证码数量
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if options.Filters != nil {
for key, value := range options.Filters {
query = query.Where(key+" = ?", value)
}
}
// 应用搜索条件
if options.Search != "" {
query = query.Where("phone LIKE ?", "%"+options.Search+"%")
}
err := query.Count(&count).Error
return count, err
}
// SoftDelete 软删除短信验证码
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error
}
// Restore 恢复短信验证码
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
return r.GetDB(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
// GetByPhone 根据手机号获取短信验证码
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// GetLatestByPhone 根据手机号获取最新短信验证码
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.GetDB(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// GetValidByPhone 根据手机号获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.GetDB(ctx).
Where("phone = ? AND expires_at > ? AND used_at IS NULL", phone, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("有效的短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// GetValidByPhoneAndScene 根据手机号和场景获取有效的短信验证码
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
var smsCode entities.SMSCode
if err := r.GetDB(ctx).
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, scene, time.Now()).
Order("created_at DESC").
First(&smsCode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, fmt.Errorf("有效的短信验证码不存在")
}
return nil, err
}
return &smsCode, nil
}
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
var smsCodes []*entities.SMSCode
var total int64
// 构建查询条件
db := r.GetDB(ctx).Model(&entities.SMSCode{})
// 应用筛选条件
if query.Phone != "" {
db = db.Where("phone = ?", query.Phone)
}
if query.Purpose != "" {
db = db.Where("scene = ?", query.Purpose)
}
if query.Status != "" {
db = db.Where("used = ?", query.Status == "used")
}
if query.StartDate != "" {
db = db.Where("created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
db = db.Where("created_at <= ?", query.EndDate)
}
// 统计总数
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用分页
offset := (query.Page - 1) * query.PageSize
if err := db.Offset(offset).Limit(query.PageSize).Order("created_at DESC").Find(&smsCodes).Error; err != nil {
return nil, 0, err
}
return smsCodes, total, nil
}
// CreateCode 创建验证码
func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, code string, purpose string) (entities.SMSCode, error) {
smsCode := entities.SMSCode{
Phone: phone,
Code: code,
Scene: entities.SMSScene(purpose), // 使用Scene字段
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟有效期
}
if err := r.GetDB(ctx).Create(&smsCode).Error; err != nil {
r.GetLogger().Error("创建短信验证码失败", zap.Error(err))
return entities.SMSCode{}, err
}
return smsCode, nil
}
// ValidateCode 验证验证码
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL", phone, code, purpose, time.Now()).
Count(&count).Error
return count > 0, err
}
// InvalidateCode 使验证码失效
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
now := time.Now()
return r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND used_at IS NULL", phone).
Update("used_at", &now).Error
}
// CheckSendFrequency 检查发送频率
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
// 检查1分钟内是否已发送
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
var count int64
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
Count(&count).Error
// 如果1分钟内已发送则返回false不允许发送
return count == 0, err
}
// GetTodaySendCount 获取今日发送数量
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
today := time.Now().Truncate(24 * time.Hour)
var count int64
err := r.GetDB(ctx).Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&count).Error
return count, err
}
// GetCodeStats 获取验证码统计
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
var stats repositories.SMSCodeStats
// 计算指定天数前的日期
startDate := time.Now().AddDate(0, 0, -days)
// 总发送数
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, startDate).
Count(&stats.TotalSent).Error; err != nil {
return nil, err
}
// 总验证数
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
Count(&stats.TotalValidated).Error; err != nil {
return nil, err
}
// 成功率
if stats.TotalSent > 0 {
stats.SuccessRate = float64(stats.TotalValidated) / float64(stats.TotalSent) * 100
}
// 今日发送数
today := time.Now().Truncate(24 * time.Hour)
if err := r.GetDB(ctx).
Model(&entities.SMSCode{}).
Where("phone = ? AND created_at >= ?", phone, today).
Count(&stats.TodaySent).Error; err != nil {
return nil, err
}
return &stats, nil
}

View File

@@ -0,0 +1,720 @@
//go:build !test
// +build !test
package repositories
import (
"context"
"errors"
"fmt"
"time"
"go.uber.org/zap"
"gorm.io/gorm"
"hyapi-server/internal/domains/user/entities"
"hyapi-server/internal/domains/user/repositories"
"hyapi-server/internal/domains/user/repositories/queries"
"hyapi-server/internal/shared/database"
"hyapi-server/internal/shared/interfaces"
)
const (
UsersTable = "users"
UserCacheTTL = 30 * 60 // 30分钟
)
// 定义错误常量
var (
// ErrUserNotFound 用户不存在错误
ErrUserNotFound = errors.New("用户不存在")
)
type GormUserRepository struct {
*database.CachedBaseRepositoryImpl
}
var _ repositories.UserRepository = (*GormUserRepository)(nil)
func NewGormUserRepository(db *gorm.DB, logger *zap.Logger) repositories.UserRepository {
return &GormUserRepository{
CachedBaseRepositoryImpl: database.NewCachedBaseRepositoryImpl(db, logger, UsersTable),
}
}
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
err := r.CreateEntity(ctx, &user)
return user, err
}
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
var user entities.User
err := r.SmartGetByID(ctx, id, &user)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, errors.New("用户不存在")
}
return entities.User{}, err
}
return user, nil
}
func (r *GormUserRepository) GetByIDWithEnterpriseInfo(ctx context.Context, id string) (entities.User, error) {
var user entities.User
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id = ?", id).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return entities.User{}, ErrUserNotFound
}
r.GetLogger().Error("根据ID查询用户失败", zap.Error(err))
return entities.User{}, err
}
return user, nil
}
func (r *GormUserRepository) BatchGetByIDsWithEnterpriseInfo(ctx context.Context, ids []string) ([]*entities.User, error) {
if len(ids) == 0 {
return []*entities.User{}, nil
}
var users []*entities.User
if err := r.GetDB(ctx).Preload("EnterpriseInfo").Where("id IN ?", ids).Find(&users).Error; err != nil {
r.GetLogger().Error("批量查询用户失败", zap.Error(err), zap.Strings("ids", ids))
return nil, err
}
return users, nil
}
func (r *GormUserRepository) ExistsByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
var count int64
query := r.GetDB(ctx).Model(&entities.User{}).
Joins("JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
Where("enterprise_infos.unified_social_code = ?", unifiedSocialCode)
// 如果指定了排除的用户ID则排除该用户的记录
if excludeUserID != "" {
query = query.Where("users.id != ?", excludeUserID)
}
err := query.Count(&count).Error
if err != nil {
r.GetLogger().Error("检查统一社会信用代码是否存在失败", zap.Error(err))
return false, err
}
return count > 0, nil
}
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
return r.UpdateEntity(ctx, &user)
}
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
r.GetLogger().Info("批量创建用户", zap.Int("count", len(users)))
return r.GetDB(ctx).Create(&users).Error
}
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
var users []entities.User
err := r.GetDB(ctx).Where("id IN ?", ids).Order("created_at DESC").Find(&users).Error
return users, err
}
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
r.GetLogger().Info("批量更新用户", zap.Int("count", len(users)))
return r.GetDB(ctx).Save(&users).Error
}
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
r.GetLogger().Info("批量删除用户", zap.Strings("ids", ids))
return r.GetDB(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
}
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
var users []entities.User
err := r.SmartList(ctx, &users, options)
return users, err
}
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
return r.DeleteEntity(ctx, id, &entities.User{})
}
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
return r.ExistsEntity(ctx, id, &entities.User{})
}
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.User{}).Count(&count).Error
return count, err
}
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
return r.GetDB(ctx).Delete(&entities.User{}, "id = ?", id).Error
}
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
return r.GetDB(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error
}
// ================ 业务专用方法 ================
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
var user entities.User
if err := r.GetDB(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.GetLogger().Error("根据手机号查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
func (r *GormUserRepository) GetByUsername(ctx context.Context, username string) (*entities.User, error) {
var user entities.User
if err := r.GetDB(ctx).Where("username = ?", username).First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrUserNotFound
}
r.GetLogger().Error("根据用户名查询用户失败", zap.Error(err))
return nil, err
}
return &user, nil
}
func (r *GormUserRepository) GetByUserType(ctx context.Context, userType string) ([]*entities.User, error) {
var users []*entities.User
err := r.GetDB(ctx).Where("user_type = ?", userType).Order("created_at DESC").Find(&users).Error
return users, err
}
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
var users []*entities.User
var total int64
// 构建查询条件,预加载企业信息
db := r.GetDB(ctx).Model(&entities.User{}).Preload("EnterpriseInfo")
// 应用筛选条件
if query.Phone != "" {
db = db.Where("users.phone LIKE ?", "%"+query.Phone+"%")
}
if query.UserType != "" {
db = db.Where("users.user_type = ?", query.UserType)
}
if query.IsActive != nil {
db = db.Where("users.active = ?", *query.IsActive)
}
if query.IsCertified != nil {
db = db.Where("users.is_certified = ?", *query.IsCertified)
}
if query.CompanyName != "" {
db = db.Joins("LEFT JOIN enterprise_infos ON users.id = enterprise_infos.user_id").
Where("enterprise_infos.company_name LIKE ?", "%"+query.CompanyName+"%")
}
if query.StartDate != "" {
db = db.Where("users.created_at >= ?", query.StartDate)
}
if query.EndDate != "" {
db = db.Where("users.created_at <= ?", query.EndDate)
}
// 统计总数
if err := db.Count(&total).Error; err != nil {
return nil, 0, err
}
// 应用排序(默认按创建时间倒序)
db = db.Order("users.created_at DESC")
// 应用分页
offset := (query.Page - 1) * query.PageSize
if err := db.Offset(offset).Limit(query.PageSize).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
var user entities.User
err := r.GetDB(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
now := time.Now()
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"last_login_at": &now,
"updated_at": now,
}).Error
}
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("password", newPassword).Error
}
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
var count int64
err := r.GetDB(ctx).Model(&entities.User{}).
Where("id = ? AND password = ?", userID, password).
Count(&count).Error
return count > 0, err
}
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("active", true).Error
}
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Update("active", false).Error
}
func (r *GormUserRepository) UpdateLoginStats(ctx context.Context, userID string) error {
return r.GetDB(ctx).Model(&entities.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"login_count": gorm.Expr("login_count + 1"),
"last_login_at": time.Now(),
}).Error
}
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.GetDB(ctx)
// 总用户数
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
return nil, err
}
// 活跃用户数
if err := db.Model(&entities.User{}).Where("active = ?", true).Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 已认证用户数
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.GetDB(ctx)
// 指定时间范围内的注册数
if err := db.Model(&entities.User{}).
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 指定时间范围内的登录数
if err := db.Model(&entities.User{}).
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetSystemUserStats 获取系统用户统计信息
func (r *GormUserRepository) GetSystemUserStats(ctx context.Context) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.GetDB(ctx)
// 总用户数
if err := db.Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
return nil, err
}
// 活跃用户数最近30天有登录
thirtyDaysAgo := time.Now().AddDate(0, 0, -30)
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", thirtyDaysAgo).Count(&stats.ActiveUsers).Error; err != nil {
return nil, err
}
// 已认证用户数
if err := db.Model(&entities.User{}).Where("is_certified = ?", true).Count(&stats.CertifiedUsers).Error; err != nil {
return nil, err
}
// 今日注册数
today := time.Now().Truncate(24 * time.Hour)
if err := db.Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 今日登录数
if err := db.Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetSystemUserStatsByDateRange 获取系统指定时间范围内的用户统计信息
func (r *GormUserRepository) GetSystemUserStatsByDateRange(ctx context.Context, startDate, endDate time.Time) (*repositories.UserStats, error) {
var stats repositories.UserStats
db := r.GetDB(ctx)
// 指定时间范围内的注册数
if err := db.Model(&entities.User{}).
Where("created_at >= ? AND created_at <= ?", startDate, endDate).
Count(&stats.TodayRegistrations).Error; err != nil {
return nil, err
}
// 指定时间范围内的登录数
if err := db.Model(&entities.User{}).
Where("last_login_at >= ? AND last_login_at <= ?", startDate, endDate).
Count(&stats.TodayLogins).Error; err != nil {
return nil, err
}
return &stats, nil
}
// GetSystemDailyUserStats 获取系统每日用户统计
func (r *GormUserRepository) GetSystemDailyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(created_at) as date,
COUNT(*) as count
FROM users
WHERE DATE(created_at) >= $1
AND DATE(created_at) <= $2
GROUP BY DATE(created_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyUserStats 获取系统每月用户统计
func (r *GormUserRepository) GetSystemMonthlyUserStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(created_at, 'YYYY-MM') as month,
COUNT(*) as count
FROM users
WHERE created_at >= $1
AND created_at <= $2
GROUP BY TO_CHAR(created_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemDailyCertificationStats 获取系统每日认证用户统计基于is_certified字段
func (r *GormUserRepository) GetSystemDailyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
DATE(updated_at) as date,
COUNT(*) as count
FROM users
WHERE is_certified = true
AND DATE(updated_at) >= $1
AND DATE(updated_at) <= $2
GROUP BY DATE(updated_at)
ORDER BY date ASC
`
err := r.GetDB(ctx).Raw(sql, startDate.Format("2006-01-02"), endDate.Format("2006-01-02")).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetSystemMonthlyCertificationStats 获取系统每月认证用户统计基于is_certified字段
func (r *GormUserRepository) GetSystemMonthlyCertificationStats(ctx context.Context, startDate, endDate time.Time) ([]map[string]interface{}, error) {
var results []map[string]interface{}
sql := `
SELECT
TO_CHAR(updated_at, 'YYYY-MM') as month,
COUNT(*) as count
FROM users
WHERE is_certified = true
AND updated_at >= $1
AND updated_at <= $2
GROUP BY TO_CHAR(updated_at, 'YYYY-MM')
ORDER BY month ASC
`
err := r.GetDB(ctx).Raw(sql, startDate, endDate).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetUserCallRankingByCalls 按调用次数获取用户排行
func (r *GormUserRepository) GetUserCallRankingByCalls(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COUNT(ac.id) as calls
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN api_calls ac ON u.id = ac.user_id
AND DATE(ac.created_at) = CURRENT_DATE
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COUNT(ac.id) > 0
ORDER BY calls DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COUNT(ac.id) as calls
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN api_calls ac ON u.id = ac.user_id
AND DATE_TRUNC('month', ac.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COUNT(ac.id) > 0
ORDER BY calls DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COUNT(ac.id) as calls
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN api_calls ac ON u.id = ac.user_id
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COUNT(ac.id) > 0
ORDER BY calls DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetUserCallRankingByConsumption 按消费金额获取用户排行
func (r *GormUserRepository) GetUserCallRankingByConsumption(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(wt.amount), 0) as consumption
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
AND DATE(wt.created_at) = CURRENT_DATE
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(wt.amount), 0) > 0
ORDER BY consumption DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(wt.amount), 0) as consumption
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
AND DATE_TRUNC('month', wt.created_at) = DATE_TRUNC('month', CURRENT_DATE)
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(wt.amount), 0) > 0
ORDER BY consumption DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(wt.amount), 0) as consumption
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN wallet_transactions wt ON u.id = wt.user_id
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(wt.amount), 0) > 0
ORDER BY consumption DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}
// GetRechargeRanking 获取充值排行(排除赠送,只统计成功状态)
func (r *GormUserRepository) GetRechargeRanking(ctx context.Context, period string, limit int) ([]map[string]interface{}, error) {
var sql string
var args []interface{}
switch period {
case "today":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(rr.amount), 0) as amount
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN recharge_records rr ON u.id = rr.user_id
AND DATE(rr.created_at) = CURRENT_DATE
AND rr.status = 'success'
AND rr.recharge_type != 'gift'
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(rr.amount), 0) > 0
ORDER BY amount DESC
LIMIT $1
`
args = []interface{}{limit}
case "month":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(rr.amount), 0) as amount
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN recharge_records rr ON u.id = rr.user_id
AND DATE_TRUNC('month', rr.created_at) = DATE_TRUNC('month', CURRENT_DATE)
AND rr.status = 'success'
AND rr.recharge_type != 'gift'
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(rr.amount), 0) > 0
ORDER BY amount DESC
LIMIT $1
`
args = []interface{}{limit}
case "total":
sql = `
SELECT
u.id as user_id,
COALESCE(ei.company_name, u.username, u.phone) as username,
COALESCE(SUM(rr.amount), 0) as amount
FROM users u
LEFT JOIN enterprise_infos ei ON u.id = ei.user_id
LEFT JOIN recharge_records rr ON u.id = rr.user_id
AND rr.status = 'success'
AND rr.recharge_type != 'gift'
WHERE u.deleted_at IS NULL
GROUP BY u.id, ei.company_name, u.username, u.phone
HAVING COALESCE(SUM(rr.amount), 0) > 0
ORDER BY amount DESC
LIMIT $1
`
args = []interface{}{limit}
default:
return nil, fmt.Errorf("不支持的时间周期: %s", period)
}
var results []map[string]interface{}
err := r.GetDB(ctx).Raw(sql, args...).Scan(&results).Error
if err != nil {
return nil, err
}
return results, nil
}

View File

@@ -0,0 +1,291 @@
package events
import (
"context"
"time"
"go.uber.org/zap"
"hyapi-server/internal/shared/interfaces"
)
// ================ 常量定义 ================
const (
// 事件类型
EventTypeCertificationCreated = "certification.created"
EventTypeEnterpriseInfoSubmitted = "certification.enterprise_info_submitted"
EventTypeEnterpriseVerificationCompleted = "certification.enterprise_verification_completed"
EventTypeContractGenerated = "certification.contract_generated"
EventTypeContractSigned = "certification.contract_signed"
EventTypeCertificationCompleted = "certification.completed"
EventTypeCertificationFailed = "certification.failed"
EventTypeStatusTransitioned = "certification.status_transitioned"
// 重试配置
MaxRetries = 3
RetryDelay = 5 * time.Second
)
// ================ 事件结构 ================
// CertificationEventData 认证事件数据结构
type CertificationEventData struct {
EventType string `json:"event_type"`
CertificationID string `json:"certification_id"`
UserID string `json:"user_id"`
Data map[string]interface{} `json:"data"`
Timestamp time.Time `json:"timestamp"`
Version string `json:"version"`
}
// ================ 事件发布器实现 ================
// CertificationEventPublisher 认证事件发布器实现
//
// 职责:
// - 发布认证域相关的事件
// - 支持异步发布和重试机制
// - 提供事件持久化能力
// - 集成监控和日志
type CertificationEventPublisher struct {
eventBus interfaces.EventBus
logger *zap.Logger
}
// NewCertificationEventPublisher 创建认证事件发布器
func NewCertificationEventPublisher(
eventBus interfaces.EventBus,
logger *zap.Logger,
) *CertificationEventPublisher {
return &CertificationEventPublisher{
eventBus: eventBus,
logger: logger,
}
}
// ================ 事件发布方法 ================
// PublishCertificationCreated 发布认证创建事件
func (p *CertificationEventPublisher) PublishCertificationCreated(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeCertificationCreated,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishEnterpriseInfoSubmitted 发布企业信息提交事件
func (p *CertificationEventPublisher) PublishEnterpriseInfoSubmitted(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeEnterpriseInfoSubmitted,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishEnterpriseVerificationCompleted 发布企业认证完成事件
func (p *CertificationEventPublisher) PublishEnterpriseVerificationCompleted(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeEnterpriseVerificationCompleted,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishContractGenerated 发布合同生成事件
func (p *CertificationEventPublisher) PublishContractGenerated(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeContractGenerated,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishContractSigned 发布合同签署事件
func (p *CertificationEventPublisher) PublishContractSigned(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeContractSigned,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishCertificationCompleted 发布认证完成事件
func (p *CertificationEventPublisher) PublishCertificationCompleted(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeCertificationCompleted,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishCertificationFailed 发布认证失败事件
func (p *CertificationEventPublisher) PublishCertificationFailed(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeCertificationFailed,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// PublishStatusTransitioned 发布状态转换事件
func (p *CertificationEventPublisher) PublishStatusTransitioned(
ctx context.Context,
certificationID, userID string,
data map[string]interface{},
) error {
eventData := &CertificationEventData{
EventType: EventTypeStatusTransitioned,
CertificationID: certificationID,
UserID: userID,
Data: data,
Timestamp: time.Now(),
Version: "1.0",
}
return p.publishEventData(ctx, eventData)
}
// ================ 内部实现 ================
// publishEventData 发布事件数据(带重试机制)
func (p *CertificationEventPublisher) publishEventData(ctx context.Context, eventData *CertificationEventData) error {
p.logger.Info("发布认证事件",
zap.String("event_type", eventData.EventType),
zap.String("certification_id", eventData.CertificationID),
zap.Time("timestamp", eventData.Timestamp))
// 尝试发布事件,带重试机制
var lastErr error
for attempt := 0; attempt <= MaxRetries; attempt++ {
if attempt > 0 {
// 指数退避重试
delay := time.Duration(attempt) * RetryDelay
p.logger.Warn("事件发布重试",
zap.String("event_type", eventData.EventType),
zap.Int("attempt", attempt),
zap.Duration("delay", delay))
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(delay):
// 继续重试
}
}
// 简化的事件发布:直接记录日志
p.logger.Info("模拟事件发布",
zap.String("event_type", eventData.EventType),
zap.String("certification_id", eventData.CertificationID),
zap.Any("data", eventData.Data))
// TODO: 这里可以集成真正的事件总线
// if err := p.eventBus.Publish(ctx, eventData); err != nil {
// lastErr = err
// continue
// }
// 发布成功
p.logger.Info("事件发布成功",
zap.String("event_type", eventData.EventType),
zap.String("certification_id", eventData.CertificationID))
return nil
}
// 理论上不会到达这里,因为简化实现总是成功
return lastErr
}
// ================ 事件处理器注册 ================
// RegisterEventHandlers 注册事件处理器
func (p *CertificationEventPublisher) RegisterEventHandlers() error {
// TODO: 注册具体的事件处理器
// 例如:发送通知、更新统计数据、触发后续流程等
p.logger.Info("认证事件处理器已注册")
return nil
}
// ================ 工具方法 ================
// CreateEventData 创建事件数据
func CreateEventData(eventType, certificationID, userID string, data map[string]interface{}) map[string]interface{} {
if data == nil {
data = make(map[string]interface{})
}
return map[string]interface{}{
"event_type": eventType,
"certification_id": certificationID,
"user_id": userID,
"data": data,
"timestamp": time.Now(),
"version": "1.0",
}
}

View File

@@ -0,0 +1,228 @@
package events
import (
"context"
"encoding/json"
"fmt"
"time"
"go.uber.org/zap"
"hyapi-server/internal/domains/finance/events"
"hyapi-server/internal/infrastructure/external/email"
"hyapi-server/internal/shared/interfaces"
)
// InvoiceEventHandler 发票事件处理器
type InvoiceEventHandler struct {
logger *zap.Logger
emailService *email.QQEmailService
name string
eventTypes []string
isAsync bool
}
// NewInvoiceEventHandler 创建发票事件处理器
func NewInvoiceEventHandler(logger *zap.Logger, emailService *email.QQEmailService) *InvoiceEventHandler {
return &InvoiceEventHandler{
logger: logger,
emailService: emailService,
name: "invoice-event-handler",
eventTypes: []string{
"InvoiceApplicationCreated",
"InvoiceApplicationApproved",
"InvoiceApplicationRejected",
"InvoiceFileUploaded",
},
isAsync: true,
}
}
// GetName 获取处理器名称
func (h *InvoiceEventHandler) GetName() string {
return h.name
}
// GetEventTypes 获取支持的事件类型
func (h *InvoiceEventHandler) GetEventTypes() []string {
return h.eventTypes
}
// IsAsync 是否为异步处理器
func (h *InvoiceEventHandler) IsAsync() bool {
return h.isAsync
}
// GetRetryConfig 获取重试配置
func (h *InvoiceEventHandler) GetRetryConfig() interfaces.RetryConfig {
return interfaces.RetryConfig{
MaxRetries: 3,
RetryDelay: 5 * time.Second,
BackoffFactor: 2.0,
MaxDelay: 30 * time.Second,
}
}
// Handle 处理事件
func (h *InvoiceEventHandler) Handle(ctx context.Context, event interfaces.Event) error {
h.logger.Info("🔄 开始处理发票事件",
zap.String("event_type", event.GetType()),
zap.String("event_id", event.GetID()),
zap.String("aggregate_id", event.GetAggregateID()),
zap.String("handler_name", h.GetName()),
zap.Time("event_timestamp", event.GetTimestamp()),
)
switch event.GetType() {
case "InvoiceApplicationCreated":
h.logger.Info("📝 处理发票申请创建事件")
return h.handleInvoiceApplicationCreated(ctx, event)
case "InvoiceApplicationApproved":
h.logger.Info("✅ 处理发票申请通过事件")
return h.handleInvoiceApplicationApproved(ctx, event)
case "InvoiceApplicationRejected":
h.logger.Info("❌ 处理发票申请拒绝事件")
return h.handleInvoiceApplicationRejected(ctx, event)
case "InvoiceFileUploaded":
h.logger.Info("📎 处理发票文件上传事件")
return h.handleInvoiceFileUploaded(ctx, event)
default:
h.logger.Warn("⚠️ 未知的发票事件类型", zap.String("event_type", event.GetType()))
return nil
}
}
// handleInvoiceApplicationCreated 处理发票申请创建事件
func (h *InvoiceEventHandler) handleInvoiceApplicationCreated(ctx context.Context, event interfaces.Event) error {
h.logger.Info("发票申请已创建",
zap.String("application_id", event.GetAggregateID()),
)
// 这里可以发送通知给管理员,告知有新的发票申请
// 暂时只记录日志
return nil
}
// handleInvoiceApplicationApproved 处理发票申请通过事件
func (h *InvoiceEventHandler) handleInvoiceApplicationApproved(ctx context.Context, event interfaces.Event) error {
h.logger.Info("发票申请已通过",
zap.String("application_id", event.GetAggregateID()),
)
// 这里可以发送通知给用户,告知发票申请已通过
// 暂时只记录日志
return nil
}
// handleInvoiceApplicationRejected 处理发票申请拒绝事件
func (h *InvoiceEventHandler) handleInvoiceApplicationRejected(ctx context.Context, event interfaces.Event) error {
h.logger.Info("发票申请被拒绝",
zap.String("application_id", event.GetAggregateID()),
)
// 这里可以发送邮件通知用户,告知发票申请被拒绝
// 暂时只记录日志
return nil
}
// handleInvoiceFileUploaded 处理发票文件上传事件
func (h *InvoiceEventHandler) handleInvoiceFileUploaded(ctx context.Context, event interfaces.Event) error {
h.logger.Info("📎 发票文件已上传事件开始处理",
zap.String("invoice_id", event.GetAggregateID()),
zap.String("event_id", event.GetID()),
)
// 解析事件数据
payload := event.GetPayload()
if payload == nil {
h.logger.Error("❌ 事件数据为空")
return fmt.Errorf("事件数据为空")
}
h.logger.Info("📋 事件数据解析开始",
zap.Any("payload_type", fmt.Sprintf("%T", payload)),
)
// 将payload转换为JSON然后解析为InvoiceFileUploadedEvent
payloadBytes, err := json.Marshal(payload)
if err != nil {
h.logger.Error("❌ 序列化事件数据失败", zap.Error(err))
return fmt.Errorf("序列化事件数据失败: %w", err)
}
h.logger.Info("📄 事件数据序列化成功",
zap.String("payload_json", string(payloadBytes)),
)
var fileUploadedEvent events.InvoiceFileUploadedEvent
err = json.Unmarshal(payloadBytes, &fileUploadedEvent)
if err != nil {
h.logger.Error("❌ 解析发票文件上传事件失败", zap.Error(err))
return fmt.Errorf("解析发票文件上传事件失败: %w", err)
}
h.logger.Info("✅ 事件数据解析成功",
zap.String("invoice_id", fileUploadedEvent.InvoiceID),
zap.String("user_id", fileUploadedEvent.UserID),
zap.String("receiving_email", fileUploadedEvent.ReceivingEmail),
zap.String("file_name", fileUploadedEvent.FileName),
zap.String("file_url", fileUploadedEvent.FileURL),
zap.String("company_name", fileUploadedEvent.CompanyName),
zap.String("amount", fileUploadedEvent.Amount.String()),
zap.String("invoice_type", string(fileUploadedEvent.InvoiceType)),
)
// 发送发票邮件给用户
return h.sendInvoiceEmail(ctx, &fileUploadedEvent)
}
// sendInvoiceEmail 发送发票邮件
func (h *InvoiceEventHandler) sendInvoiceEmail(ctx context.Context, event *events.InvoiceFileUploadedEvent) error {
h.logger.Info("📧 开始发送发票邮件",
zap.String("invoice_id", event.InvoiceID),
zap.String("user_id", event.UserID),
zap.String("receiving_email", event.ReceivingEmail),
zap.String("file_name", event.FileName),
zap.String("file_url", event.FileURL),
)
// 构建邮件数据
emailData := &email.InvoiceEmailData{
CompanyName: event.CompanyName,
Amount: event.Amount.String(),
InvoiceType: event.InvoiceType.GetDisplayName(),
FileURL: event.FileURL,
FileName: event.FileName,
ReceivingEmail: event.ReceivingEmail,
ApprovedAt: event.UploadedAt.Format("2006-01-02 15:04:05"),
}
h.logger.Info("📋 邮件数据构建完成",
zap.String("company_name", emailData.CompanyName),
zap.String("amount", emailData.Amount),
zap.String("invoice_type", emailData.InvoiceType),
zap.String("file_url", emailData.FileURL),
zap.String("file_name", emailData.FileName),
zap.String("receiving_email", emailData.ReceivingEmail),
zap.String("approved_at", emailData.ApprovedAt),
)
// 发送邮件
h.logger.Info("🚀 开始调用邮件服务发送邮件")
err := h.emailService.SendInvoiceEmail(ctx, emailData)
if err != nil {
h.logger.Error("❌ 发送发票邮件失败",
zap.String("invoice_id", event.InvoiceID),
zap.String("receiving_email", event.ReceivingEmail),
zap.Error(err),
)
return fmt.Errorf("发送发票邮件失败: %w", err)
}
h.logger.Info("✅ 发票邮件发送成功",
zap.String("invoice_id", event.InvoiceID),
zap.String("receiving_email", event.ReceivingEmail),
)
return nil
}

View File

@@ -0,0 +1,115 @@
package events
import (
"context"
"go.uber.org/zap"
"hyapi-server/internal/domains/finance/events"
"hyapi-server/internal/shared/interfaces"
)
// InvoiceEventPublisher 发票事件发布器实现
type InvoiceEventPublisher struct {
logger *zap.Logger
eventBus interfaces.EventBus
}
// NewInvoiceEventPublisher 创建发票事件发布器
func NewInvoiceEventPublisher(logger *zap.Logger, eventBus interfaces.EventBus) *InvoiceEventPublisher {
return &InvoiceEventPublisher{
logger: logger,
eventBus: eventBus,
}
}
// PublishInvoiceApplicationCreated 发布发票申请创建事件
func (p *InvoiceEventPublisher) PublishInvoiceApplicationCreated(ctx context.Context, event *events.InvoiceApplicationCreatedEvent) error {
p.logger.Info("发布发票申请创建事件",
zap.String("application_id", event.ApplicationID),
zap.String("user_id", event.UserID),
zap.String("invoice_type", string(event.InvoiceType)),
zap.String("amount", event.Amount.String()),
zap.String("company_name", event.CompanyName),
zap.String("receiving_email", event.ReceivingEmail),
)
// TODO: 实现实际的事件发布逻辑
// 例如:发送到消息队列、调用外部服务等
return nil
}
// PublishInvoiceApplicationApproved 发布发票申请通过事件
func (p *InvoiceEventPublisher) PublishInvoiceApplicationApproved(ctx context.Context, event *events.InvoiceApplicationApprovedEvent) error {
p.logger.Info("发布发票申请通过事件",
zap.String("application_id", event.ApplicationID),
zap.String("user_id", event.UserID),
zap.String("amount", event.Amount.String()),
zap.String("receiving_email", event.ReceivingEmail),
zap.Time("approved_at", event.ApprovedAt),
)
// TODO: 实现实际的事件发布逻辑
// 例如:发送邮件通知用户、更新统计数据等
return nil
}
// PublishInvoiceApplicationRejected 发布发票申请拒绝事件
func (p *InvoiceEventPublisher) PublishInvoiceApplicationRejected(ctx context.Context, event *events.InvoiceApplicationRejectedEvent) error {
p.logger.Info("发布发票申请拒绝事件",
zap.String("application_id", event.ApplicationID),
zap.String("user_id", event.UserID),
zap.String("reason", event.Reason),
zap.String("receiving_email", event.ReceivingEmail),
zap.Time("rejected_at", event.RejectedAt),
)
// TODO: 实现实际的事件发布逻辑
// 例如:发送邮件通知用户、记录拒绝原因等
return nil
}
// PublishInvoiceFileUploaded 发布发票文件上传事件
func (p *InvoiceEventPublisher) PublishInvoiceFileUploaded(ctx context.Context, event *events.InvoiceFileUploadedEvent) error {
p.logger.Info("📤 开始发布发票文件上传事件",
zap.String("invoice_id", event.InvoiceID),
zap.String("user_id", event.UserID),
zap.String("file_id", event.FileID),
zap.String("file_name", event.FileName),
zap.String("file_url", event.FileURL),
zap.String("receiving_email", event.ReceivingEmail),
zap.Time("uploaded_at", event.UploadedAt),
)
// 发布到事件总线
if p.eventBus != nil {
p.logger.Info("🚀 准备发布事件到事件总线",
zap.String("event_type", event.GetType()),
zap.String("event_id", event.GetID()),
)
if err := p.eventBus.Publish(ctx, event); err != nil {
p.logger.Error("❌ 发布发票文件上传事件到事件总线失败",
zap.String("invoice_id", event.InvoiceID),
zap.String("event_type", event.GetType()),
zap.String("event_id", event.GetID()),
zap.Error(err),
)
return err
}
p.logger.Info("✅ 发票文件上传事件已发布到事件总线",
zap.String("invoice_id", event.InvoiceID),
zap.String("event_type", event.GetType()),
zap.String("event_id", event.GetID()),
)
} else {
p.logger.Warn("⚠️ 事件总线未初始化,无法发布事件",
zap.String("invoice_id", event.InvoiceID),
)
}
return nil
}

View File

@@ -0,0 +1,123 @@
# 外部服务错误处理修复说明
## 问题描述
在外部服务WestDex、Yushan、Zhicha使用 `fmt.Errorf("%w: %s", ErrXXX, err)` 包装错误后,外层的 `errors.Is(err, ErrXXX)` 无法正确识别错误类型。
## 问题原因
`fmt.Errorf` 创建的包装错误虽然实现了 `Unwrap()` 接口,但没有实现 `Is()` 接口,因此 `errors.Is` 无法正确判断错误类型。
## 修复方案
统一使用 `errors.Join` 来组合错误,这是 Go 1.20+ 的标准做法,天然支持 `errors.Is` 判断。
## 修复内容
### 1. WestDex 服务 (`westdex_service.go`)
#### 修复前:
```go
// 无法被 errors.Is 识别的错误包装
err = fmt.Errorf("%w: %s", ErrSystem, marshalErr.Error())
err = fmt.Errorf("%w: %s", ErrDatasource, westDexResp.Message)
```
#### 修复后:
```go
// 可以被 errors.Is 正确识别的错误组合
err = errors.Join(ErrSystem, marshalErr)
err = errors.Join(ErrDatasource, fmt.Errorf(westDexResp.Message))
```
### 2. Yushan 服务 (`yushan_service.go`)
#### 修复前:
```go
// 无法被 errors.Is 识别的错误包装
err = fmt.Errorf("%w: %s", ErrSystem, err.Error())
err = fmt.Errorf("%w: %s", ErrDatasource, "羽山请求retdata为空")
```
#### 修复后:
```go
// 可以被 errors.Is 正确识别的错误组合
err = errors.Join(ErrSystem, err)
err = errors.Join(ErrDatasource, fmt.Errorf("羽山请求retdata为空"))
```
### 3. Zhicha 服务 (`zhicha_service.go`)
#### 修复前:
```go
// 无法被 errors.Is 识别的错误包装
err = fmt.Errorf("%w: %s", ErrSystem, marshalErr.Error())
err = fmt.Errorf("%w: %s", ErrDatasource, "HTTP状态码 %d", response.StatusCode)
```
#### 修复后:
```go
// 可以被 errors.Is 正确识别的错误组合
err = errors.Join(ErrSystem, marshalErr)
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
```
## 修复效果
### 修复前的问题:
```go
// 在应用服务层
if errors.Is(err, westdex.ErrDatasource) {
// 这里无法正确识别,因为 fmt.Errorf 包装的错误
// 没有实现 Is() 接口
return ErrDatasource
}
```
### 修复后的效果:
```go
// 在应用服务层
if errors.Is(err, westdex.ErrDatasource) {
// 现在可以正确识别了!
return ErrDatasource
}
if errors.Is(err, westdex.ErrSystem) {
// 系统错误也能正确识别
return ErrSystem
}
```
## 优势
1. **完全兼容**`errors.Is` 现在可以正确识别所有错误类型
2. **标准做法**:使用 Go 1.20+ 的 `errors.Join` 标准库功能
3. **性能优秀**:标准库实现,性能优于自定义解决方案
4. **维护简单**:无需自定义错误类型,代码更简洁
## 注意事项
1. **Go版本要求**:需要 Go 1.20 或更高版本(项目使用 Go 1.23.4,完全满足)
2. **错误消息格式**`errors.Join` 使用换行符分隔多个错误
3. **向后兼容**:现有的错误处理代码无需修改
## 测试验证
所有修复后的外部服务都能正确编译:
```bash
go build ./internal/infrastructure/external/westdex/...
go build ./internal/infrastructure/external/yushan/...
go build ./internal/infrastructure/external/zhicha/...
```
## 总结
通过统一使用 `errors.Join` 修复外部服务的错误处理,现在:
-`errors.Is(err, ErrDatasource)` 可以正确识别数据源异常
-`errors.Is(err, ErrSystem)` 可以正确识别系统异常
-`errors.Is(err, ErrNotFound)` 可以正确识别查询为空
- ✅ 错误处理逻辑更加清晰和可靠
- ✅ 符合 Go 1.20+ 的最佳实践
这个修复确保了整个系统的错误处理链路都能正确工作,提高了系统的可靠性和可维护性。

View File

@@ -0,0 +1,194 @@
# 阿里云二要素验证服务
这个服务提供了调用阿里云身份证二要素验证API的功能用于验证姓名和身份证号码是否匹配。
## 功能特性
- 身份证二要素验证(姓名 + 身份证号)
- 支持详细验证结果返回
- 支持简单布尔值判断
- 错误处理和中文错误信息
## 配置说明
### 必需配置
- `Host`: 阿里云API的域名地址
- `AppCode`: 阿里云市场应用的AppCode
### 配置示例
```go
host := "https://kzidcardv1.market.alicloudapi.com"
appCode := "您的AppCode"
```
## 使用方法
### 1. 创建服务实例
```go
service := NewAlicloudService(host, appCode)
```
### 2. 调用API
#### 身份证二要素验证示例
```go
// 构建请求参数
params := map[string]interface{}{
"name": "张三",
"idcard": "110101199001011234",
}
// 调用API
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
if err != nil {
log.Printf("验证失败: %v", err)
return
}
// 解析完整响应结构
var response struct {
Msg string `json:"msg"`
Success bool `json:"success"`
Code int `json:"code"`
Data struct {
Birthday string `json:"birthday"`
Result int `json:"result"`
Address string `json:"address"`
OrderNo string `json:"orderNo"`
Sex string `json:"sex"`
Desc string `json:"desc"`
} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
log.Printf("响应解析失败: %v", err)
return
}
// 检查响应状态
if response.Code != 200 {
log.Printf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
return
}
idCardData := response.Data
// 判断验证结果
if idCardData.Result == 1 {
fmt.Println("身份证信息验证通过")
} else {
fmt.Println("身份证信息验证失败")
}
```
#### 通用API调用
```go
// 调用其他阿里云API
params := map[string]interface{}{
"param1": "value1",
"param2": "value2",
}
responseBody, err := service.CallAPI("your/api/path", params)
if err != nil {
log.Printf("API调用失败: %v", err)
return
}
// 根据具体API的响应结构进行解析
// 每个API的响应结构可能不同需要根据API文档定义相应的结构体
var response struct {
Msg string `json:"msg"`
Code int `json:"code"`
Data interface{} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
log.Printf("响应解析失败: %v", err)
return
}
// 处理响应数据
fmt.Printf("响应数据: %s\n", string(responseBody))
```
## 响应格式
### 通用响应结构
```json
{
"msg": "成功",
"success": true,
"code": 200,
"data": {
// 具体的业务数据
}
}
```
### 身份证验证响应示例
#### 成功响应 (code: 200)
```json
{
"msg": "成功",
"success": true,
"code": 200,
"data": {
"birthday": "19840816",
"result": 1,
"address": "浙江省杭州市淳安县",
"orderNo": "202406271440416095174",
"sex": "男",
"desc": "不一致"
}
}
```
#### 参数错误响应 (code: 400)
```json
{
"msg": "请输入有效的身份证号码",
"code": 400,
"data": null
}
```
### 错误响应
```json
{
"msg": "AppCode无效",
"success": false,
"code": 400
}
```
## 错误处理
服务定义了以下错误类型:
- `ErrDatasource`: 数据源异常
- `ErrSystem`: 系统异常
- `ErrInvalid`: 身份证信息不匹配
## 注意事项
1. 请确保您的AppCode有效且有足够的调用额度
2. 身份证号码必须是18位有效格式
3. 姓名必须是真实有效的姓名
4. 建议在生产环境中添加适当的重试机制和超时设置
5. 请遵守阿里云API的使用规范和频率限制
## 依赖
- Go 1.16+
- 标准库:`net/http`, `encoding/json`, `net/url`

View File

@@ -0,0 +1,48 @@
package alicloud
import (
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewAlicloudServiceWithConfig 使用配置创建阿里云服务,并启用外部服务调用日志
func NewAlicloudServiceWithConfig(cfg *config.Config) (*AlicloudService, error) {
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: true,
LogDir: "./logs/external_services",
ServiceName: "alicloud",
UseDaily: false,
EnableLevelSeparation: true,
LevelConfigs: map[string]external_logger.ExternalServiceLevelFileConfig{
"info": {
MaxSize: 100,
MaxBackups: 3,
MaxAge: 28,
Compress: true,
},
"error": {
MaxSize: 100,
MaxBackups: 3,
MaxAge: 28,
Compress: true,
},
"warn": {
MaxSize: 100,
MaxBackups: 3,
MaxAge: 28,
Compress: true,
},
},
}
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
return NewAlicloudService(
cfg.Alicloud.Host,
cfg.Alicloud.AppCode,
logger,
), nil
}

View File

@@ -0,0 +1,142 @@
package alicloud
import (
"crypto/md5"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"hyapi-server/internal/shared/external_logger"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
)
// AlicloudConfig 阿里云配置
type AlicloudConfig struct {
Host string
AppCode string
}
// AlicloudService 阿里云服务
type AlicloudService struct {
config AlicloudConfig
logger *external_logger.ExternalServiceLogger
}
// NewAlicloudService 创建阿里云服务实例
func NewAlicloudService(host, appCode string, logger ...*external_logger.ExternalServiceLogger) *AlicloudService {
var serviceLogger *external_logger.ExternalServiceLogger
if len(logger) > 0 {
serviceLogger = logger[0]
}
return &AlicloudService{
config: AlicloudConfig{
Host: host,
AppCode: appCode,
},
logger: serviceLogger,
}
}
// generateRequestID 生成请求ID
func (a *AlicloudService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, a.config.Host)))
return fmt.Sprintf("alicloud_%x", hash[:8])
}
// CallAPI 调用阿里云API的通用方法
// path: API路径如 "api-mall/api/id_card/check"
// params: 请求参数
func (a *AlicloudService) CallAPI(path string, params map[string]interface{}) (respBytes []byte, err error) {
startTime := time.Now()
requestID := a.generateRequestID()
transactionID := ""
// 构建请求URL
reqURL := a.config.Host + "/" + path
// 记录请求日志
if a.logger != nil {
a.logger.LogRequest(requestID, transactionID, path, reqURL)
}
// 构建请求参数
formData := url.Values{}
for key, value := range params {
formData.Set(key, fmt.Sprintf("%v", value))
}
// 创建HTTP请求
req, err := http.NewRequest("POST", reqURL, strings.NewReader(formData.Encode()))
if err != nil {
if a.logger != nil {
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrSystem, err), params)
}
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
}
// 设置请求头
req.Header.Set("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8")
req.Header.Set("Authorization", "APPCODE "+a.config.AppCode)
// 发送请求超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
if a.logger != nil {
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %s", err.Error())), params)
}
return nil, fmt.Errorf("%w: API请求超时: %s", ErrDatasource, err.Error())
}
if a.logger != nil {
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrSystem, err), params)
}
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
}
defer resp.Body.Close()
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
if a.logger != nil {
a.logger.LogError(requestID, transactionID, path, errors.Join(ErrSystem, err), params)
}
return nil, fmt.Errorf("%w: %s", ErrSystem, err.Error())
}
// 记录响应日志(不记录具体响应数据)
if a.logger != nil {
duration := time.Since(startTime)
a.logger.LogResponse(requestID, transactionID, path, resp.StatusCode, duration)
}
// 直接返回原始响应body让调用方自己处理
return body, nil
}
// GetConfig 获取配置信息
func (a *AlicloudService) GetConfig() AlicloudConfig {
return a.config
}

View File

@@ -0,0 +1,143 @@
package alicloud
import (
"encoding/json"
"fmt"
"testing"
)
func TestRealAlicloudAPI(t *testing.T) {
// 使用真实的阿里云API配置
host := "https://kzidcardv1.market.alicloudapi.com"
appCode := "d55b58829efb41c8aa8e86769cba4844"
service := NewAlicloudService(host, appCode)
// 测试真实的身份证验证
name := "张荣宏"
idCard := "45212220000827423X"
fmt.Printf("开始测试阿里云二要素验证API...\n")
fmt.Printf("姓名: %s\n", name)
fmt.Printf("身份证: %s\n", idCard)
// 构建请求参数
params := map[string]interface{}{
"name": name,
"idcard": idCard,
}
// 调用真实API
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
if err != nil {
t.Logf("API调用失败: %v", err)
fmt.Printf("错误详情: %v\n", err)
t.Fail()
return
}
// 打印原始响应数据
fmt.Printf("API响应成功!\n")
fmt.Printf("原始响应数据: %s\n", string(responseBody))
// 解析完整响应结构
var response struct {
Msg string `json:"msg"`
Success bool `json:"success"`
Code int `json:"code"`
Data struct {
Birthday string `json:"birthday"`
Result int `json:"result"`
Address string `json:"address"`
OrderNo string `json:"orderNo"`
Sex string `json:"sex"`
Desc string `json:"desc"`
} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
t.Logf("响应数据解析失败: %v", err)
t.Fail()
return
}
// 检查响应状态
if response.Code != 200 {
t.Logf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
t.Fail()
return
}
idCardData := response.Data
// 打印详细响应结果
fmt.Printf("验证结果: %d\n", idCardData.Result)
fmt.Printf("描述: %s\n", idCardData.Desc)
fmt.Printf("生日: %s\n", idCardData.Birthday)
fmt.Printf("性别: %s\n", idCardData.Sex)
fmt.Printf("地址: %s\n", idCardData.Address)
fmt.Printf("订单号: %s\n", idCardData.OrderNo)
// 将完整响应转换为JSON并打印
jsonResponse, _ := json.MarshalIndent(idCardData, "", " ")
fmt.Printf("完整响应JSON:\n%s\n", string(jsonResponse))
// 判断验证结果
if idCardData.Result == 1 {
fmt.Printf("验证结果: 通过\n")
} else {
fmt.Printf("验证结果: 失败\n")
}
}
// TestAlicloudAPIError 测试错误响应
func TestAlicloudAPIError(t *testing.T) {
// 使用真实的阿里云API配置
host := "https://kzidcardv1.market.alicloudapi.com"
appCode := "d55b58829efb41c8aa8e86769cba4844"
service := NewAlicloudService(host, appCode)
// 测试无效的身份证号码
name := "张三"
invalidIdCard := "123456789"
fmt.Printf("测试错误响应 - 无效身份证号\n")
fmt.Printf("姓名: %s\n", name)
fmt.Printf("身份证: %s\n", invalidIdCard)
// 构建请求参数
params := map[string]interface{}{
"name": name,
"idcard": invalidIdCard,
}
// 调用真实API
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
if err != nil {
fmt.Printf("网络请求错误: %v\n", err)
return
}
// 解析响应
var response struct {
Msg string `json:"msg"`
Code int `json:"code"`
Data interface{} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
fmt.Printf("响应解析失败: %v\n", err)
return
}
// 检查是否为错误响应
if response.Code != 200 {
fmt.Printf("预期的错误响应: code=%d, msg=%s\n", response.Code, response.Msg)
fmt.Printf("错误处理正确: API返回错误状态\n")
} else {
t.Error("期望返回错误,但实际成功")
}
}

View File

@@ -0,0 +1,76 @@
package alicloud
import (
"encoding/json"
"fmt"
"log"
)
// ExampleUsage 使用示例
func ExampleUsage() {
// 创建阿里云服务实例
// 请替换为您的实际配置
host := "https://kzidcardv1.market.alicloudapi.com"
appCode := "您的AppCode"
service := NewAlicloudService(host, appCode)
// 示例:验证身份证信息
name := "张三"
idCard := "110101199001011234"
// 构建请求参数
params := map[string]interface{}{
"name": name,
"idcard": idCard,
}
// 调用API
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
if err != nil {
log.Printf("验证失败: %v", err)
return
}
// 解析完整响应结构
var response struct {
Msg string `json:"msg"`
Success bool `json:"success"`
Code int `json:"code"`
Data struct {
Birthday string `json:"birthday"`
Result int `json:"result"`
Address string `json:"address"`
OrderNo string `json:"orderNo"`
Sex string `json:"sex"`
Desc string `json:"desc"`
} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
log.Printf("响应解析失败: %v", err)
return
}
// 检查响应状态
if response.Code != 200 {
log.Printf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
return
}
idCardData := response.Data
fmt.Printf("验证结果: %d\n", idCardData.Result)
fmt.Printf("描述: %s\n", idCardData.Desc)
fmt.Printf("生日: %s\n", idCardData.Birthday)
fmt.Printf("性别: %s\n", idCardData.Sex)
fmt.Printf("地址: %s\n", idCardData.Address)
fmt.Printf("订单号: %s\n", idCardData.OrderNo)
// 判断验证结果
if idCardData.Result == 1 {
fmt.Println("身份证信息验证通过")
} else {
fmt.Println("身份证信息验证失败")
}
}

View File

@@ -0,0 +1,162 @@
package alicloud
import (
"encoding/json"
"fmt"
"log"
)
// ExampleAdvancedUsage 高级使用示例
func ExampleAdvancedUsage() {
// 创建阿里云服务实例
host := "https://kzidcardv1.market.alicloudapi.com"
appCode := "您的AppCode"
service := NewAlicloudService(host, appCode)
// 示例1: 身份证二要素验证
fmt.Println("=== 示例1: 身份证二要素验证 ===")
exampleIdCardCheck(service)
// 示例2: 其他API调用假设
fmt.Println("\n=== 示例2: 其他API调用 ===")
exampleOtherAPI(service)
}
// exampleIdCardCheck 身份证验证示例
func exampleIdCardCheck(service *AlicloudService) {
// 构建请求参数
params := map[string]interface{}{
"name": "张三",
"idcard": "110101199001011234",
}
// 调用API
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
if err != nil {
log.Printf("身份证验证失败: %v", err)
return
}
// 解析完整响应结构
var response struct {
Msg string `json:"msg"`
Success bool `json:"success"`
Code int `json:"code"`
Data struct {
Birthday string `json:"birthday"`
Result int `json:"result"`
Address string `json:"address"`
OrderNo string `json:"orderNo"`
Sex string `json:"sex"`
Desc string `json:"desc"`
} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
log.Printf("响应解析失败: %v", err)
return
}
// 检查响应状态
if response.Code != 200 {
log.Printf("API返回错误: code=%d, msg=%s", response.Code, response.Msg)
return
}
idCardData := response.Data
// 处理验证结果
fmt.Printf("验证结果: %d (%s)\n", idCardData.Result, idCardData.Desc)
fmt.Printf("生日: %s\n", idCardData.Birthday)
fmt.Printf("性别: %s\n", idCardData.Sex)
fmt.Printf("地址: %s\n", idCardData.Address)
fmt.Printf("订单号: %s\n", idCardData.OrderNo)
if idCardData.Result == 1 {
fmt.Println("✅ 身份证信息验证通过")
} else {
fmt.Println("❌ 身份证信息验证失败")
}
}
// exampleOtherAPI 其他API调用示例
func exampleOtherAPI(service *AlicloudService) {
// 假设调用其他API
params := map[string]interface{}{
"param1": "value1",
"param2": "value2",
}
// 调用API
responseBody, err := service.CallAPI("other/api/path", params)
if err != nil {
log.Printf("API调用失败: %v", err)
return
}
// 根据具体API的响应结构进行解析
// 这里只是示例实际使用时需要根据API文档定义相应的结构体
fmt.Printf("API响应数据: %s\n", string(responseBody))
// 示例:解析通用响应结构
var genericData map[string]interface{}
if err := json.Unmarshal(responseBody, &genericData); err != nil {
log.Printf("响应解析失败: %v", err)
return
}
fmt.Printf("解析后的数据: %+v\n", genericData)
}
// ExampleErrorHandling 错误处理示例
func ExampleErrorHandling() {
host := "https://kzidcardv1.market.alicloudapi.com"
appCode := "您的AppCode"
service := NewAlicloudService(host, appCode)
// 测试各种错误情况
testCases := []struct {
name string
idCard string
desc string
}{
{"张三", "123456789", "无效身份证号"},
{"", "110101199001011234", "空姓名"},
{"张三", "", "空身份证号"},
}
for _, tc := range testCases {
fmt.Printf("\n测试: %s\n", tc.desc)
params := map[string]interface{}{
"name": tc.name,
"idcard": tc.idCard,
}
responseBody, err := service.CallAPI("api-mall/api/id_card/check", params)
if err != nil {
fmt.Printf("❌ 网络请求错误: %v\n", err)
continue
}
// 解析响应
var response struct {
Msg string `json:"msg"`
Code int `json:"code"`
Data interface{} `json:"data"`
}
if err := json.Unmarshal(responseBody, &response); err != nil {
fmt.Printf("❌ 响应解析失败: %v\n", err)
continue
}
if response.Code != 200 {
fmt.Printf("❌ 预期错误: code=%d, msg=%s\n", response.Code, response.Msg)
} else {
fmt.Printf("⚠️ 意外成功\n")
}
}
}

View File

@@ -0,0 +1,134 @@
package captcha
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
"time"
"github.com/alibabacloud-go/tea/tea"
captcha20230305 "github.com/alibabacloud-go/captcha-20230305/client"
openapi "github.com/alibabacloud-go/darabonba-openapi/v2/client"
)
var (
ErrCaptchaVerifyFailed = errors.New("图形验证码校验失败")
ErrCaptchaConfig = errors.New("验证码配置错误")
ErrCaptchaEncryptMissing = errors.New("加密模式需要配置 EncryptKey控制台 ekey")
)
// CaptchaConfig 阿里云验证码配置
type CaptchaConfig struct {
AccessKeyID string
AccessKeySecret string
EndpointURL string
SceneID string
// EncryptKey 加密模式使用的密钥(控制台 ekeyBase64 编码的 32 字节),用于生成 EncryptedSceneId
EncryptKey string
}
// CaptchaService 阿里云验证码服务
type CaptchaService struct {
config CaptchaConfig
}
// NewCaptchaService 创建验证码服务实例
func NewCaptchaService(config CaptchaConfig) *CaptchaService {
return &CaptchaService{
config: config,
}
}
// Verify 验证滑块验证码
func (s *CaptchaService) Verify(captchaVerifyParam string) error {
if captchaVerifyParam == "" {
return ErrCaptchaVerifyFailed
}
if s.config.AccessKeyID == "" || s.config.AccessKeySecret == "" {
return ErrCaptchaConfig
}
clientCfg := &openapi.Config{
AccessKeyId: tea.String(s.config.AccessKeyID),
AccessKeySecret: tea.String(s.config.AccessKeySecret),
}
clientCfg.Endpoint = tea.String(s.config.EndpointURL)
client, err := captcha20230305.NewClient(clientCfg)
if err != nil {
return errors.Join(ErrCaptchaConfig, err)
}
req := &captcha20230305.VerifyIntelligentCaptchaRequest{
SceneId: tea.String(s.config.SceneID),
CaptchaVerifyParam: tea.String(captchaVerifyParam),
}
resp, err := client.VerifyIntelligentCaptcha(req)
if err != nil {
return errors.Join(ErrCaptchaVerifyFailed, err)
}
if resp.Body == nil || !tea.BoolValue(resp.Body.Result.VerifyResult) {
return ErrCaptchaVerifyFailed
}
return nil
}
// GetEncryptedSceneId 生成加密场景 IDEncryptedSceneId供前端加密模式初始化验证码使用。
// 算法AES-256-CBC明文 sceneId&timestamp&expireTime密钥为控制台 ekeyBase64 解码后 32 字节)。
// expireTimeSec 有效期为 186400 秒。
func (s *CaptchaService) GetEncryptedSceneId(expireTimeSec int) (string, error) {
if expireTimeSec <= 0 || expireTimeSec > 86400 {
return "", fmt.Errorf("expireTimeSec 必须在 186400 之间")
}
if s.config.EncryptKey == "" {
return "", ErrCaptchaEncryptMissing
}
if s.config.SceneID == "" {
return "", ErrCaptchaConfig
}
keyBytes, err := base64.StdEncoding.DecodeString(s.config.EncryptKey)
if err != nil || len(keyBytes) != 32 {
return "", errors.Join(ErrCaptchaConfig, fmt.Errorf("EncryptKey 必须为 Base64 编码的 32 字节"))
}
plaintext := fmt.Sprintf("%s&%d&%d", s.config.SceneID, time.Now().Unix(), expireTimeSec)
plainBytes := []byte(plaintext)
plainBytes = pkcs7Pad(plainBytes, aes.BlockSize)
block, err := aes.NewCipher(keyBytes)
if err != nil {
return "", errors.Join(ErrCaptchaConfig, err)
}
iv := make([]byte, aes.BlockSize)
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return "", err
}
mode := cipher.NewCBCEncrypter(block, iv)
ciphertext := make([]byte, len(plainBytes))
mode.CryptBlocks(ciphertext, plainBytes)
result := make([]byte, len(iv)+len(ciphertext))
copy(result, iv)
copy(result[len(iv):], ciphertext)
return base64.StdEncoding.EncodeToString(result), nil
}
func pkcs7Pad(data []byte, blockSize int) []byte {
n := blockSize - (len(data) % blockSize)
pad := make([]byte, n)
for i := range pad {
pad[i] = byte(n)
}
return append(data, pad...)
}

View File

@@ -0,0 +1,712 @@
package email
import (
"context"
"crypto/tls"
"fmt"
"html/template"
"net"
"net/smtp"
"strings"
"time"
"go.uber.org/zap"
"hyapi-server/internal/config"
)
// QQEmailService QQ邮箱服务
type QQEmailService struct {
config config.EmailConfig
logger *zap.Logger
}
// EmailData 邮件数据
type EmailData struct {
To string `json:"to"`
Subject string `json:"subject"`
Content string `json:"content"`
Data map[string]interface{} `json:"data"`
}
// InvoiceEmailData 发票邮件数据
type InvoiceEmailData struct {
CompanyName string `json:"company_name"`
Amount string `json:"amount"`
InvoiceType string `json:"invoice_type"`
FileURL string `json:"file_url"`
FileName string `json:"file_name"`
ReceivingEmail string `json:"receiving_email"`
ApprovedAt string `json:"approved_at"`
}
// NewQQEmailService 创建QQ邮箱服务
func NewQQEmailService(config config.EmailConfig, logger *zap.Logger) *QQEmailService {
return &QQEmailService{
config: config,
logger: logger,
}
}
// SendEmail 发送邮件
func (s *QQEmailService) SendEmail(ctx context.Context, data *EmailData) error {
s.logger.Info("开始发送邮件",
zap.String("to", data.To),
zap.String("subject", data.Subject),
)
// 构建邮件内容
message := s.buildEmailMessage(data)
// 发送邮件
err := s.sendSMTP(data.To, data.Subject, message)
if err != nil {
s.logger.Error("发送邮件失败",
zap.String("to", data.To),
zap.String("subject", data.Subject),
zap.Error(err),
)
return fmt.Errorf("发送邮件失败: %w", err)
}
s.logger.Info("邮件发送成功",
zap.String("to", data.To),
zap.String("subject", data.Subject),
)
return nil
}
// SendInvoiceEmail 发送发票邮件
func (s *QQEmailService) SendInvoiceEmail(ctx context.Context, data *InvoiceEmailData) error {
s.logger.Info("开始发送发票邮件",
zap.String("to", data.ReceivingEmail),
zap.String("company_name", data.CompanyName),
zap.String("amount", data.Amount),
)
// 构建邮件内容
subject := "您的发票已开具成功"
content := s.buildInvoiceEmailContent(data)
emailData := &EmailData{
To: data.ReceivingEmail,
Subject: subject,
Content: content,
Data: map[string]interface{}{
"company_name": data.CompanyName,
"amount": data.Amount,
"invoice_type": data.InvoiceType,
"file_url": data.FileURL,
"file_name": data.FileName,
"approved_at": data.ApprovedAt,
},
}
return s.SendEmail(ctx, emailData)
}
// buildEmailMessage 构建邮件消息
func (s *QQEmailService) buildEmailMessage(data *EmailData) string {
headers := make(map[string]string)
headers["From"] = s.config.FromEmail
headers["To"] = data.To
headers["Subject"] = data.Subject
headers["MIME-Version"] = "1.0"
headers["Content-Type"] = "text/html; charset=UTF-8"
var message strings.Builder
for key, value := range headers {
message.WriteString(fmt.Sprintf("%s: %s\r\n", key, value))
}
message.WriteString("\r\n")
message.WriteString(data.Content)
return message.String()
}
// buildInvoiceEmailContent 构建发票邮件内容
func (s *QQEmailService) buildInvoiceEmailContent(data *InvoiceEmailData) string {
htmlTemplate := `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<title>发票开具成功通知</title>
<style>
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Inter', 'Microsoft YaHei', Arial, sans-serif;
line-height: 1.6;
color: #2d3748;
background: linear-gradient(135deg, #f7fafc 0%, #edf2f7 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 650px;
margin: 0 auto;
background: #ffffff;
border-radius: 24px;
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.08);
overflow: hidden;
border: 1px solid #e2e8f0;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 50px 40px 40px;
text-align: center;
position: relative;
overflow: hidden;
}
.header::before {
content: '';
position: absolute;
top: -50%;
left: -50%;
width: 200%;
height: 200%;
background: radial-gradient(circle, rgba(255,255,255,0.08) 0%, transparent 70%);
animation: float 6s ease-in-out infinite;
}
@keyframes float {
0%, 100% { transform: translateY(0px) rotate(0deg); }
50% { transform: translateY(-20px) rotate(180deg); }
}
.success-icon {
font-size: 48px;
margin-bottom: 16px;
position: relative;
z-index: 1;
opacity: 0.9;
}
.header h1 {
font-size: 24px;
font-weight: 500;
margin: 0;
position: relative;
z-index: 1;
letter-spacing: 0.5px;
}
.content {
padding: 0;
}
.greeting {
padding: 40px 40px 20px;
text-align: center;
background: linear-gradient(135deg, #f8fafc 0%, #ffffff 100%);
}
.greeting p {
font-size: 16px;
color: #4a5568;
margin-bottom: 8px;
font-weight: 400;
}
.access-section {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 40px;
text-align: center;
position: relative;
overflow: hidden;
margin: 0 20px 30px;
border-radius: 20px;
}
.access-section::before {
content: '';
position: absolute;
top: -50%;
left: -50%;
width: 200%;
height: 200%;
background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%);
animation: shimmer 8s ease-in-out infinite;
}
@keyframes shimmer {
0%, 100% { transform: translateX(-100%) translateY(-100%) rotate(0deg); }
50% { transform: translateX(100%) translateY(100%) rotate(180deg); }
}
.access-section h3 {
color: white;
font-size: 22px;
font-weight: 600;
margin-bottom: 12px;
position: relative;
z-index: 1;
}
.access-section p {
color: rgba(255, 255, 255, 0.9);
margin-bottom: 25px;
position: relative;
z-index: 1;
font-size: 15px;
}
.access-btn {
display: inline-block;
background: rgba(255, 255, 255, 0.15);
color: white;
padding: 16px 32px;
text-decoration: none;
border-radius: 50px;
font-weight: 600;
font-size: 15px;
border: 2px solid rgba(255, 255, 255, 0.2);
transition: all 0.3s ease;
position: relative;
z-index: 1;
backdrop-filter: blur(10px);
letter-spacing: 0.3px;
}
.access-btn:hover {
background: rgba(255, 255, 255, 0.25);
transform: translateY(-3px);
box-shadow: 0 15px 35px rgba(0, 0, 0, 0.15);
border-color: rgba(255, 255, 255, 0.3);
}
.info-section {
padding: 0 40px 40px;
}
.info-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 20px;
margin: 30px 0;
}
.info-item {
background: linear-gradient(135deg, #f8fafc 0%, #ffffff 100%);
padding: 24px;
border-radius: 16px;
border: 1px solid #e2e8f0;
position: relative;
overflow: hidden;
transition: all 0.3s ease;
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.05);
}
.info-item:hover {
transform: translateY(-2px);
box-shadow: 0 12px 25px -5px rgba(102, 126, 234, 0.1);
border-color: #cbd5e0;
}
.info-item::before {
content: '';
position: absolute;
top: 0;
left: 0;
width: 4px;
height: 100%;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 0 2px 2px 0;
}
.info-label {
font-weight: 600;
color: #718096;
display: block;
margin-bottom: 8px;
font-size: 13px;
text-transform: uppercase;
letter-spacing: 0.8px;
}
.info-value {
color: #2d3748;
font-size: 16px;
font-weight: 500;
position: relative;
z-index: 1;
}
.notes-section {
background: linear-gradient(135deg, #f0fff4 0%, #ffffff 100%);
padding: 30px;
border-radius: 16px;
margin: 30px 0;
border: 1px solid #c6f6d5;
position: relative;
}
.notes-section::before {
content: '';
position: absolute;
top: 0;
left: 0;
width: 4px;
height: 100%;
background: linear-gradient(135deg, #48bb78 0%, #38a169 100%);
border-radius: 0 2px 2px 0;
}
.notes-section h4 {
color: #2f855a;
font-size: 16px;
font-weight: 600;
margin-bottom: 16px;
display: flex;
align-items: center;
}
.notes-section h4::before {
content: '📋';
margin-right: 8px;
font-size: 18px;
}
.notes-section ul {
list-style: none;
padding: 0;
}
.notes-section li {
color: #4a5568;
margin-bottom: 10px;
padding-left: 24px;
position: relative;
font-size: 14px;
}
.notes-section li::before {
content: '✓';
color: #48bb78;
font-weight: bold;
position: absolute;
left: 0;
font-size: 16px;
}
.footer {
background: linear-gradient(135deg, #2d3748 0%, #1a202c 100%);
color: rgba(255, 255, 255, 0.8);
padding: 35px 40px;
text-align: center;
font-size: 14px;
}
.footer p {
margin-bottom: 8px;
line-height: 1.5;
}
.footer-divider {
width: 60px;
height: 2px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
margin: 20px auto;
border-radius: 1px;
}
@media (max-width: 600px) {
.container {
margin: 10px;
border-radius: 20px;
}
.header {
padding: 40px 30px 30px;
}
.greeting {
padding: 30px 30px 20px;
}
.access-section {
margin: 0 15px 25px;
padding: 30px 25px;
}
.info-section {
padding: 0 30px 30px;
}
.info-grid {
grid-template-columns: 1fr;
gap: 16px;
}
.footer {
padding: 30px 30px;
}
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<div class="success-icon">✓</div>
<h1>发票已开具完成</h1>
</div>
<div class="content">
<div class="greeting">
<p>尊敬的用户,您好!</p>
<p>您的发票申请已审核通过,发票已成功开具。</p>
</div>
<div class="access-section">
<h3>📄 发票访问链接</h3>
<p>您的发票已准备就绪,请点击下方按钮访问查看页面</p>
<a href="{{.FileURL}}" class="access-btn" target="_blank">
🔗 访问发票页面
</a>
</div>
<div class="info-section">
<div class="info-grid">
<div class="info-item">
<span class="info-label">公司名称</span>
<span class="info-value">{{.CompanyName}}</span>
</div>
<div class="info-item">
<span class="info-label">发票金额</span>
<span class="info-value">¥{{.Amount}}</span>
</div>
<div class="info-item">
<span class="info-label">发票类型</span>
<span class="info-value">{{.InvoiceType}}</span>
</div>
<div class="info-item">
<span class="info-label">开具时间</span>
<span class="info-value">{{.ApprovedAt}}</span>
</div>
</div>
<div class="notes-section">
<h4>注意事项</h4>
<ul>
<li>访问页面后可在页面内下载发票文件</li>
<li>请妥善保管发票文件,建议打印存档</li>
<li>如有疑问,请回到我们平台进行下载</li>
</ul>
</div>
</div>
</div>
<div class="footer">
<p>此邮件由系统自动发送,请勿回复</p>
<div class="footer-divider"></div>
<p>海宇数据 API 服务平台</p>
<p>发送时间:{{.CurrentTime}}</p>
</div>
</div>
</body>
</html>`
// 解析模板
tmpl, err := template.New("invoice_email").Parse(htmlTemplate)
if err != nil {
s.logger.Error("解析邮件模板失败", zap.Error(err))
return s.buildSimpleInvoiceEmail(data)
}
// 准备模板数据
templateData := struct {
CompanyName string
Amount string
InvoiceType string
FileURL string
FileName string
ApprovedAt string
CurrentTime string
Domain string
}{
CompanyName: data.CompanyName,
Amount: data.Amount,
InvoiceType: data.InvoiceType,
FileURL: data.FileURL,
FileName: data.FileName,
ApprovedAt: data.ApprovedAt,
CurrentTime: time.Now().Format("2006-01-02 15:04:05"),
Domain: s.config.Domain,
}
// 执行模板
var content strings.Builder
err = tmpl.Execute(&content, templateData)
if err != nil {
s.logger.Error("执行邮件模板失败", zap.Error(err))
return s.buildSimpleInvoiceEmail(data)
}
return content.String()
}
// buildSimpleInvoiceEmail 构建简单的发票邮件内容(备用方案)
func (s *QQEmailService) buildSimpleInvoiceEmail(data *InvoiceEmailData) string {
return fmt.Sprintf(`
发票开具成功通知
尊敬的用户,您好!
您的发票申请已审核通过,发票已成功开具。
发票信息:
- 公司名称:%s
- 发票金额:¥%s
- 发票类型:%s
- 开具时间:%s
发票文件下载链接:%s
文件名:%s
如有疑问请访问控制台查看详细信息https://%s
海宇数据 API 服务平台
%s
`, data.CompanyName, data.Amount, data.InvoiceType, data.ApprovedAt, data.FileURL, data.FileName, s.config.Domain, time.Now().Format("2006-01-02 15:04:05"))
}
// sendSMTP 通过SMTP发送邮件
func (s *QQEmailService) sendSMTP(to, subject, message string) error {
// 构建认证信息
auth := smtp.PlainAuth("", s.config.Username, s.config.Password, s.config.Host)
// 构建收件人列表
toList := []string{to}
// 发送邮件
if s.config.UseSSL {
// QQ邮箱587端口使用STARTTLS465端口使用直接SSL
if s.config.Port == 587 {
// 使用STARTTLS (587端口)
conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port))
if err != nil {
return fmt.Errorf("连接SMTP服务器失败: %w", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, s.config.Host)
if err != nil {
return fmt.Errorf("创建SMTP客户端失败: %w", err)
}
defer client.Close()
// 启用STARTTLS
if err = client.StartTLS(&tls.Config{
ServerName: s.config.Host,
InsecureSkipVerify: false,
}); err != nil {
return fmt.Errorf("启用STARTTLS失败: %w", err)
}
// 认证
if err = client.Auth(auth); err != nil {
return fmt.Errorf("SMTP认证失败: %w", err)
}
// 设置发件人
if err = client.Mail(s.config.FromEmail); err != nil {
return fmt.Errorf("设置发件人失败: %w", err)
}
// 设置收件人
for _, recipient := range toList {
if err = client.Rcpt(recipient); err != nil {
return fmt.Errorf("设置收件人失败: %w", err)
}
}
// 发送邮件内容
writer, err := client.Data()
if err != nil {
return fmt.Errorf("准备发送邮件内容失败: %w", err)
}
defer writer.Close()
_, err = writer.Write([]byte(message))
if err != nil {
return fmt.Errorf("发送邮件内容失败: %w", err)
}
} else {
// 使用直接SSL连接 (465端口)
tlsConfig := &tls.Config{
ServerName: s.config.Host,
InsecureSkipVerify: false,
}
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", s.config.Host, s.config.Port), tlsConfig)
if err != nil {
return fmt.Errorf("连接SMTP服务器失败: %w", err)
}
defer conn.Close()
client, err := smtp.NewClient(conn, s.config.Host)
if err != nil {
return fmt.Errorf("创建SMTP客户端失败: %w", err)
}
defer client.Close()
// 认证
if err = client.Auth(auth); err != nil {
return fmt.Errorf("SMTP认证失败: %w", err)
}
// 设置发件人
if err = client.Mail(s.config.FromEmail); err != nil {
return fmt.Errorf("设置发件人失败: %w", err)
}
// 设置收件人
for _, recipient := range toList {
if err = client.Rcpt(recipient); err != nil {
return fmt.Errorf("设置收件人失败: %w", err)
}
}
// 发送邮件内容
writer, err := client.Data()
if err != nil {
return fmt.Errorf("准备发送邮件内容失败: %w", err)
}
defer writer.Close()
_, err = writer.Write([]byte(message))
if err != nil {
return fmt.Errorf("发送邮件内容失败: %w", err)
}
}
} else {
// 使用普通连接
err := smtp.SendMail(
fmt.Sprintf("%s:%d", s.config.Host, s.config.Port),
auth,
s.config.FromEmail,
toList,
[]byte(message),
)
if err != nil {
return fmt.Errorf("发送邮件失败: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,301 @@
package esign
import (
"context"
"fmt"
"time"
"go.uber.org/zap"
"hyapi-server/internal/domains/certification/entities/value_objects"
"hyapi-server/internal/domains/certification/enums"
"hyapi-server/internal/domains/certification/repositories"
"hyapi-server/internal/shared/esign"
)
// ================ 常量定义 ================
const (
// 企业认证超时时间
EnterpriseAuthTimeout = 30 * time.Minute
// 合同签署超时时间
ContractSignTimeout = 7 * 24 * time.Hour // 7天
// 回调重试次数
MaxCallbackRetries = 3
)
// ================ 服务实现 ================
// CertificationEsignService 认证e签宝服务实现
//
// 业务职责:
// - 处理企业认证流程
// - 处理合同生成和签署
// - 处理e签宝回调
// - 管理认证状态更新
type CertificationEsignService struct {
esignClient *esign.Client
commandRepo repositories.CertificationCommandRepository
queryRepo repositories.CertificationQueryRepository
logger *zap.Logger
}
// NewCertificationEsignService 创建认证e签宝服务
func NewCertificationEsignService(
esignClient *esign.Client,
commandRepo repositories.CertificationCommandRepository,
queryRepo repositories.CertificationQueryRepository,
logger *zap.Logger,
) *CertificationEsignService {
return &CertificationEsignService{
esignClient: esignClient,
commandRepo: commandRepo,
queryRepo: queryRepo,
logger: logger,
}
}
// ================ 企业认证流程 ================
// StartEnterpriseAuth 开始企业认证
//
// 业务流程:
// 1. 调用e签宝企业认证API
// 2. 更新认证记录的auth_flow_id
// 3. 更新状态为企业认证中
//
// 参数:
// - ctx: 上下文
// - certificationID: 认证ID
// - enterpriseInfo: 企业信息
//
// 返回:
// - authURL: 认证URL
// - error: 错误信息
func (s *CertificationEsignService) StartEnterpriseAuth(
ctx context.Context,
certificationID string,
enterpriseInfo *value_objects.EnterpriseInfo,
) (string, error) {
s.logger.Info("开始企业认证",
zap.String("certification_id", certificationID),
zap.String("company_name", enterpriseInfo.CompanyName))
// TODO: 实现e签宝企业认证API调用
// 暂时使用模拟响应
authFlowID := fmt.Sprintf("auth_%s_%d", certificationID, time.Now().Unix())
authURL := fmt.Sprintf("https://esign.example.com/auth/%s", authFlowID)
s.logger.Info("模拟调用e签宝企业认证API",
zap.String("auth_flow_id", authFlowID),
zap.String("auth_url", authURL))
// 更新认证记录
if err := s.commandRepo.UpdateAuthFlowID(ctx, certificationID, authFlowID); err != nil {
s.logger.Error("更新认证流程ID失败", zap.Error(err))
return "", fmt.Errorf("更新认证流程ID失败: %w", err)
}
s.logger.Info("企业认证启动成功",
zap.String("certification_id", certificationID),
zap.String("auth_flow_id", authFlowID))
return authURL, nil
}
// HandleEnterpriseAuthCallback 处理企业认证回调
//
// 业务流程:
// 1. 根据回调信息查找认证记录
// 2. 根据回调状态更新认证状态
// 3. 如果成功,继续合同生成流程
//
// 参数:
// - ctx: 上下文
// - authFlowID: 认证流程ID
// - success: 是否成功
// - message: 回调消息
//
// 返回:
// - error: 错误信息
func (s *CertificationEsignService) HandleEnterpriseAuthCallback(
ctx context.Context,
authFlowID string,
success bool,
message string,
) error {
s.logger.Info("处理企业认证回调",
zap.String("auth_flow_id", authFlowID),
zap.Bool("success", success))
// 查找认证记录
cert, err := s.queryRepo.FindByAuthFlowID(ctx, authFlowID)
if err != nil {
s.logger.Error("根据认证流程ID查找认证记录失败", zap.Error(err))
return fmt.Errorf("查找认证记录失败: %w", err)
}
if success {
// 企业认证成功,更新状态
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusEnterpriseVerified); err != nil {
s.logger.Error("更新认证状态失败", zap.Error(err))
return fmt.Errorf("更新认证状态失败: %w", err)
}
s.logger.Info("企业认证成功", zap.String("certification_id", cert.ID))
} else {
// 企业认证失败,更新状态
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusInfoRejected); err != nil {
s.logger.Error("更新认证状态失败", zap.Error(err))
return fmt.Errorf("更新认证状态失败: %w", err)
}
s.logger.Info("企业认证失败", zap.String("certification_id", cert.ID), zap.String("reason", message))
}
return nil
}
// ================ 合同管理流程 ================
// GenerateContract 生成认证合同
//
// 业务流程:
// 1. 调用e签宝合同生成API
// 2. 更新认证记录的合同信息
// 3. 更新状态为合同已生成
//
// 参数:
// - ctx: 上下文
// - certificationID: 认证ID
//
// 返回:
// - contractSignURL: 合同签署URL
// - error: 错误信息
func (s *CertificationEsignService) GenerateContract(
ctx context.Context,
certificationID string,
) (string, error) {
s.logger.Info("生成认证合同", zap.String("certification_id", certificationID))
// TODO: 实现e签宝合同生成API调用
// 暂时使用模拟响应
contractFileID := fmt.Sprintf("contract_%s_%d", certificationID, time.Now().Unix())
esignFlowID := fmt.Sprintf("flow_%s_%d", certificationID, time.Now().Unix())
contractURL := fmt.Sprintf("https://esign.example.com/contract/%s", contractFileID)
contractSignURL := fmt.Sprintf("https://esign.example.com/sign/%s", esignFlowID)
s.logger.Info("模拟调用e签宝合同生成API",
zap.String("contract_file_id", contractFileID),
zap.String("esign_flow_id", esignFlowID))
// 更新认证记录
if err := s.commandRepo.UpdateContractInfo(
ctx,
certificationID,
contractFileID,
esignFlowID,
contractURL,
contractSignURL,
); err != nil {
s.logger.Error("更新合同信息失败", zap.Error(err))
return "", fmt.Errorf("更新合同信息失败: %w", err)
}
// 更新状态
if err := s.commandRepo.UpdateStatus(ctx, certificationID, enums.StatusContractApplied); err != nil {
s.logger.Error("更新认证状态失败", zap.Error(err))
return "", fmt.Errorf("更新认证状态失败: %w", err)
}
s.logger.Info("认证合同生成成功",
zap.String("certification_id", certificationID),
zap.String("contract_file_id", contractFileID))
return contractSignURL, nil
}
// HandleContractSignCallback 处理合同签署回调
//
// 业务流程:
// 1. 根据回调信息查找认证记录
// 2. 根据回调状态更新认证状态
// 3. 如果成功,认证流程完成
//
// 参数:
// - ctx: 上下文
// - esignFlowID: e签宝流程ID
// - success: 是否成功
// - signedFileURL: 已签署文件URL
//
// 返回:
// - error: 错误信息
func (s *CertificationEsignService) HandleContractSignCallback(
ctx context.Context,
esignFlowID string,
success bool,
signedFileURL string,
) error {
s.logger.Info("处理合同签署回调",
zap.String("esign_flow_id", esignFlowID),
zap.Bool("success", success))
// 查找认证记录
cert, err := s.queryRepo.FindByEsignFlowID(ctx, esignFlowID)
if err != nil {
s.logger.Error("根据e签宝流程ID查找认证记录失败", zap.Error(err))
return fmt.Errorf("查找认证记录失败: %w", err)
}
if success {
// 合同签署成功更新合同URL
if err := s.commandRepo.UpdateContractInfo(ctx, cert.ID, cert.ContractFileID, cert.EsignFlowID, signedFileURL, cert.ContractSignURL); err != nil {
s.logger.Error("更新合同URL失败", zap.Error(err))
return fmt.Errorf("更新合同URL失败: %w", err)
}
// 更新状态到合同已签署
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusContractSigned); err != nil {
s.logger.Error("更新认证状态失败", zap.Error(err))
return fmt.Errorf("更新认证状态失败: %w", err)
}
s.logger.Info("合同签署成功", zap.String("certification_id", cert.ID))
} else {
// 合同签署失败
if err := s.commandRepo.UpdateStatus(ctx, cert.ID, enums.StatusContractRejected); err != nil {
s.logger.Error("更新认证状态失败", zap.Error(err))
return fmt.Errorf("更新认证状态失败: %w", err)
}
s.logger.Info("合同签署失败", zap.String("certification_id", cert.ID))
}
return nil
}
// ================ 辅助方法 ================
// GetContractSignURL 获取合同签署URL
//
// 参数:
// - ctx: 上下文
// - certificationID: 认证ID
//
// 返回:
// - signURL: 签署URL
// - error: 错误信息
func (s *CertificationEsignService) GetContractSignURL(ctx context.Context, certificationID string) (string, error) {
cert, err := s.queryRepo.GetByID(ctx, certificationID)
if err != nil {
return "", fmt.Errorf("获取认证信息失败: %w", err)
}
if cert.ContractSignURL == "" {
return "", fmt.Errorf("合同签署URL尚未生成")
}
return cert.ContractSignURL, nil
}

View File

@@ -0,0 +1,48 @@
package jiguang
import (
"crypto/hmac"
"crypto/md5"
"encoding/hex"
"fmt"
"strings"
)
// SignMethod 签名方法类型
type SignMethod string
const (
SignMethodMD5 SignMethod = "md5"
SignMethodHMACMD5 SignMethod = "hmac"
)
// GenerateSign 生成签名
// 根据 signMethod 参数选择使用 MD5 或 HMAC-MD5 算法
// MD5: md5(timestamp + "&appSecret=" + appSecret),然后转大写十六进制
// HMAC-MD5: hmac_md5(timestamp, appSecret),然后转大写十六进制
func GenerateSign(timestamp string, appSecret string, signMethod SignMethod) (string, error) {
var hashBytes []byte
switch signMethod {
case SignMethodMD5:
// MD5算法在待签名字符串后面加上 &appSecret=xxx 再进行计算
signStr := timestamp + "&appSecret=" + appSecret
hash := md5.Sum([]byte(signStr))
hashBytes = hash[:]
case SignMethodHMACMD5:
// HMAC-MD5算法使用 appSecret 初始化摘要算法再进行计算
mac := hmac.New(md5.New, []byte(appSecret))
mac.Write([]byte(timestamp))
hashBytes = mac.Sum(nil)
default:
return "", fmt.Errorf("不支持的签名方法: %s", signMethod)
}
// 将二进制转化为大写的十六进制正确签名应该为32大写字符串
return strings.ToUpper(hex.EncodeToString(hashBytes)), nil
}
// GenerateSignWithDefault 使用默认的 HMAC-MD5 方法生成签名
func GenerateSignWithDefault(timestamp string, appSecret string) (string, error) {
return GenerateSign(timestamp, appSecret, SignMethodHMACMD5)
}

View File

@@ -0,0 +1,149 @@
package jiguang
import (
"fmt"
)
// JiguangError 极光服务错误
type JiguangError struct {
Code int `json:"code"`
Message string `json:"message"`
}
// Error 实现error接口
func (e *JiguangError) Error() string {
return fmt.Sprintf("极光错误 [%d]: %s", e.Code, e.Message)
}
// IsSuccess 检查是否成功
func (e *JiguangError) IsSuccess() bool {
return e.Code == 0
}
// IsQueryFailed 检查是否查询失败
func (e *JiguangError) IsQueryFailed() bool {
return e.Code == 922
}
// IsNoRecord 检查是否查无记录
func (e *JiguangError) IsNoRecord() bool {
return e.Code == 921
}
// IsParamError 检查是否是参数相关错误
func (e *JiguangError) IsParamError() bool {
return e.Code == 400 || e.Code == 906 || e.Code == 914 || e.Code == 918
}
// IsAuthError 检查是否是认证相关错误
func (e *JiguangError) IsAuthError() bool {
return e.Code == 902 || e.Code == 903 || e.Code == 904 || e.Code == 905
}
// IsSystemError 检查是否是系统错误
func (e *JiguangError) IsSystemError() bool {
return e.Code == 405 || e.Code == 911 || e.Code == 912 || e.Code == 915 || e.Code == 916 || e.Code == 917 || e.Code == 919 || e.Code == 923
}
// 预定义错误常量
var (
// 成功状态
ErrSuccess = &JiguangError{Code: 0, Message: "请求成功"}
// 参数错误
ErrParamInvalid = &JiguangError{Code: 400, Message: "请求参数不正确QCXGGB2Q查询为空"}
ErrMethodInvalid = &JiguangError{Code: 405, Message: "请求方法不正确"}
ErrParamFormInvalid = &JiguangError{Code: 906, Message: "请求参数形式不正确"}
ErrBodyIncomplete = &JiguangError{Code: 914, Message: "Body 请求参数不完整"}
ErrBodyNotSupported = &JiguangError{Code: 918, Message: "Body 请求参数不支持"}
// 认证错误
ErrAppIDInvalid = &JiguangError{Code: 902, Message: "错误的 appId/账户已删除"}
ErrTimestampInvalid = &JiguangError{Code: 903, Message: "错误的时间戳/时间误差大于 10 分钟"}
ErrSignMethodInvalid = &JiguangError{Code: 904, Message: "无法识别的签名方法"}
ErrSignInvalid = &JiguangError{Code: 905, Message: "签名不合法"}
// 系统错误
ErrAccountStatusError = &JiguangError{Code: 911, Message: "账户状态异常"}
ErrInterfaceDisabled = &JiguangError{Code: 912, Message: "接口状态不可用"}
ErrAPICallError = &JiguangError{Code: 915, Message: "API 接口调用有误"}
ErrInternalError = &JiguangError{Code: 916, Message: "内部接口调用错误,请联系相关人员"}
ErrTimeout = &JiguangError{Code: 917, Message: "请求超时"}
ErrBusinessDisabled = &JiguangError{Code: 919, Message: "业务状态不可用"}
ErrInterfaceException = &JiguangError{Code: 923, Message: "接口异常"}
// 业务错误
ErrNoRecord = &JiguangError{Code: 921, Message: "查无记录"}
ErrQueryFailed = &JiguangError{Code: 922, Message: "查询失败"}
)
// NewJiguangError 创建新的极光错误
func NewJiguangError(code int, message string) *JiguangError {
return &JiguangError{
Code: code,
Message: message,
}
}
// NewJiguangErrorFromCode 根据状态码创建错误
func NewJiguangErrorFromCode(code int) *JiguangError {
switch code {
case 0:
return ErrSuccess
case 400:
return ErrParamInvalid
case 405:
return ErrMethodInvalid
case 902:
return ErrAppIDInvalid
case 903:
return ErrTimestampInvalid
case 904:
return ErrSignMethodInvalid
case 905:
return ErrSignInvalid
case 906:
return ErrParamFormInvalid
case 911:
return ErrAccountStatusError
case 912:
return ErrInterfaceDisabled
case 914:
return ErrBodyIncomplete
case 915:
return ErrAPICallError
case 916:
return ErrInternalError
case 917:
return ErrTimeout
case 918:
return ErrBodyNotSupported
case 919:
return ErrBusinessDisabled
case 921:
return ErrNoRecord
case 922:
return ErrQueryFailed
case 923:
return ErrInterfaceException
default:
return &JiguangError{
Code: code,
Message: fmt.Sprintf("未知错误码: %d", code),
}
}
}
// IsJiguangError 检查是否是极光错误
func IsJiguangError(err error) bool {
_, ok := err.(*JiguangError)
return ok
}
// GetJiguangError 获取极光错误
func GetJiguangError(err error) *JiguangError {
if jiguangErr, ok := err.(*JiguangError); ok {
return jiguangErr
}
return nil
}

View File

@@ -0,0 +1,85 @@
package jiguang
import (
"time"
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewJiguangServiceWithConfig 使用配置创建极光服务
func NewJiguangServiceWithConfig(cfg *config.Config) (*JiguangService, error) {
// 将配置类型转换为通用外部服务日志配置
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Jiguang.Logging.Enabled,
LogDir: cfg.Jiguang.Logging.LogDir,
ServiceName: "jiguang",
UseDaily: cfg.Jiguang.Logging.UseDaily,
EnableLevelSeparation: cfg.Jiguang.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
// 转换级别配置
for key, value := range cfg.Jiguang.Logging.LevelConfigs {
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: value.MaxSize,
MaxBackups: value.MaxBackups,
MaxAge: value.MaxAge,
Compress: value.Compress,
}
}
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 解析签名方法
var signMethod SignMethod
if cfg.Jiguang.SignMethod == "md5" {
signMethod = SignMethodMD5
} else {
signMethod = SignMethodHMACMD5 // 默认使用 HMAC-MD5
}
// 解析超时时间
timeout := 60 * time.Second
if cfg.Jiguang.Timeout > 0 {
timeout = cfg.Jiguang.Timeout
}
// 创建极光服务
service := NewJiguangService(
cfg.Jiguang.URL,
cfg.Jiguang.AppID,
cfg.Jiguang.AppSecret,
signMethod,
timeout,
logger,
)
return service, nil
}
// NewJiguangServiceWithLogging 使用自定义日志配置创建极光服务
func NewJiguangServiceWithLogging(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig) (*JiguangService, error) {
// 设置服务名称
loggingConfig.ServiceName = "jiguang"
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建极光服务
service := NewJiguangService(url, appID, appSecret, signMethod, timeout, logger)
return service, nil
}
// NewJiguangServiceSimple 创建简单的极光服务(无日志)
func NewJiguangServiceSimple(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration) *JiguangService {
return NewJiguangService(url, appID, appSecret, signMethod, timeout, nil)
}

View File

@@ -0,0 +1,316 @@
package jiguang
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"hyapi-server/internal/shared/external_logger"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
ErrNotFound = errors.New("查询为空")
)
// JiguangResponse 极光API响应结构兼容两套字段命名
//
// 格式一ordernum、message、result定位/查询类接口常见)
// 格式二order_id、msg、data文档中的 code/msg/order_id/data
type JiguangResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Message string `json:"message"`
OrderID string `json:"order_id"`
OrderNum string `json:"ordernum"`
Data interface{} `json:"data"`
Result interface{} `json:"result"`
}
// normalize 将异名字段合并到 OrderID、Msg便于后续统一分支使用
func (r *JiguangResponse) normalize() {
if r == nil {
return
}
if r.OrderID == "" && r.OrderNum != "" {
r.OrderID = r.OrderNum
}
if r.Msg == "" && r.Message != "" {
r.Msg = r.Message
}
}
// JiguangConfig 极光服务配置
type JiguangConfig struct {
URL string
AppID string
AppSecret string
SignMethod SignMethod // 签名方法md5 或 hmac
Timeout time.Duration
}
// JiguangService 极光服务
type JiguangService struct {
config JiguangConfig
logger *external_logger.ExternalServiceLogger
}
// NewJiguangService 创建一个新的极光服务实例
func NewJiguangService(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *JiguangService {
// 如果没有指定签名方法,默认使用 HMAC-MD5
if signMethod == "" {
signMethod = SignMethodHMACMD5
}
// 如果没有指定超时时间,默认使用 60 秒
if timeout == 0 {
timeout = 60 * time.Second
}
return &JiguangService{
config: JiguangConfig{
URL: url,
AppID: appID,
AppSecret: appSecret,
SignMethod: signMethod,
Timeout: timeout,
},
logger: logger,
}
}
// generateRequestID 生成请求ID
func (j *JiguangService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, j.config.AppID)))
return fmt.Sprintf("jiguang_%x", hash[:8])
}
// CallAPI 调用极光API
// apiCode: API服务编码如 marriage-single-v2用于请求头
// apiPath: API路径如 marriage/single-v2用于URL路径
// params: 请求参数会作为JSON body发送
func (j *JiguangService) CallAPI(ctx context.Context, apiCode string, apiPath string, params map[string]interface{}) (resp []byte, err error) {
startTime := time.Now()
requestID := j.generateRequestID()
// 生成时间戳(毫秒)
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
// 从ctx中获取transactionId
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
// 生成签名
sign, signErr := GenerateSign(timestamp, j.config.AppSecret, j.config.SignMethod)
if signErr != nil {
err = errors.Join(ErrSystem, fmt.Errorf("生成签名失败: %w", signErr))
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
// 构建完整的请求URL使用apiPath作为路径
requestURL := strings.TrimSuffix(j.config.URL, "/") + "/" + strings.TrimPrefix(apiPath, "/")
// 记录请求日志
if j.logger != nil {
j.logger.LogRequest(requestID, transactionID, apiCode, requestURL)
}
// 将请求参数转换为JSON
jsonData, marshalErr := json.Marshal(params)
if marshalErr != nil {
err = errors.Join(ErrSystem, marshalErr)
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
// 创建HTTP POST请求
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", requestURL, bytes.NewBuffer(jsonData))
if newRequestErr != nil {
err = errors.Join(ErrSystem, newRequestErr)
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("appId", j.config.AppID)
req.Header.Set("apiCode", apiCode)
req.Header.Set("timestamp", timestamp)
req.Header.Set("signMethod", string(j.config.SignMethod))
req.Header.Set("sign", sign)
// 创建HTTP客户端
client := &http.Client{
Timeout: j.config.Timeout,
}
// 发送请求
httpResp, clientDoErr := client.Do(req)
if clientDoErr != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
} else {
err = errors.Join(ErrSystem, clientDoErr)
}
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
defer func(Body io.ReadCloser) {
closeErr := Body.Close()
if closeErr != nil {
// 记录关闭错误
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), params)
}
}
}(httpResp.Body)
// 计算请求耗时
duration := time.Since(startTime)
// 读取响应体
bodyBytes, readErr := io.ReadAll(httpResp.Body)
if readErr != nil {
err = errors.Join(ErrSystem, readErr)
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
// 检查HTTP状态码
if httpResp.StatusCode != http.StatusOK {
err = errors.Join(ErrSystem, fmt.Errorf("极光请求失败,状态码: %d", httpResp.StatusCode))
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
// 解析响应结构
var jiguangResp JiguangResponse
if err := json.Unmarshal(bodyBytes, &jiguangResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
if j.logger != nil {
j.logger.LogError(requestID, transactionID, apiCode, err, params)
}
return nil, err
}
jiguangResp.normalize()
// 记录响应日志(不记录具体响应数据)
if j.logger != nil {
if jiguangResp.OrderID != "" {
j.logger.LogResponseWithID(requestID, transactionID, apiCode, httpResp.StatusCode, duration, jiguangResp.OrderID)
} else {
j.logger.LogResponse(requestID, transactionID, apiCode, httpResp.StatusCode, duration)
}
}
// 检查业务状态码
if jiguangResp.Code != 0 && jiguangResp.Code != 200 {
// 创建极光错误
jiguangErr := NewJiguangErrorFromCode(jiguangResp.Code)
if jiguangErr.Message == fmt.Sprintf("未知错误码: %d", jiguangResp.Code) {
if jiguangResp.Msg != "" {
jiguangErr.Message = jiguangResp.Msg
} else if jiguangResp.Message != "" {
jiguangErr.Message = jiguangResp.Message
}
}
// 根据错误类型返回不同的错误
if jiguangErr.IsNoRecord() {
// 从context中获取apiCode判断是否需要抛出异常
var processorCode string
if ctxProcessorCode, ok := ctx.Value("api_code").(string); ok {
processorCode = ctxProcessorCode
}
// 定义不需要抛出异常的处理器列表(默认情况下查无记录时抛出异常)
processorsNotToThrowError := map[string]bool{
// 在这个列表中的处理器,查无记录时返回空数组,不抛出异常
// 示例:如果需要添加某个处理器,取消下面的注释
// "QCXG9P1C": true,
}
// 如果是不需要抛出异常的处理器,返回空数组;否则(默认)抛出异常
if processorsNotToThrowError[processorCode] {
// 查无记录时返回空数组API调用记录为成功
return []byte("[]"), nil
}
// 记录错误日志
if j.logger != nil {
j.logger.LogErrorWithResponseID(requestID, transactionID, apiCode, jiguangErr, params, jiguangResp.OrderID)
}
return nil, errors.Join(ErrNotFound, jiguangErr)
}
// 记录错误日志(查无记录的情况不记录错误日志)
if j.logger != nil {
j.logger.LogErrorWithResponseID(requestID, transactionID, apiCode, jiguangErr, params, jiguangResp.OrderID)
}
if jiguangErr.IsQueryFailed() {
return nil, errors.Join(ErrDatasource, jiguangErr)
} else if jiguangErr.IsSystemError() {
return nil, errors.Join(ErrSystem, jiguangErr)
} else {
return nil, errors.Join(ErrDatasource, jiguangErr)
}
}
// 成功时业务体在 data 或 result
payload := jiguangResp.Data
if payload == nil {
payload = jiguangResp.Result
}
if payload == nil {
return []byte("{}"), nil
}
dataBytes, err := json.Marshal(payload)
if err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("业务数据序列化失败: %w", err))
if j.logger != nil {
j.logger.LogErrorWithResponseID(requestID, transactionID, apiCode, err, params, jiguangResp.OrderID)
}
return nil, err
}
return dataBytes, nil
}
// GetConfig 获取配置信息
func (j *JiguangService) GetConfig() JiguangConfig {
return j.config
}

View File

@@ -0,0 +1,25 @@
package muzi
import "fmt"
// MuziError 木子数据业务错误
type MuziError struct {
Code int
Message string
}
// Error implements error interface.
func (e *MuziError) Error() string {
return fmt.Sprintf("木子数据返回错误,代码: %d信息: %s", e.Code, e.Message)
}
// NewMuziError 根据错误码创建业务错误
func NewMuziError(code int, message string) *MuziError {
if message == "" {
message = "木子数据返回未知错误"
}
return &MuziError{
Code: code,
Message: message,
}
}

View File

@@ -0,0 +1,61 @@
package muzi
import (
"time"
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewMuziServiceWithConfig 使用配置创建木子数据服务
func NewMuziServiceWithConfig(cfg *config.Config) (*MuziService, error) {
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Muzi.Logging.Enabled,
LogDir: cfg.Muzi.Logging.LogDir,
ServiceName: "muzi",
UseDaily: cfg.Muzi.Logging.UseDaily,
EnableLevelSeparation: cfg.Muzi.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
for level, levelCfg := range cfg.Muzi.Logging.LevelConfigs {
loggingConfig.LevelConfigs[level] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: levelCfg.MaxSize,
MaxBackups: levelCfg.MaxBackups,
MaxAge: levelCfg.MaxAge,
Compress: levelCfg.Compress,
}
}
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
service := NewMuziService(
cfg.Muzi.URL,
cfg.Muzi.AppID,
cfg.Muzi.AppSecret,
cfg.Muzi.Timeout,
logger,
)
return service, nil
}
// NewMuziServiceWithLogging 使用自定义日志配置创建木子数据服务
func NewMuziServiceWithLogging(url, appID, appSecret string, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig) (*MuziService, error) {
loggingConfig.ServiceName = "muzi"
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
return NewMuziService(url, appID, appSecret, timeout, logger), nil
}
// NewMuziServiceSimple 创建无日志的木子数据服务
func NewMuziServiceSimple(url, appID, appSecret string, timeout time.Duration) *MuziService {
return NewMuziService(url, appID, appSecret, timeout, nil)
}

View File

@@ -0,0 +1,406 @@
package muzi
import (
"bytes"
"context"
"crypto/aes"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"reflect"
"sort"
"strconv"
"strings"
"time"
"hyapi-server/internal/shared/external_logger"
)
const defaultRequestTimeout = 60 * time.Second
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
)
// Muzi状态码常量
const (
CodeSuccess = 0 // 成功查询
CodeSystemError = 500 // 系统异常
CodeParamMissing = 601 // 参数不全
CodeInterfaceExpired = 602 // 接口已过期
CodeVerifyFailed = 603 // 接口校验失败
CodeIPNotInWhitelist = 604 // IP不在白名单中
CodeProductNotFound = 701 // 产品编号不存在
CodeUserNotFound = 702 // 用户名不存在
CodeUnauthorizedAPI = 703 // 接口未授权
CodeInsufficientFund = 704 // 商户余额不足
)
// MuziResponse 木子数据接口通用响应
type MuziResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data json.RawMessage `json:"data"`
Timestamp int64 `json:"timestamp"`
ExecuteTime int64 `json:"executeTime"`
}
// MuziConfig 木子数据接口配置
type MuziConfig struct {
URL string
AppID string
AppSecret string
Timeout time.Duration
}
// MuziService 木子数据接口服务封装
type MuziService struct {
config MuziConfig
logger *external_logger.ExternalServiceLogger
}
// NewMuziService 创建木子数据服务实例
func NewMuziService(url, appID, appSecret string, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *MuziService {
if timeout <= 0 {
timeout = defaultRequestTimeout
}
return &MuziService{
config: MuziConfig{
URL: url,
AppID: appID,
AppSecret: appSecret,
Timeout: timeout,
},
logger: logger,
}
}
// generateRequestID 生成请求ID
func (m *MuziService) generateRequestID() string {
timestamp := time.Now().UnixNano()
raw := fmt.Sprintf("%d_%s", timestamp, m.config.AppID)
sum := md5.Sum([]byte(raw))
return fmt.Sprintf("muzi_%x", sum[:8])
}
// CallAPI 调用木子数据接口
func (m *MuziService) CallAPI(ctx context.Context, prodCode string, path string, params map[string]interface{},paramSign map[string]interface{}) (json.RawMessage, error) {
requestID := m.generateRequestID()
now := time.Now()
timestamp := strconv.FormatInt(now.UnixMilli(), 10)
flatParams := flattenParams(params)
signParts := collectSignatureValues(paramSign)
signature := m.GenerateSignature(prodCode, timestamp, signParts...)
// 从上下文获取链路ID
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
requestBody := map[string]interface{}{
"appId": m.config.AppID,
"prodCode": prodCode,
"timestamp": timestamp,
"signature": signature,
}
for key, value := range flatParams {
requestBody[key] = value
}
if m.logger != nil {
m.logger.LogRequest(requestID, transactionID, prodCode, m.config.URL)
}
bodyBytes, marshalErr := json.Marshal(requestBody)
if marshalErr != nil {
err := errors.Join(ErrSystem, marshalErr)
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
}
return nil, err
}
// 构建完整的URL拼接路径参数
fullURL := m.config.URL
if path != "" {
// 确保路径以/开头
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
// 确保URL不以/结尾,避免双斜杠
if strings.HasSuffix(fullURL, "/") {
fullURL = fullURL[:len(fullURL)-1]
}
fullURL += path
}
req, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, fullURL, bytes.NewBuffer(bodyBytes))
if reqErr != nil {
err := errors.Join(ErrSystem, reqErr)
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
}
return nil, err
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{
Timeout: m.config.Timeout,
}
resp, httpErr := client.Do(req)
if httpErr != nil {
err := wrapHTTPError(httpErr)
if errors.Is(err, ErrDatasource) {
err = errors.Join(err, fmt.Errorf("API请求超时: %v", httpErr))
}
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
}
return nil, err
}
defer func(body io.ReadCloser) {
closeErr := body.Close()
if closeErr != nil && m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), requestBody)
}
}(resp.Body)
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
err := errors.Join(ErrSystem, readErr)
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
}
return nil, err
}
if m.logger != nil {
// 记录响应日志(不记录具体响应数据)
m.logger.LogResponse(requestID, transactionID, prodCode, resp.StatusCode, time.Since(now))
}
if resp.StatusCode != http.StatusOK {
err := errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", resp.StatusCode))
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
}
return nil, err
}
var muziResp MuziResponse
if err := json.Unmarshal(respBody, &muziResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %v", err))
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, err, requestBody)
}
return nil, err
}
if muziResp.Code != CodeSuccess {
muziErr := NewMuziError(muziResp.Code, muziResp.Msg)
var resultErr error
switch muziResp.Code {
case CodeSystemError:
resultErr = errors.Join(ErrDatasource, muziErr)
default:
resultErr = errors.Join(ErrSystem, muziErr)
}
if m.logger != nil {
m.logger.LogError(requestID, transactionID, prodCode, muziErr, requestBody)
}
return nil, resultErr
}
return muziResp.Data, nil
}
func wrapHTTPError(err error) error {
var timeout bool
if err == context.DeadlineExceeded {
timeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
timeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
timeout = true
}
if timeout {
return errors.Join(ErrDatasource, err)
}
return errors.Join(ErrSystem, err)
}
func pkcs5Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
func flattenParams(params map[string]interface{}) map[string]interface{} {
result := make(map[string]interface{})
if params == nil {
return result
}
for key, value := range params {
flattenValue(key, value, result)
}
return result
}
func flattenValue(prefix string, value interface{}, out map[string]interface{}) {
switch val := value.(type) {
case map[string]interface{}:
for k, v := range val {
flattenValue(combinePrefix(prefix, k), v, out)
}
case map[interface{}]interface{}:
for k, v := range val {
keyStr := fmt.Sprint(k)
flattenValue(combinePrefix(prefix, keyStr), v, out)
}
case []interface{}:
for i, item := range val {
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
flattenValue(nextPrefix, item, out)
}
case []string:
for i, item := range val {
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
flattenValue(nextPrefix, item, out)
}
case []int:
for i, item := range val {
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
flattenValue(nextPrefix, item, out)
}
case []float64:
for i, item := range val {
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
flattenValue(nextPrefix, item, out)
}
case []bool:
for i, item := range val {
nextPrefix := fmt.Sprintf("%s[%d]", prefix, i)
flattenValue(nextPrefix, item, out)
}
default:
out[prefix] = val
}
}
func combinePrefix(prefix, key string) string {
if prefix == "" {
return key
}
return prefix + "." + key
}
// Encrypt 使用 AES/ECB/PKCS5Padding 对单个字符串进行加密并返回 Base64 结果
func (m *MuziService) Encrypt(value string) (string, error) {
if len(m.config.AppSecret) != 32 {
return "", fmt.Errorf("AppSecret长度必须为32位")
}
block, err := aes.NewCipher([]byte(m.config.AppSecret))
if err != nil {
return "", fmt.Errorf("初始化加密器失败: %w", err)
}
padded := pkcs5Padding([]byte(value), block.BlockSize())
encrypted := make([]byte, len(padded))
for bs, be := 0, block.BlockSize(); bs < len(padded); bs, be = bs+block.BlockSize(), be+block.BlockSize() {
block.Encrypt(encrypted[bs:be], padded[bs:be])
}
return base64.StdEncoding.EncodeToString(encrypted), nil
}
// GenerateSignature 根据协议生成签名extraValues 会按顺序追加在待签名字符串之后
func (m *MuziService) GenerateSignature(prodCode, timestamp string, extraValues ...string) string {
signStr := m.config.AppID + prodCode + timestamp
for _, extra := range extraValues {
signStr += extra
}
hash := md5.Sum([]byte(signStr))
return hex.EncodeToString(hash[:])
}
// GenerateTimestamp 生成当前毫秒级时间戳字符串
func (m *MuziService) GenerateTimestamp() string {
return strconv.FormatInt(time.Now().UnixMilli(), 10)
}
// FlattenParams 将嵌套参数展平为一维键值对
func (m *MuziService) FlattenParams(params map[string]interface{}) map[string]interface{} {
return flattenParams(params)
}
func collectSignatureValues(data interface{}) []string {
var result []string
collectSignatureValuesRecursive(reflect.ValueOf(data), &result)
return result
}
func collectSignatureValuesRecursive(value reflect.Value, result *[]string) {
if !value.IsValid() {
*result = append(*result, "")
return
}
switch value.Kind() {
case reflect.Pointer, reflect.Interface:
if value.IsNil() {
*result = append(*result, "")
return
}
collectSignatureValuesRecursive(value.Elem(), result)
case reflect.Map:
keys := value.MapKeys()
sort.Slice(keys, func(i, j int) bool {
return fmt.Sprint(keys[i].Interface()) < fmt.Sprint(keys[j].Interface())
})
for _, key := range keys {
collectSignatureValuesRecursive(value.MapIndex(key), result)
}
case reflect.Slice, reflect.Array:
for i := 0; i < value.Len(); i++ {
collectSignatureValuesRecursive(value.Index(i), result)
}
case reflect.Struct:
typeInfo := value.Type()
fieldNames := make([]string, 0, value.NumField())
fieldIndices := make(map[string]int, value.NumField())
for i := 0; i < value.NumField(); i++ {
field := typeInfo.Field(i)
if field.PkgPath != "" {
continue
}
fieldNames = append(fieldNames, field.Name)
fieldIndices[field.Name] = i
}
sort.Strings(fieldNames)
for _, name := range fieldNames {
collectSignatureValuesRecursive(value.Field(fieldIndices[name]), result)
}
default:
*result = append(*result, fmt.Sprint(value.Interface()))
}
}

View File

@@ -0,0 +1,573 @@
package notification
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
)
// WeChatWorkService 企业微信通知服务
type WeChatWorkService struct {
webhookURL string
secret string
timeout time.Duration
logger *zap.Logger
}
// WechatWorkConfig 企业微信配置
type WechatWorkConfig struct {
WebhookURL string `yaml:"webhook_url"`
Timeout time.Duration `yaml:"timeout"`
}
// WechatWorkMessage 企业微信消息
type WechatWorkMessage struct {
MsgType string `json:"msgtype"`
Text *WechatWorkText `json:"text,omitempty"`
Markdown *WechatWorkMarkdown `json:"markdown,omitempty"`
}
// WechatWorkText 文本消息
type WechatWorkText struct {
Content string `json:"content"`
MentionedList []string `json:"mentioned_list,omitempty"`
MentionedMobileList []string `json:"mentioned_mobile_list,omitempty"`
}
// WechatWorkMarkdown Markdown消息
type WechatWorkMarkdown struct {
Content string `json:"content"`
}
// NewWeChatWorkService 创建企业微信通知服务
func NewWeChatWorkService(webhookURL, secret string, logger *zap.Logger) *WeChatWorkService {
return &WeChatWorkService{
webhookURL: webhookURL,
secret: secret,
timeout: 60 * time.Second,
logger: logger,
}
}
// SendTextMessage 发送文本消息
func (s *WeChatWorkService) SendTextMessage(ctx context.Context, content string, mentionedList []string, mentionedMobileList []string) error {
s.logger.Info("发送企业微信文本消息",
zap.String("content", content),
zap.Strings("mentioned_list", mentionedList),
)
message := map[string]interface{}{
"msgtype": "text",
"text": map[string]interface{}{
"content": content,
"mentioned_list": mentionedList,
"mentioned_mobile_list": mentionedMobileList,
},
}
return s.sendMessage(ctx, message)
}
// SendMarkdownMessage 发送Markdown消息
func (s *WeChatWorkService) SendMarkdownMessage(ctx context.Context, content string) error {
s.logger.Info("发送企业微信Markdown消息", zap.String("content", content))
message := map[string]interface{}{
"msgtype": "markdown",
"markdown": map[string]interface{}{
"content": content,
},
}
return s.sendMessage(ctx, message)
}
// SendCardMessage 发送卡片消息
func (s *WeChatWorkService) SendCardMessage(ctx context.Context, title, description, url string, btnText string) error {
s.logger.Info("发送企业微信卡片消息",
zap.String("title", title),
zap.String("description", description),
)
message := map[string]interface{}{
"msgtype": "template_card",
"template_card": map[string]interface{}{
"card_type": "text_notice",
"source": map[string]interface{}{
"icon_url": "https://example.com/icon.png",
"desc": "企业认证系统",
},
"main_title": map[string]interface{}{
"title": title,
},
"horizontal_content_list": []map[string]interface{}{
{
"keyname": "描述",
"value": description,
},
},
"jump_list": []map[string]interface{}{
{
"type": "1",
"title": btnText,
"url": url,
},
},
},
}
return s.sendMessage(ctx, message)
}
// SendCertificationNotification 发送认证相关通知
func (s *WeChatWorkService) SendCertificationNotification(ctx context.Context, notificationType string, data map[string]interface{}) error {
s.logger.Info("发送认证通知", zap.String("type", notificationType))
switch notificationType {
case "new_application":
return s.sendNewApplicationNotification(ctx, data)
case "pending_manual_review":
return s.sendPendingManualReviewNotification(ctx, data)
case "ocr_success":
return s.sendOCRSuccessNotification(ctx, data)
case "ocr_failed":
return s.sendOCRFailedNotification(ctx, data)
case "face_verify_success":
return s.sendFaceVerifySuccessNotification(ctx, data)
case "face_verify_failed":
return s.sendFaceVerifyFailedNotification(ctx, data)
case "admin_approved":
return s.sendAdminApprovedNotification(ctx, data)
case "admin_rejected":
return s.sendAdminRejectedNotification(ctx, data)
case "contract_signed":
return s.sendContractSignedNotification(ctx, data)
case "certification_completed":
return s.sendCertificationCompletedNotification(ctx, data)
default:
return fmt.Errorf("不支持的通知类型: %s", notificationType)
}
}
// sendNewApplicationNotification 发送新申请通知
func (s *WeChatWorkService) sendNewApplicationNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicantName := data["applicant_name"].(string)
applicationID := data["application_id"].(string)
content := fmt.Sprintf(`## 【海宇数据】🆕 新的企业认证申请
**企业名称**: %s
**申请人**: %s
**申请ID**: %s
**申请时间**: %s
请管理员及时审核处理。`,
companyName,
applicantName,
applicationID,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendPendingManualReviewNotification 用户已提交企业信息,待管理员人工审核(三真审核前序步骤)
func (s *WeChatWorkService) sendPendingManualReviewNotification(ctx context.Context, data map[string]interface{}) error {
companyName := fmt.Sprint(data["company_name"])
legalPersonName := fmt.Sprint(data["legal_person_name"])
authorizedRepName := fmt.Sprint(data["authorized_rep_name"])
contactPhone := fmt.Sprint(data["contact_phone"])
apiUsage := fmt.Sprint(data["api_usage"])
submitAt := fmt.Sprint(data["submit_at"])
if authorizedRepName == "" || authorizedRepName == "<nil>" {
authorizedRepName = "—"
}
if apiUsage == "" || apiUsage == "<nil>" {
apiUsage = "—"
}
if contactPhone == "" || contactPhone == "<nil>" {
contactPhone = "—"
}
content := fmt.Sprintf(`## 【海宇数据】📋 企业信息待人工审核
**企业名称**: %s
**法人**: %s
**授权申请人**: %s
**联系电话**: %s
**应用场景说明**: %s
**提交时间**: %s
> 请管理员登录后台 **企业审核** 通过审核后,用户方可进行 e签宝企业认证。`,
companyName,
legalPersonName,
authorizedRepName,
contactPhone,
apiUsage,
submitAt)
return s.SendMarkdownMessage(ctx, content)
}
// sendOCRSuccessNotification 发送OCR识别成功通知
func (s *WeChatWorkService) sendOCRSuccessNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
confidence := data["confidence"].(float64)
applicationID := data["application_id"].(string)
content := fmt.Sprintf(`## 【海宇数据】✅ OCR识别成功
**企业名称**: %s
**识别置信度**: %.2f%%
**申请ID**: %s
**识别时间**: %s
营业执照信息已自动提取,请用户确认信息。`,
companyName,
confidence*100,
applicationID,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendOCRFailedNotification 发送OCR识别失败通知
func (s *WeChatWorkService) sendOCRFailedNotification(ctx context.Context, data map[string]interface{}) error {
applicationID := data["application_id"].(string)
errorMsg := data["error_message"].(string)
content := fmt.Sprintf(`## 【海宇数据】❌ OCR识别失败
**申请ID**: %s
**错误信息**: %s
**失败时间**: %s
请检查营业执照图片质量或联系技术支持。`,
applicationID,
errorMsg,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendFaceVerifySuccessNotification 发送人脸识别成功通知
func (s *WeChatWorkService) sendFaceVerifySuccessNotification(ctx context.Context, data map[string]interface{}) error {
applicantName := data["applicant_name"].(string)
applicationID := data["application_id"].(string)
confidence := data["confidence"].(float64)
content := fmt.Sprintf(`## 【海宇数据】✅ 人脸识别成功
**申请人**: %s
**申请ID**: %s
**识别置信度**: %.2f%%
**识别时间**: %s
身份验证通过,可以进行下一步操作。`,
applicantName,
applicationID,
confidence*100,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendFaceVerifyFailedNotification 发送人脸识别失败通知
func (s *WeChatWorkService) sendFaceVerifyFailedNotification(ctx context.Context, data map[string]interface{}) error {
applicantName := data["applicant_name"].(string)
applicationID := data["application_id"].(string)
errorMsg := data["error_message"].(string)
content := fmt.Sprintf(`## 【海宇数据】❌ 人脸识别失败
**申请人**: %s
**申请ID**: %s
**错误信息**: %s
**失败时间**: %s
请重新进行人脸识别或联系技术支持。`,
applicantName,
applicationID,
errorMsg,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendAdminApprovedNotification 发送管理员审核通过通知
func (s *WeChatWorkService) sendAdminApprovedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
adminName := data["admin_name"].(string)
comment := data["comment"].(string)
content := fmt.Sprintf(`## 【海宇数据】✅ 管理员审核通过
**企业名称**: %s
**申请ID**: %s
**审核人**: %s
**审核意见**: %s
**审核时间**: %s
认证申请已通过审核,请用户签署电子合同。`,
companyName,
applicationID,
adminName,
comment,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendAdminRejectedNotification 发送管理员审核拒绝通知
func (s *WeChatWorkService) sendAdminRejectedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
adminName := data["admin_name"].(string)
reason := data["reason"].(string)
content := fmt.Sprintf(`## 【海宇数据】❌ 管理员审核拒绝
**企业名称**: %s
**申请ID**: %s
**审核人**: %s
**拒绝原因**: %s
**审核时间**: %s
认证申请被拒绝,请根据反馈意见重新提交。`,
companyName,
applicationID,
adminName,
reason,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendContractSignedNotification 发送合同签署通知
func (s *WeChatWorkService) sendContractSignedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
signerName := data["signer_name"].(string)
content := fmt.Sprintf(`## 【海宇数据】📝 电子合同已签署
**企业名称**: %s
**申请ID**: %s
**签署人**: %s
**签署时间**: %s
电子合同签署完成系统将自动生成钱包和Access Key。`,
companyName,
applicationID,
signerName,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendCertificationCompletedNotification 发送认证完成通知
func (s *WeChatWorkService) sendCertificationCompletedNotification(ctx context.Context, data map[string]interface{}) error {
companyName := data["company_name"].(string)
applicationID := data["application_id"].(string)
walletAddress := data["wallet_address"].(string)
content := fmt.Sprintf(`## 【海宇数据】🎉 企业认证完成
**企业名称**: %s
**申请ID**: %s
**钱包地址**: %s
**完成时间**: %s
恭喜企业认证流程已完成钱包和Access Key已生成。`,
companyName,
applicationID,
walletAddress,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// sendMessage 发送消息到企业微信
func (s *WeChatWorkService) sendMessage(ctx context.Context, message map[string]interface{}) error {
// 生成签名URL
signedURL := s.generateSignedURL()
// 序列化消息
messageBytes, err := json.Marshal(message)
if err != nil {
return fmt.Errorf("序列化消息失败: %w", err)
}
// 创建HTTP客户端
client := &http.Client{
Timeout: s.timeout,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "POST", signedURL, bytes.NewBuffer(messageBytes))
if err != nil {
return fmt.Errorf("创建请求失败: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "hyapi-server/1.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
errorMsg := "发送请求失败"
if isTimeout {
errorMsg = "发送请求超时"
}
return fmt.Errorf("%s: %w", errorMsg, err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
}
// 解析响应
var response map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误码
if errCode, ok := response["errcode"].(float64); ok && errCode != 0 {
errmsg := response["errmsg"].(string)
return fmt.Errorf("企业微信API错误: %d - %s", int(errCode), errmsg)
}
s.logger.Info("企业微信消息发送成功", zap.Any("response", response))
return nil
}
// generateSignedURL 生成带签名的URL
func (s *WeChatWorkService) generateSignedURL() string {
if s.secret == "" {
return s.webhookURL
}
// 生成时间戳
timestamp := time.Now().Unix()
// 生成随机字符串(这里简化处理,实际应该使用随机字符串)
nonce := fmt.Sprintf("%d", timestamp)
// 构建签名字符串
signStr := fmt.Sprintf("%d\n%s", timestamp, s.secret)
// 计算签名
h := hmac.New(sha256.New, []byte(s.secret))
h.Write([]byte(signStr))
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
// 构建签名URL
return fmt.Sprintf("%s&timestamp=%d&nonce=%s&sign=%s",
s.webhookURL, timestamp, nonce, signature)
}
// SendSystemAlert 发送系统告警
func (s *WeChatWorkService) SendSystemAlert(ctx context.Context, level, title, message string) error {
s.logger.Info("发送系统告警",
zap.String("level", level),
zap.String("title", title),
)
// 根据告警级别选择图标
var icon string
switch level {
case "info":
icon = ""
case "warning":
icon = "⚠️"
case "error":
icon = "🚨"
case "critical":
icon = "💥"
default:
icon = "📢"
}
content := fmt.Sprintf(`## 【海宇数据】%s 系统告警
**级别**: %s
**标题**: %s
**消息**: %s
**时间**: %s
请相关人员及时处理。`,
icon,
level,
title,
message,
time.Now().Format("2006-01-02 15:04:05"))
return s.SendMarkdownMessage(ctx, content)
}
// SendDailyReport 发送每日报告
func (s *WeChatWorkService) SendDailyReport(ctx context.Context, reportData map[string]interface{}) error {
s.logger.Info("发送每日报告")
content := fmt.Sprintf(`## 【海宇数据】📊 企业认证系统每日报告
**报告日期**: %s
### 统计数据
- **新增申请**: %d
- **OCR识别成功**: %d
- **OCR识别失败**: %d
- **人脸识别成功**: %d
- **人脸识别失败**: %d
- **审核通过**: %d
- **审核拒绝**: %d
- **认证完成**: %d
### 系统状态
- **系统运行时间**: %s
- **API调用次数**: %d
- **错误次数**: %d
祝您工作愉快!`,
time.Now().Format("2006-01-02"),
reportData["new_applications"],
reportData["ocr_success"],
reportData["ocr_failed"],
reportData["face_verify_success"],
reportData["face_verify_failed"],
reportData["admin_approved"],
reportData["admin_rejected"],
reportData["certification_completed"],
reportData["uptime"],
reportData["api_calls"],
reportData["errors"])
return s.SendMarkdownMessage(ctx, content)
}

View File

@@ -0,0 +1,147 @@
package notification_test
import (
"context"
"fmt"
"os"
"testing"
"time"
"go.uber.org/zap"
"hyapi-server/internal/infrastructure/external/notification"
)
// newTestWeChatWorkService 创建用于测试的企业微信服务实例
// 默认使用环境变量 WECOM_WEBHOOK若未设置则使用项目配置中的 webhook。
func newTestWeChatWorkService(t *testing.T) *notification.WeChatWorkService {
t.Helper()
webhook := os.Getenv("WECOM_WEBHOOK")
if webhook == "" {
// 使用你提供的 webhook 地址
webhook = "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=649bf737-28ca-4f30-ad5f-cfb65b2af113"
}
logger, _ := zap.NewDevelopment()
return notification.NewWeChatWorkService(webhook, "", logger)
}
// TestWeChatWork_SendAllBusinessNotifications
// 手动运行该用例,将依次向企业微信群推送 5 种业务场景的通知:
// 1. 用户充值成功
// 2. 用户申请开发票
// 3. 用户企业认证成功
// 4. 用户余额低于阈值
// 5. 用户余额欠费
//
// 注意:
// - 通知中只使用企业名称和手机号码不展示用户ID
// - 默认使用示例企业名称和手机号,实际使用时请根据需要修改
func TestWeChatWork_SendAllBusinessNotifications(t *testing.T) {
svc := newTestWeChatWorkService(t)
ctx := context.Background()
// 示例企业信息(实际可按需修改)
enterpriseName := "测试企业有限公司"
phone := "13800000000"
now := time.Now().Format("2006-01-02 15:04:05")
tests := []struct {
name string
content string
}{
{
name: "recharge_success",
content: fmt.Sprintf(
"### 【海宇数据】用户充值成功通知\n"+
"> 企业名称:%s\n"+
"> 联系手机:%s\n"+
"> 充值金额:%s 元\n"+
"> 入账总额:%s 元(含赠送)\n"+
"> 时间:%s\n",
enterpriseName,
phone,
"1000.00",
"1050.00",
now,
),
},
{
name: "invoice_applied",
content: fmt.Sprintf(
"### 【海宇数据】用户申请开发票\n"+
"> 企业名称:%s\n"+
"> 联系手机:%s\n"+
"> 申请开票金额:%s 元\n"+
"> 发票类型:%s\n"+
"> 申请时间:%s\n"+
"\n请财务尽快审核并开具发票。",
enterpriseName,
phone,
"500.00",
"增值税专用发票",
now,
),
},
{
name: "certification_completed",
content: fmt.Sprintf(
"### 【海宇数据】企业认证成功\n"+
"> 企业名称:%s\n"+
"> 联系手机:%s\n"+
"> 完成时间:%s\n"+
"\n该企业已完成认证请相关同事同步更新内部系统并关注后续接入情况。",
enterpriseName,
phone,
now,
),
},
{
name: "low_balance_alert",
content: fmt.Sprintf(
"### 【海宇数据】用户余额预警\n"+
"<font color=\"warning\">用户余额已低于预警阈值,请及时跟进。</font>\n"+
"> 企业名称:%s\n"+
"> 联系手机:%s\n"+
"> 当前余额:%s 元\n"+
"> 预警阈值:%s 元\n"+
"> 时间:%s\n",
enterpriseName,
phone,
"180.00",
"200.00",
now,
),
},
{
name: "arrears_alert",
content: fmt.Sprintf(
"### 【海宇数据】用户余额欠费告警\n"+
"<font color=\"warning\">该企业已发生欠费,请及时联系并处理。</font>\n"+
"> 企业名称:%s\n"+
"> 联系手机:%s\n"+
"> 当前余额:%s 元\n"+
"> 欠费金额:%s 元\n"+
"> 时间:%s\n",
enterpriseName,
phone,
"-50.00",
"50.00",
now,
),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
if err := svc.SendMarkdownMessage(ctx, tc.content); err != nil {
t.Fatalf("发送场景[%s]通知失败: %v", tc.name, err)
}
// 简单间隔,避免瞬时发送过多消息
time.Sleep(500 * time.Millisecond)
})
}
}

View File

@@ -0,0 +1,531 @@
package ocr
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"go.uber.org/zap"
"hyapi-server/internal/application/certification/dto/responses"
)
// BaiduOCRService 百度OCR服务
type BaiduOCRService struct {
apiKey string
secretKey string
endpoint string
timeout time.Duration
logger *zap.Logger
}
// NewBaiduOCRService 创建百度OCR服务
func NewBaiduOCRService(apiKey, secretKey string, logger *zap.Logger) *BaiduOCRService {
return &BaiduOCRService{
apiKey: apiKey,
secretKey: secretKey,
endpoint: "https://aip.baidubce.com",
timeout: 60 * time.Second,
logger: logger,
}
}
// RecognizeBusinessLicense 识别营业执照
func (s *BaiduOCRService) RecognizeBusinessLicense(ctx context.Context, imageBytes []byte) (*responses.BusinessLicenseResult, error) {
s.logger.Info("开始识别营业执照", zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/business_license?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("营业执照识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
licenseResult := s.parseBusinessLicenseResult(result)
s.logger.Info("营业执照识别成功",
zap.String("company_name", licenseResult.CompanyName),
zap.String("legal_representative", licenseResult.LegalPersonName),
zap.String("registered_capital", licenseResult.RegisteredCapital),
)
return licenseResult, nil
}
// RecognizeIDCard 识别身份证
func (s *BaiduOCRService) RecognizeIDCard(ctx context.Context, imageBytes []byte, side string) (*responses.IDCardResult, error) {
s.logger.Info("开始识别身份证", zap.String("side", side), zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/idcard?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s&side=%s", imageBase64UrlEncoded, side))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("身份证识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
idCardResult := s.parseIDCardResult(result, side)
s.logger.Info("身份证识别成功",
zap.String("name", idCardResult.Name),
zap.String("id_number", idCardResult.IDCardNumber),
zap.String("side", side),
)
return idCardResult, nil
}
// RecognizeGeneralText 通用文字识别
func (s *BaiduOCRService) RecognizeGeneralText(ctx context.Context, imageBytes []byte) (*responses.GeneralTextResult, error) {
s.logger.Info("开始通用文字识别", zap.Int("image_size", len(imageBytes)))
// 获取访问令牌
accessToken, err := s.getAccessToken(ctx)
if err != nil {
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
}
// 将图片转换为base64并进行URL编码
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
// 构建请求URL只包含access_token
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/general_basic?access_token=%s", s.endpoint, accessToken)
// 构建POST请求体
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
if err != nil {
return nil, fmt.Errorf("通用文字识别请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, fmt.Errorf("解析响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
errorMsg := result["error_msg"].(string)
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
}
// 解析识别结果
textResult := s.parseGeneralTextResult(result)
s.logger.Info("通用文字识别成功",
zap.Int("word_count", len(textResult.Words)),
zap.Float64("confidence", textResult.Confidence),
)
return textResult, nil
}
// RecognizeFromURL 从URL识别图片
func (s *BaiduOCRService) RecognizeFromURL(ctx context.Context, imageURL string, ocrType string) (interface{}, error) {
s.logger.Info("从URL识别图片", zap.String("url", imageURL), zap.String("type", ocrType))
// 下载图片
imageBytes, err := s.downloadImage(ctx, imageURL)
if err != nil {
s.logger.Error("下载图片失败", zap.Error(err))
return nil, fmt.Errorf("下载图片失败: %w", err)
}
// 根据类型调用相应的识别方法
switch ocrType {
case "business_license":
return s.RecognizeBusinessLicense(ctx, imageBytes)
case "idcard_front":
return s.RecognizeIDCard(ctx, imageBytes, "front")
case "idcard_back":
return s.RecognizeIDCard(ctx, imageBytes, "back")
case "general_text":
return s.RecognizeGeneralText(ctx, imageBytes)
default:
return nil, fmt.Errorf("不支持的OCR类型: %s", ocrType)
}
}
// getAccessToken 获取百度API访问令牌
func (s *BaiduOCRService) getAccessToken(ctx context.Context) (string, error) {
// 构建获取访问令牌的URL
tokenURL := fmt.Sprintf("%s/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
s.endpoint, s.apiKey, s.secretKey)
// 发送请求
resp, err := s.sendRequest(ctx, "POST", tokenURL, nil)
if err != nil {
return "", fmt.Errorf("获取访问令牌请求失败: %w", err)
}
// 解析响应
var result map[string]interface{}
if err := json.Unmarshal(resp, &result); err != nil {
return "", fmt.Errorf("解析访问令牌响应失败: %w", err)
}
// 检查错误
if errCode, ok := result["error"].(string); ok && errCode != "" {
errorDesc := result["error_description"].(string)
return "", fmt.Errorf("获取访问令牌失败: %s - %s", errCode, errorDesc)
}
// 提取访问令牌
accessToken, ok := result["access_token"].(string)
if !ok {
return "", fmt.Errorf("响应中未找到访问令牌")
}
return accessToken, nil
}
// sendRequest 发送HTTP请求
func (s *BaiduOCRService) sendRequest(ctx context.Context, method, url string, body io.Reader) ([]byte, error) {
// 创建HTTP客户端
client := &http.Client{
Timeout: s.timeout,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 设置请求头
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("User-Agent", "hyapi-server/1.0")
// 发送请求
resp, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error();
errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
return nil, fmt.Errorf("API请求超时: %w", err)
}
return nil, fmt.Errorf("发送请求失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
}
// 读取响应内容
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败: %w", err)
}
return responseBody, nil
}
// parseBusinessLicenseResult 解析营业执照识别结果
func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface{}) *responses.BusinessLicenseResult {
wordsResult := result["words_result"].(map[string]interface{})
// 提取企业信息
companyName := ""
if companyNameObj, ok := wordsResult["单位名称"].(map[string]interface{}); ok {
companyName = companyNameObj["words"].(string)
}
unifiedSocialCode := ""
if socialCreditCodeObj, ok := wordsResult["社会信用代码"].(map[string]interface{}); ok {
unifiedSocialCode = socialCreditCodeObj["words"].(string)
}
legalPersonName := ""
if legalPersonObj, ok := wordsResult["法人"].(map[string]interface{}); ok {
legalPersonName = legalPersonObj["words"].(string)
}
// 提取注册资本等其他信息
registeredCapital := ""
if registeredCapitalObj, ok := wordsResult["注册资本"].(map[string]interface{}); ok {
registeredCapital = registeredCapitalObj["words"].(string)
}
// 提取企业地址
address := ""
if addressObj, ok := wordsResult["地址"].(map[string]interface{}); ok {
address = addressObj["words"].(string)
}
// 计算置信度这里简化处理实际应该从OCR结果中获取
confidence := 0.9 // 默认置信度
return &responses.BusinessLicenseResult{
CompanyName: companyName,
UnifiedSocialCode: unifiedSocialCode,
LegalPersonName: legalPersonName,
LegalPersonID: "", // 营业执照上没有法人身份证号
RegisteredCapital: registeredCapital,
Address: address,
Confidence: confidence,
ProcessedAt: time.Now(),
}
}
// parseIDCardResult 解析身份证识别结果
func (s *BaiduOCRService) parseIDCardResult(result map[string]interface{}, side string) *responses.IDCardResult {
wordsResult := result["words_result"].(map[string]interface{})
idCardResult := &responses.IDCardResult{
Side: side,
Confidence: s.extractConfidence(result),
}
if side == "front" {
if name, ok := wordsResult["姓名"]; ok {
if word, ok := name.(map[string]interface{}); ok {
idCardResult.Name = word["words"].(string)
}
}
if gender, ok := wordsResult["性别"]; ok {
if word, ok := gender.(map[string]interface{}); ok {
idCardResult.Gender = word["words"].(string)
}
}
if nation, ok := wordsResult["民族"]; ok {
if word, ok := nation.(map[string]interface{}); ok {
idCardResult.Nation = word["words"].(string)
}
}
if birthday, ok := wordsResult["出生"]; ok {
if word, ok := birthday.(map[string]interface{}); ok {
idCardResult.Birthday = word["words"].(string)
}
}
if address, ok := wordsResult["住址"]; ok {
if word, ok := address.(map[string]interface{}); ok {
idCardResult.Address = word["words"].(string)
}
}
if idNumber, ok := wordsResult["公民身份号码"]; ok {
if word, ok := idNumber.(map[string]interface{}); ok {
idCardResult.IDCardNumber = word["words"].(string)
}
}
} else {
if issuingAgency, ok := wordsResult["签发机关"]; ok {
if word, ok := issuingAgency.(map[string]interface{}); ok {
idCardResult.IssuingAgency = word["words"].(string)
}
}
if validPeriod, ok := wordsResult["有效期限"]; ok {
if word, ok := validPeriod.(map[string]interface{}); ok {
idCardResult.ValidPeriod = word["words"].(string)
}
}
}
return idCardResult
}
// parseGeneralTextResult 解析通用文字识别结果
func (s *BaiduOCRService) parseGeneralTextResult(result map[string]interface{}) *responses.GeneralTextResult {
wordsResult := result["words_result"].([]interface{})
textResult := &responses.GeneralTextResult{
Confidence: s.extractConfidence(result),
Words: make([]responses.TextLine, 0, len(wordsResult)),
}
for _, word := range wordsResult {
if wordMap, ok := word.(map[string]interface{}); ok {
line := responses.TextLine{
Text: wordMap["words"].(string),
Confidence: 1.0, // 百度返回的通用文字识别没有单独置信度
}
textResult.Words = append(textResult.Words, line)
}
}
return textResult
}
// extractConfidence 提取置信度
func (s *BaiduOCRService) extractConfidence(result map[string]interface{}) float64 {
if confidence, ok := result["confidence"].(float64); ok {
return confidence
}
return 0.0
}
// extractWords 提取识别的文字
func (s *BaiduOCRService) extractWords(result map[string]interface{}) []string {
words := make([]string, 0)
if wordsResult, ok := result["words_result"]; ok {
switch v := wordsResult.(type) {
case map[string]interface{}:
// 营业执照等结构化文档
for _, word := range v {
if wordMap, ok := word.(map[string]interface{}); ok {
if wordsStr, ok := wordMap["words"].(string); ok {
words = append(words, wordsStr)
}
}
}
case []interface{}:
// 通用文字识别
for _, word := range v {
if wordMap, ok := word.(map[string]interface{}); ok {
if wordsStr, ok := wordMap["words"].(string); ok {
words = append(words, wordsStr)
}
}
}
}
}
return words
}
// downloadImage 下载图片
func (s *BaiduOCRService) downloadImage(ctx context.Context, imageURL string) ([]byte, error) {
// 创建HTTP客户端
client := &http.Client{
Timeout: 60 * time.Second,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("下载图片失败: %w", err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
}
// 读取响应内容
imageBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取图片内容失败: %w", err)
}
return imageBytes, nil
}
// ValidateBusinessLicense 验证营业执照识别结果
func (s *BaiduOCRService) ValidateBusinessLicense(result *responses.BusinessLicenseResult) error {
if result.Confidence < 0.8 {
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
}
if result.CompanyName == "" {
return fmt.Errorf("未能识别公司名称")
}
if result.LegalPersonName == "" {
return fmt.Errorf("未能识别法定代表人")
}
if result.UnifiedSocialCode == "" {
return fmt.Errorf("未能识别统一社会信用代码")
}
return nil
}
// ValidateIDCard 验证身份证识别结果
func (s *BaiduOCRService) ValidateIDCard(result *responses.IDCardResult) error {
if result.Confidence < 0.8 {
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
}
if result.Side == "front" {
if result.Name == "" {
return fmt.Errorf("未能识别姓名")
}
if result.IDCardNumber == "" {
return fmt.Errorf("未能识别身份证号码")
}
} else {
if result.IssuingAgency == "" {
return fmt.Errorf("未能识别签发机关")
}
if result.ValidPeriod == "" {
return fmt.Errorf("未能识别有效期限")
}
}
return nil
}

View File

@@ -0,0 +1,164 @@
package pdfgen
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"hyapi-server/internal/config"
"go.uber.org/zap"
)
// PDFGenService PDF生成服务客户端
type PDFGenService struct {
baseURL string
apiPath string
logger *zap.Logger
client *http.Client
}
// NewPDFGenService 创建PDF生成服务客户端
func NewPDFGenService(cfg *config.Config, logger *zap.Logger) *PDFGenService {
// 根据环境选择服务地址
var baseURL string
if cfg.App.IsProduction() {
baseURL = cfg.PDFGen.ProductionURL
} else {
baseURL = cfg.PDFGen.DevelopmentURL
}
// 如果配置为空,使用默认值
if baseURL == "" {
if cfg.App.IsProduction() {
baseURL = "http://localhost:15990"
} else {
baseURL = "http://101.43.41.217:15990"
}
}
// 获取API路径如果为空使用默认值
apiPath := cfg.PDFGen.APIPath
if apiPath == "" {
apiPath = "/api/v1/generate/guangzhou"
}
// 获取超时时间如果为0使用默认值
timeout := cfg.PDFGen.Timeout
if timeout == 0 {
timeout = 120 * time.Second
}
logger.Info("PDF生成服务已初始化",
zap.String("base_url", baseURL),
zap.String("api_path", apiPath),
zap.Duration("timeout", timeout),
)
return &PDFGenService{
baseURL: baseURL,
apiPath: apiPath,
logger: logger,
client: &http.Client{
Timeout: timeout,
Transport: &http.Transport{
Proxy: nil, // 不使用任何代理
},
},
}
}
// GeneratePDFRequest PDF生成请求
type GeneratePDFRequest struct {
Data []map[string]interface{} `json:"data"`
ReportNumber string `json:"report_number,omitempty"`
GenerateTime string `json:"generate_time,omitempty"`
}
// GeneratePDFResponse PDF生成响应
type GeneratePDFResponse struct {
PDFBytes []byte
FileName string
}
// GenerateGuangzhouPDF 生成广州大数据租赁风险PDF报告
func (s *PDFGenService) GenerateGuangzhouPDF(ctx context.Context, req *GeneratePDFRequest) (*GeneratePDFResponse, error) {
// 构建请求体
reqBody, err := json.Marshal(req)
if err != nil {
return nil, fmt.Errorf("序列化请求失败: %w", err)
}
// 构建请求URL
url := fmt.Sprintf("%s%s", s.baseURL, s.apiPath)
// 创建HTTP请求
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(reqBody))
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 设置请求头
httpReq.Header.Set("Content-Type", "application/json")
start := time.Now()
// 发送请求
s.logger.Info("开始调用PDF生成服务",
zap.String("url", url),
zap.Int("data_count", len(req.Data)),
zap.ByteString("reqBody", reqBody),
)
resp, err := s.client.Do(httpReq)
if err != nil {
s.logger.Error("调用PDF生成服务失败",
zap.String("url", url),
zap.Duration("duration", time.Since(start)),
zap.Error(err),
)
return nil, fmt.Errorf("调用PDF生成服务失败: %w", err)
}
defer resp.Body.Close()
// 检查HTTP状态码
if resp.StatusCode != http.StatusOK {
// 尝试读取错误信息
errorBody, _ := io.ReadAll(resp.Body)
s.logger.Error("PDF生成服务返回错误",
zap.String("url", url),
zap.Int("status_code", resp.StatusCode),
zap.Duration("duration", time.Since(start)),
zap.String("error_body", string(errorBody)),
)
return nil, fmt.Errorf("PDF生成失败状态码: %d, 错误: %s", resp.StatusCode, string(errorBody))
}
// 读取PDF文件
pdfBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取PDF文件失败: %w", err)
}
// 生成文件名
fileName := "大数据租赁风险报告.pdf"
if req.ReportNumber != "" {
fileName = fmt.Sprintf("%s.pdf", req.ReportNumber)
}
s.logger.Info("PDF生成成功",
zap.String("url", url),
zap.String("file_name", fileName),
zap.Int("file_size", len(pdfBytes)),
zap.Duration("duration", time.Since(start)),
)
return &GeneratePDFResponse{
PDFBytes: pdfBytes,
FileName: fileName,
}, nil
}

View File

@@ -0,0 +1,47 @@
package shujubao
import (
"crypto/hmac"
"crypto/md5"
"encoding/hex"
"strings"
)
// SignMethod 签名方法
type SignMethod string
const (
SignMethodMD5 SignMethod = "md5"
SignMethodHMACMD5 SignMethod = "hmac"
)
// GenerateSignMD5 使用 MD5 生成签名md5(app_secret + timestamp)32 位小写
func GenerateSignMD5(appSecret, timestamp string) string {
h := md5.Sum([]byte(appSecret + timestamp))
sign := strings.ToLower(hex.EncodeToString(h[:]))
return sign
}
// GenerateSignHMAC 使用 HMAC-MD5 生成签名(仅 timestamp兼容旧逻辑
func GenerateSignHMAC(appSecret, timestamp string) string {
mac := hmac.New(md5.New, []byte(appSecret))
mac.Write([]byte(timestamp))
sign := strings.ToLower(hex.EncodeToString(mac.Sum(nil)))
return sign
}
// GenerateSignFromParamsMD5 根据入参生成签名:入参按 ASCII 排序组合后与 app_secret 做 MD5。
// sortedParamStr 格式为 key1=value1&key2=value2&...key 按字母序)。
func GenerateSignFromParamsMD5(appSecret, sortedParamStr string) string {
h := md5.Sum([]byte(appSecret + sortedParamStr))
sign := strings.ToLower(hex.EncodeToString(h[:]))
return sign
}
// GenerateSignFromParamsHMAC 根据入参生成签名:入参按 ASCII 排序组合后与 app_secret 做 HMAC-MD5。
func GenerateSignFromParamsHMAC(appSecret, sortedParamStr string) string {
mac := hmac.New(md5.New, []byte(appSecret))
mac.Write([]byte(sortedParamStr))
sign := strings.ToLower(hex.EncodeToString(mac.Sum(nil)))
return sign
}

View File

@@ -0,0 +1,135 @@
package shujubao
import (
"fmt"
)
// GetQueryEmptyErrByCode 将数据宝错误码归类为“查询为空/不扣费”错误。
// 说明:上游通常依赖 errors.Is(err, ErrQueryEmpty) 来决定是否扣费。
func GetQueryEmptyErrByCode(code string) error {
switch code {
case "10001", "10006":
return ErrQueryEmpty
default:
return nil
}
}
// ShujubaoError 数据宝服务错误
type ShujubaoError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Error 实现 error 接口
func (e *ShujubaoError) Error() string {
return fmt.Sprintf("数据宝错误 [%s]: %s", e.Code, e.Message)
}
// IsSuccess 检查是否成功
func (e *ShujubaoError) IsSuccess() bool {
return e.Code == "200" || e.Code == "0" || e.Code == "10000"
}
// NewShujubaoError 创建新的数据宝错误
func NewShujubaoError(code, message string) *ShujubaoError {
return &ShujubaoError{
Code: code,
Message: message,
}
}
// 数据宝全系统错误码与描述映射Code -> Desc
var systemErrorCodeDesc = map[string]string{
"10000": "成功",
"10001": "查空",
"10002": "查询失败",
"10003": "系统处理异常",
"10004": "系统处理超时",
"10005": "服务异常",
"10006": "查无",
"10017": "查询失败",
"10018": "参数错误",
"10019": "系统异常",
"10020": "同一参数请求次数超限",
"99999": "其他错误",
"999": "接口处理异常",
"000": "key参数不能为空",
"001": "找不到这个key",
"002": "调用次数已用完",
"003": "用户该接口状态不可用",
"004": "接口信息不存在",
"005": "你没有认证信息",
"008": "当前接口只允许“企业认证”通过的账户进行调用,请在数据宝官网个人中心进行企业认证后再进行调用,谢谢!",
"009": "触发风控",
"011": "接口缺少参数",
"012": "没有ip访问权限",
"013": "接口模板不存在",
"015": "该接口已下架",
"020": "调用第三方产生异常",
"022": "调用第三方返回的数据格式错误",
"025": "你没有购买此接口",
"026": "用户信息不存在",
"027": "请求第三方地址超时,请稍后再试",
"028": "请求第三方地址被拒绝,请稍后再试",
"034": "签名不合法",
"035": "请求参数加密有误",
"036": "验签失败",
"037": "timestamp不能为空",
"038": "请求繁忙,请稍后联系管理员再试",
"039": "请在个人中心接口设置加密状态",
"040": "timestamp不合法",
"041": "timestamp已过期",
"042": "身份证手机号姓名银行卡等不符合规则",
"043": "该号段不支持验证",
"047": "请在个人中心获取密钥",
"048": "找不到这个secretKey",
"049": "用户还未申购该产品",
"050": "请联系客服开启验签",
"051": "超过当日调用次数",
"052": "机房限制调用,请联系客服切换其他机房",
"053": "系统错误",
"054": "token无效",
"055": "配置信息未完善,请联系数据宝工作人员",
"056": "apiName参数不能为空",
"057": "并发量超过限制,请联系客服",
"058": "撞库风控预警,请联系客服",
}
// GetSystemErrorDesc 根据错误码获取系统错误描述(支持带 SYSTEM_ 前缀或纯数字)
func GetSystemErrorDesc(code string) string {
// 去掉 SYSTEM_ 前缀
key := code
if len(code) > 7 && code[:7] == "SYSTEM_" {
key = code[7:]
}
if desc, ok := systemErrorCodeDesc[key]; ok {
return desc
}
return ""
}
// NewShujubaoErrorFromCode 根据状态码创建错误
func NewShujubaoErrorFromCode(code, message string) *ShujubaoError {
if message != "" {
return NewShujubaoError(code, message)
}
if desc := GetSystemErrorDesc(code); desc != "" {
return NewShujubaoError(code, desc)
}
return NewShujubaoError(code, "未知错误")
}
// IsShujubaoError 检查是否是数据宝错误
func IsShujubaoError(err error) bool {
_, ok := err.(*ShujubaoError)
return ok
}
// GetShujubaoError 获取数据宝错误
func GetShujubaoError(err error) *ShujubaoError {
if shujubaoErr, ok := err.(*ShujubaoError); ok {
return shujubaoErr
}
return nil
}

View File

@@ -0,0 +1,66 @@
package shujubao
import (
"time"
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewShujubaoServiceWithConfig 使用配置创建数据宝服务
func NewShujubaoServiceWithConfig(cfg *config.Config) (*ShujubaoService, error) {
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Shujubao.Logging.Enabled,
LogDir: cfg.Shujubao.Logging.LogDir,
ServiceName: "shujubao",
UseDaily: cfg.Shujubao.Logging.UseDaily,
EnableLevelSeparation: cfg.Shujubao.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
for k, v := range cfg.Shujubao.Logging.LevelConfigs {
loggingConfig.LevelConfigs[k] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: v.MaxSize,
MaxBackups: v.MaxBackups,
MaxAge: v.MaxAge,
Compress: v.Compress,
}
}
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
var signMethod SignMethod
if cfg.Shujubao.SignMethod == "md5" {
signMethod = SignMethodMD5
} else {
signMethod = SignMethodHMACMD5
}
timeout := 60 * time.Second
if cfg.Shujubao.Timeout > 0 {
timeout = cfg.Shujubao.Timeout
}
return NewShujubaoService(
cfg.Shujubao.URL,
cfg.Shujubao.AppSecret,
signMethod,
timeout,
logger,
), nil
}
// NewShujubaoServiceWithLogging 使用自定义日志配置创建数据宝服务
func NewShujubaoServiceWithLogging(url, appSecret string, signMethod SignMethod, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig) (*ShujubaoService, error) {
loggingConfig.ServiceName = "shujubao"
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
return NewShujubaoService(url, appSecret, signMethod, timeout, logger), nil
}
// NewShujubaoServiceSimple 创建无日志的数据宝服务
func NewShujubaoServiceSimple(url, appSecret string, signMethod SignMethod, timeout time.Duration) *ShujubaoService {
return NewShujubaoService(url, appSecret, signMethod, timeout, nil)
}

View File

@@ -0,0 +1,313 @@
package shujubao
import (
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
"hyapi-server/internal/shared/external_logger"
)
const (
// 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志
maxLogParamValueLen = 300
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
ErrQueryEmpty = errors.New("查询为空")
)
// truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容
func truncateForLog(s string, maxLen int) string {
if maxLen <= 0 {
return s
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]"
}
// paramsForLog 返回适合写入错误日志的入参副本(长字符串会被截断)
func paramsForLog(params map[string]interface{}) map[string]interface{} {
if params == nil {
return nil
}
out := make(map[string]interface{}, len(params))
for k, v := range params {
if v == nil {
out[k] = nil
continue
}
switch val := v.(type) {
case string:
out[k] = truncateForLog(val, maxLogParamValueLen)
default:
s := fmt.Sprint(v)
out[k] = truncateForLog(s, maxLogParamValueLen)
}
}
return out
}
// ShujubaoResp 数据宝 API 通用响应(按实际文档调整)
type ShujubaoResp struct {
Code string `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
Success bool `json:"success"`
}
// ShujubaoConfig 数据宝服务配置
type ShujubaoConfig struct {
URL string
AppSecret string
SignMethod SignMethod
Timeout time.Duration
}
// ShujubaoService 数据宝服务
type ShujubaoService struct {
config ShujubaoConfig
logger *external_logger.ExternalServiceLogger
}
// NewShujubaoService 创建数据宝服务实例
func NewShujubaoService(url, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger) *ShujubaoService {
if signMethod == "" {
signMethod = SignMethodHMACMD5
}
if timeout == 0 {
timeout = 60 * time.Second
}
return &ShujubaoService{
config: ShujubaoConfig{
URL: url,
AppSecret: appSecret,
SignMethod: signMethod,
Timeout: timeout,
},
logger: logger,
}
}
// generateRequestID 生成请求 ID
func (s *ShujubaoService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, s.config.AppSecret)))
return fmt.Sprintf("shujubao_%x", hash[:8])
}
// buildSortedParamStr 将入参按 key 的 ASCII 排序组合为 key1=value1&key2=value2&...
func buildSortedParamStr(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for i, k := range keys {
if i > 0 {
b.WriteByte('&')
}
v := params[k]
var vs string
switch val := v.(type) {
case string:
vs = val
case nil:
vs = ""
default:
vs = fmt.Sprint(val)
}
b.WriteString(k)
b.WriteByte('=')
b.WriteString(vs)
}
return b.String()
}
// buildFormUrlEncodedBody 按 key 的 ASCII 排序构建 application/x-www-form-urlencoded 请求体(键与值均已 URL 编码)
func buildFormUrlEncodedBody(params map[string]interface{}) string {
if len(params) == 0 {
return ""
}
keys := make([]string, 0, len(params))
for k := range params {
keys = append(keys, k)
}
sort.Strings(keys)
var b strings.Builder
for i, k := range keys {
if i > 0 {
b.WriteByte('&')
}
v := params[k]
var vs string
switch val := v.(type) {
case string:
vs = val
case nil:
vs = ""
default:
vs = fmt.Sprint(val)
}
b.WriteString(url.QueryEscape(k))
b.WriteByte('=')
b.WriteString(url.QueryEscape(vs))
}
return b.String()
}
// generateSign 根据入参与时间戳生成签名。入参按 ASCII 排序组合后与 app_secret 做 MD5/HMAC。
// 对于开启了加密的接口需传 sign 与 timestamp明文传输的接口则无需传这两个参数。
func (s *ShujubaoService) generateSign(timestamp string, params map[string]interface{}) string {
// 合并 timestamp 到入参后参与排序
merged := make(map[string]interface{}, len(params)+1)
for k, v := range params {
merged[k] = v
}
merged["timestamp"] = timestamp
sortedStr := buildSortedParamStr(merged)
switch s.config.SignMethod {
case SignMethodMD5:
return GenerateSignFromParamsMD5(s.config.AppSecret, sortedStr)
default:
return GenerateSignFromParamsHMAC(s.config.AppSecret, sortedStr)
}
}
// buildRequestURL 拼接接口地址得到最终请求 URL如 https://api.chinadatapay.com/communication/personal/197
func (s *ShujubaoService) buildRequestURL(apiPath string) string {
base := strings.TrimSuffix(s.config.URL, "/")
if apiPath == "" {
return base
}
return base + "/" + strings.TrimPrefix(apiPath, "/")
}
// CallAPI 调用数据宝 APIPOST。最终请求地址 = url + 拼接接口地址值body 为业务参数sign、timestamp 按原样传 header。
func (s *ShujubaoService) CallAPI(ctx context.Context, apiPath string, params map[string]interface{}) (data interface{}, err error) {
startTime := time.Now()
requestID := s.generateRequestID()
timestamp := strconv.FormatInt(time.Now().Unix(), 10)
// 最终请求 URL = https://api.chinadatapay.com/communication + 拼接接口地址值,如 /personal/197
requestURL := s.buildRequestURL(apiPath)
var transactionID string
if id, ok := ctx.Value("transaction_id").(string); ok {
transactionID = id
}
if s.logger != nil {
s.logger.LogRequest(requestID, transactionID, apiPath, requestURL)
}
// 使用 application/x-www-form-urlencoded贵司接口暂不支持 JSON 入参
formBody := buildFormUrlEncodedBody(params)
req, err := http.NewRequestWithContext(ctx, "POST", requestURL, strings.NewReader(formBody))
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("timestamp", timestamp)
req.Header.Set("sign", s.generateSign(timestamp, params))
client := &http.Client{Timeout: s.config.Timeout}
response, err := client.Do(req)
if err != nil {
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
} else {
err = errors.Join(ErrSystem, err)
}
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
defer response.Body.Close()
respBody, err := io.ReadAll(response.Body)
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
if s.logger != nil {
duration := time.Since(startTime)
s.logger.LogResponse(requestID, transactionID, apiPath, response.StatusCode, duration)
}
if response.StatusCode != http.StatusOK {
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
var shujubaoResp ShujubaoResp
if err := json.Unmarshal(respBody, &shujubaoResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
code := shujubaoResp.Code
// 成功码只有这三类:其它 code 都走统一错误映射返回
if code != "10000" && code != "200" && code != "0" {
shujubaoErr := NewShujubaoErrorFromCode(code, shujubaoResp.Message)
if queryEmptyErr := GetQueryEmptyErrByCode(code); queryEmptyErr != nil {
err = errors.Join(queryEmptyErr, shujubaoErr)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, paramsForLog(params))
}
return nil, err
}
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, shujubaoErr, paramsForLog(params))
}
return nil, errors.Join(ErrDatasource, shujubaoErr)
}
return shujubaoResp.Data, nil
}

View File

@@ -0,0 +1,199 @@
package shumai
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"errors"
"strings"
)
// SignMethod 签名方法
type SignMethod string
const (
SignMethodMD5 SignMethod = "md5"
SignMethodHMACMD5 SignMethod = "hmac"
)
// GenerateSignForm 生成表单接口签名appid & timestamp & app_security
// 拼接规则appid + "&" + timestamp + "&" + app_security对拼接串做 MD532 位小写十六进制;
// 不足 32 位左侧补 0。
func GenerateSignForm(appid, timestamp, appSecret string) string {
str := appid + "&" + timestamp + "&" + appSecret
hash := md5.Sum([]byte(str))
sign := strings.ToLower(hex.EncodeToString(hash[:]))
if n := 32 - len(sign); n > 0 {
sign = strings.Repeat("0", n) + sign
}
return sign
}
// app_secret: "BnJWo61hUgNEa5fqBCueiT1IZ1e0DxPU"
// Encrypt 使用 AES/ECB/PKCS5Padding 加密数据
// 加密算法AES工作模式ECB无初始向量填充方式PKCS5Padding
// 加密 key 是服务商分配的 app_securityAES 加密之后再进行 base64 编码
func Encrypt(data, appSecurity string) (string, error) {
key := prepareAESKey([]byte(appSecurity))
ciphertext, err := aesEncryptECB([]byte(data), key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt 解密 base64 编码的 AES/ECB/PKCS5Padding 加密数据
func Decrypt(encodedData, appSecurity string) ([]byte, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
if err != nil {
return nil, err
}
key := prepareAESKey([]byte(appSecurity))
plaintext, err := aesDecryptECB(ciphertext, key)
if err != nil {
return nil, err
}
return plaintext, nil
}
// prepareAESKey 准备 AES 密钥,确保长度为 16/24/32 字节
// 如果 key 长度不足,用 0 填充;如果过长,截取前 32 字节
func prepareAESKey(key []byte) []byte {
keyLen := len(key)
if keyLen == 16 || keyLen == 24 || keyLen == 32 {
return key
}
if keyLen < 16 {
// 不足 16 字节,用 0 填充到 16 字节AES-128
padded := make([]byte, 16)
copy(padded, key)
return padded
}
if keyLen < 24 {
// 不足 24 字节,用 0 填充到 24 字节AES-192
padded := make([]byte, 24)
copy(padded, key)
return padded
}
if keyLen < 32 {
// 不足 32 字节,用 0 填充到 32 字节AES-256
padded := make([]byte, 32)
copy(padded, key)
return padded
}
// 超过 32 字节,截取前 32 字节AES-256
return key[:32]
}
// aesEncryptECB 使用 AES ECB 模式加密PKCS5 填充
func aesEncryptECB(plaintext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
ciphertext := make([]byte, len(paddedPlaintext))
mode := newECBEncrypter(block)
mode.CryptBlocks(ciphertext, paddedPlaintext)
return ciphertext, nil
}
// aesDecryptECB 使用 AES ECB 模式解密PKCS5 去填充
func aesDecryptECB(ciphertext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(ciphertext)%block.BlockSize() != 0 {
return nil, errors.New("ciphertext length is not a multiple of block size")
}
plaintext := make([]byte, len(ciphertext))
mode := newECBDecrypter(block)
mode.CryptBlocks(plaintext, ciphertext)
return pkcs5Unpadding(plaintext), nil
}
// pkcs5Padding PKCS5 填充
func pkcs5Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
// pkcs5Unpadding 去除 PKCS5 填充
func pkcs5Unpadding(src []byte) []byte {
length := len(src)
if length == 0 {
return src
}
unpadding := int(src[length-1])
if unpadding > length {
return src
}
return src[:length-unpadding]
}
// ECB 模式加密/解密实现
type ecb struct {
b cipher.Block
blockSize int
}
func newECB(b cipher.Block) *ecb {
return &ecb{
b: b,
blockSize: b.BlockSize(),
}
}
type ecbEncrypter ecb
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
func (x *ecbEncrypter) BlockSize() int {
return x.blockSize
}
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Encrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
type ecbDecrypter ecb
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
func (x *ecbDecrypter) BlockSize() int {
return x.blockSize
}
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Decrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}

View File

@@ -0,0 +1,108 @@
package shumai
import (
"fmt"
)
// ShumaiError 数脉服务错误
type ShumaiError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Error 实现 error 接口
func (e *ShumaiError) Error() string {
return fmt.Sprintf("数脉错误 [%s]: %s", e.Code, e.Message)
}
// IsSuccess 是否成功
func (e *ShumaiError) IsSuccess() bool {
return e.Code == "0" || e.Code == "200"
}
// IsNoRecord 是否查无记录
func (e *ShumaiError) IsNoRecord() bool {
return e.Code == "404"
}
// IsParamError 是否参数错误
func (e *ShumaiError) IsParamError() bool {
return e.Code == "400"
}
// IsAuthError 是否认证错误
func (e *ShumaiError) IsAuthError() bool {
return e.Code == "601" || e.Code == "602"
}
// IsSystemError 是否系统错误
func (e *ShumaiError) IsSystemError() bool {
return e.Code == "500" || e.Code == "501"
}
// 预定义错误
var (
ErrSuccess = &ShumaiError{Code: "200", Message: "成功"}
ErrParamError = &ShumaiError{Code: "400", Message: "参数错误"}
ErrNoRecord = &ShumaiError{Code: "404", Message: "请求资源不存在"}
ErrSystemError = &ShumaiError{Code: "500", Message: "系统内部错误,请联系服务商"}
ErrThirdPartyError = &ShumaiError{Code: "501", Message: "第三方服务异常"}
ErrNoPermission = &ShumaiError{Code: "601", Message: "服务商未开通接口权限"}
ErrAccountDisabled = &ShumaiError{Code: "602", Message: "账号停用"}
ErrInsufficientBalance = &ShumaiError{Code: "603", Message: "余额不足请充值"}
ErrInterfaceDisabled = &ShumaiError{Code: "604", Message: "接口停用"}
ErrInsufficientQuota = &ShumaiError{Code: "605", Message: "次数不足,请购买套餐"}
ErrRateLimitExceeded = &ShumaiError{Code: "606", Message: "调用超限,请联系服务商"}
ErrOther = &ShumaiError{Code: "1001", Message: "其他,以实际返回为准"}
)
// NewShumaiError 创建数脉错误
func NewShumaiError(code, message string) *ShumaiError {
return &ShumaiError{Code: code, Message: message}
}
// NewShumaiErrorFromCode 根据状态码创建错误
func NewShumaiErrorFromCode(code string) *ShumaiError {
switch code {
case "0", "200":
return ErrSuccess
case "400":
return ErrParamError
case "404":
return ErrNoRecord
case "500":
return ErrSystemError
case "501":
return ErrThirdPartyError
case "601":
return ErrNoPermission
case "602":
return ErrAccountDisabled
case "603":
return ErrInsufficientBalance
case "604":
return ErrInterfaceDisabled
case "605":
return ErrInsufficientQuota
case "606":
return ErrRateLimitExceeded
case "1001":
return ErrOther
default:
return &ShumaiError{Code: code, Message: "未知错误"}
}
}
// IsShumaiError 是否为数脉错误
func IsShumaiError(err error) bool {
_, ok := err.(*ShumaiError)
return ok
}
// GetShumaiError 获取数脉错误
func GetShumaiError(err error) *ShumaiError {
if e, ok := err.(*ShumaiError); ok {
return e
}
return nil
}

View File

@@ -0,0 +1,69 @@
package shumai
import (
"time"
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewShumaiServiceWithConfig 使用 config 创建数脉服务
func NewShumaiServiceWithConfig(cfg *config.Config) (*ShumaiService, error) {
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Shumai.Logging.Enabled,
LogDir: cfg.Shumai.Logging.LogDir,
ServiceName: "shumai",
UseDaily: cfg.Shumai.Logging.UseDaily,
EnableLevelSeparation: cfg.Shumai.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
for k, v := range cfg.Shumai.Logging.LevelConfigs {
loggingConfig.LevelConfigs[k] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: v.MaxSize,
MaxBackups: v.MaxBackups,
MaxAge: v.MaxAge,
Compress: v.Compress,
}
}
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
var signMethod SignMethod
if cfg.Shumai.SignMethod == "md5" {
signMethod = SignMethodMD5
} else {
signMethod = SignMethodHMACMD5
}
timeout := 60 * time.Second
if cfg.Shumai.Timeout > 0 {
timeout = cfg.Shumai.Timeout
}
return NewShumaiService(
cfg.Shumai.URL,
cfg.Shumai.AppID,
cfg.Shumai.AppSecret,
signMethod,
timeout,
logger,
cfg.Shumai.AppID2, // 走政务接口使用这个
cfg.Shumai.AppSecret2, // 走政务接口使用这个
), nil
}
// NewShumaiServiceWithLogging 使用自定义日志配置创建数脉服务
func NewShumaiServiceWithLogging(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, loggingConfig external_logger.ExternalServiceLoggingConfig, appID2, appSecret2 string) (*ShumaiService, error) {
loggingConfig.ServiceName = "shumai"
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
return NewShumaiService(url, appID, appSecret, signMethod, timeout, logger, appID2, appSecret2), nil
}
// NewShumaiServiceSimple 创建无数脉日志的数脉服务
func NewShumaiServiceSimple(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, appID2, appSecret2 string) *ShumaiService {
return NewShumaiService(url, appID, appSecret, signMethod, timeout, nil, appID2, appSecret2)
}

View File

@@ -0,0 +1,360 @@
package shumai
import (
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"hyapi-server/internal/shared/external_logger"
)
const (
// 错误日志中单条入参值的最大长度,避免 base64 等长内容打满日志
maxLogParamValueLen = 300
// 错误日志中 response_body 的最大长度
maxLogResponseBodyLen = 500
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
ErrNotFound = errors.New("查询为空")
)
// truncateForLog 将字符串截断到指定长度,用于错误日志,避免 base64 等过长内容
func truncateForLog(s string, maxLen int) string {
if maxLen <= 0 {
return s
}
if len(s) <= maxLen {
return s
}
return s[:maxLen] + "...[truncated, total " + strconv.Itoa(len(s)) + " chars]"
}
// requestParamsForLog 返回适合写入错误日志的入参副本(长字符串会被截断)
func requestParamsForLog(reqFormData map[string]interface{}) map[string]interface{} {
if reqFormData == nil {
return nil
}
out := make(map[string]interface{}, len(reqFormData))
for k, v := range reqFormData {
if v == nil {
out[k] = nil
continue
}
switch val := v.(type) {
case string:
out[k] = truncateForLog(val, maxLogParamValueLen)
default:
s := fmt.Sprint(v)
out[k] = truncateForLog(s, maxLogParamValueLen)
}
}
return out
}
// ShumaiResponse 数脉 API 通用响应(占位,按实际文档调整)
type ShumaiResponse struct {
Code int `json:"code"` // 状态码
Msg string `json:"msg"`
Message string `json:"message"`
Data interface{} `json:"data"`
}
// ShumaiConfig 数脉服务配置
type ShumaiConfig struct {
URL string
AppID string
AppSecret string
AppID2 string // 走政务接口使用这个
AppSecret2 string // 走政务接口使用这个
SignMethod SignMethod
Timeout time.Duration
}
// ShumaiService 数脉服务
type ShumaiService struct {
config ShumaiConfig
logger *external_logger.ExternalServiceLogger
useGovernment bool // 是否使用政务接口app_id2
}
// NewShumaiService 创建数脉服务实例
// appID2 和 appSecret2 用于政务接口,如果为空则只使用普通接口
func NewShumaiService(url, appID, appSecret string, signMethod SignMethod, timeout time.Duration, logger *external_logger.ExternalServiceLogger, appID2, appSecret2 string) *ShumaiService {
if signMethod == "" {
signMethod = SignMethodHMACMD5
}
if timeout == 0 {
timeout = 60 * time.Second
}
return &ShumaiService{
config: ShumaiConfig{
URL: url,
AppID: appID,
AppSecret: appSecret,
AppID2: appID2, // 走政务接口使用这个
AppSecret2: appSecret2, // 走政务接口使用这个
SignMethod: signMethod,
Timeout: timeout,
},
logger: logger,
useGovernment: false,
}
}
func (s *ShumaiService) generateRequestID() string {
timestamp := time.Now().UnixNano()
appID := s.getCurrentAppID()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, appID)))
return fmt.Sprintf("shumai_%x", hash[:8])
}
// generateRequestIDWithAppID 根据指定的 AppID 生成请求ID用于不依赖全局状态的情况
func (s *ShumaiService) generateRequestIDWithAppID(appID string) string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, appID)))
return fmt.Sprintf("shumai_%x", hash[:8])
}
// getCurrentAppID 获取当前使用的 AppID
func (s *ShumaiService) getCurrentAppID() string {
if s.useGovernment && s.config.AppID2 != "" {
return s.config.AppID2
}
return s.config.AppID
}
// getCurrentAppSecret 获取当前使用的 AppSecret
func (s *ShumaiService) getCurrentAppSecret() string {
if s.useGovernment && s.config.AppSecret2 != "" {
return s.config.AppSecret2
}
return s.config.AppSecret
}
// UseGovernment 切换到政务接口(使用 app_id2 和 app_secret2
func (s *ShumaiService) UseGovernment() {
s.useGovernment = true
}
// UseNormal 切换到普通接口(使用 app_id 和 app_secret
func (s *ShumaiService) UseNormal() {
s.useGovernment = false
}
// IsUsingGovernment 检查是否正在使用政务接口
func (s *ShumaiService) IsUsingGovernment() bool {
return s.useGovernment
}
// GetConfig 返回当前配置
func (s *ShumaiService) GetConfig() ShumaiConfig {
return s.config
}
// CallAPIForm 以表单方式调用数脉 APIapplication/x-www-form-urlencoded
// 在方法内部将 reqFormData 转为表单:先写入业务参数,再追加 appid、timestamp、sign。
// 签名算法md5(appid&timestamp&app_security)32 位小写,不足补 0。
// useGovernment 可选参数true 表示使用政务接口app_id2false 表示使用实时接口app_id
// 如果未提供参数,则使用全局状态(通过 UseGovernment()/UseNormal() 设置)
func (s *ShumaiService) CallAPIForm(ctx context.Context, apiPath string, reqFormData map[string]interface{}, useGovernment ...bool) ([]byte, error) {
// 确定是否使用政务接口:如果提供了参数则使用参数值,否则使用全局状态
var useGov bool
if len(useGovernment) > 0 {
useGov = useGovernment[0]
} else {
// 未提供参数时,使用全局状态以保持向后兼容
useGov = s.useGovernment
}
startTime := time.Now()
timestamp := strconv.FormatInt(time.Now().UnixMilli(), 10)
// 根据参数选择使用的 AppID 和 AppSecret而不是依赖全局状态
var appID, appSecret string
if useGov && s.config.AppID2 != "" {
appID = s.config.AppID2
appSecret = s.config.AppSecret2
} else {
appID = s.config.AppID
appSecret = s.config.AppSecret
}
// 使用指定的 AppID 生成请求ID
requestID := s.generateRequestIDWithAppID(appID)
sign := GenerateSignForm(appID, timestamp, appSecret)
var transactionID string
if id, ok := ctx.Value("transaction_id").(string); ok {
transactionID = id
}
form := url.Values{}
form.Set("appid", appID)
form.Set("timestamp", timestamp)
form.Set("sign", sign)
for k, v := range reqFormData {
if v == nil {
continue
}
form.Set(k, fmt.Sprint(v))
}
body := form.Encode()
baseURL := strings.TrimSuffix(s.config.URL, "/")
reqURL := baseURL
if apiPath != "" {
reqURL = baseURL + "/" + strings.TrimPrefix(apiPath, "/")
}
if apiPath == "" {
apiPath = "shumai_form"
}
if s.logger != nil {
s.logger.LogRequest(requestID, transactionID, apiPath, reqURL)
}
req, err := http.NewRequestWithContext(ctx, "POST", reqURL, strings.NewReader(body))
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, map[string]interface{}{"request_params": requestParamsForLog(reqFormData)})
}
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: s.config.Timeout}
resp, err := client.Do(req)
if err != nil {
isTimeout := ctx.Err() == context.DeadlineExceeded
if !isTimeout {
if te, ok := err.(interface{ Timeout() bool }); ok && te.Timeout() {
isTimeout = true
}
}
if !isTimeout {
es := err.Error()
if strings.Contains(es, "deadline exceeded") || strings.Contains(es, "timeout") || strings.Contains(es, "canceled") {
isTimeout = true
}
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
} else {
err = errors.Join(ErrSystem, err)
}
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, map[string]interface{}{"request_params": requestParamsForLog(reqFormData)})
}
return nil, err
}
defer resp.Body.Close()
duration := time.Since(startTime)
raw, err := io.ReadAll(resp.Body)
if err != nil {
err = errors.Join(ErrSystem, err)
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, err, map[string]interface{}{"request_params": requestParamsForLog(reqFormData)})
}
return nil, err
}
if resp.StatusCode != http.StatusOK {
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP %d", resp.StatusCode))
if s.logger != nil {
errorPayload := map[string]interface{}{
"request_params": requestParamsForLog(reqFormData),
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
}
s.logger.LogError(requestID, transactionID, apiPath, err, errorPayload)
}
return nil, err
}
if s.logger != nil {
s.logger.LogResponse(requestID, transactionID, apiPath, resp.StatusCode, duration)
}
var shumaiResp ShumaiResponse
if err := json.Unmarshal(raw, &shumaiResp); err != nil {
parseErr := errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, parseErr, map[string]interface{}{
"request_params": requestParamsForLog(reqFormData),
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
})
}
return nil, parseErr
}
codeStr := strconv.Itoa(shumaiResp.Code)
msg := shumaiResp.Msg
if msg == "" {
msg = shumaiResp.Message
}
shumaiErr := NewShumaiErrorFromCode(codeStr)
if !shumaiErr.IsSuccess() {
if shumaiErr.Message == "未知错误" && msg != "" {
shumaiErr = NewShumaiError(codeStr, msg)
}
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, shumaiErr, map[string]interface{}{
"request_params": requestParamsForLog(reqFormData),
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
})
}
if shumaiErr.IsNoRecord() {
return nil, errors.Join(ErrNotFound, shumaiErr)
}
return nil, errors.Join(ErrDatasource, shumaiErr)
}
if shumaiResp.Data == nil {
return []byte("{}"), nil
}
dataBytes, err := json.Marshal(shumaiResp.Data)
if err != nil {
marshalErr := errors.Join(ErrSystem, fmt.Errorf("data 序列化失败: %w", err))
if s.logger != nil {
s.logger.LogError(requestID, transactionID, apiPath, marshalErr, map[string]interface{}{
"request_params": requestParamsForLog(reqFormData),
"response_body": truncateForLog(string(raw), maxLogResponseBodyLen),
})
}
return nil, marshalErr
}
return dataBytes, nil
}
func (s *ShumaiService) Encrypt(data string) (string, error) {
appSecret := s.getCurrentAppSecret()
encryptedValue, err := Encrypt(data, appSecret)
if err != nil {
return "", ErrSystem
}
return encryptedValue, nil
}
func (s *ShumaiService) Decrypt(encodedData string) ([]byte, error) {
appSecret := s.getCurrentAppSecret()
decryptedValue, err := Decrypt(encodedData, appSecret)
if err != nil {
return nil, ErrSystem
}
return decryptedValue, nil
}

View File

@@ -0,0 +1,148 @@
package sms
import (
"context"
"crypto/rand"
"encoding/json"
"fmt"
"math/big"
"time"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
"go.uber.org/zap"
"hyapi-server/internal/config"
)
// AliSMSService 阿里云短信服务
type AliSMSService struct {
client *dysmsapi.Client
config config.SMSConfig
logger *zap.Logger
}
// NewAliSMSService 创建阿里云短信服务
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
if err != nil {
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
}
return &AliSMSService{
client: client,
config: cfg,
logger: logger,
}, nil
}
// SendVerificationCode 发送验证码
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.PhoneNumbers = phone
request.SignName = s.config.SignName
request.TemplateCode = s.config.TemplateCode
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("Failed to send SMS",
zap.String("phone", phone),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Code != "OK" {
s.logger.Error("SMS send failed",
zap.String("phone", phone),
zap.String("code", response.Code),
zap.String("message", response.Message))
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
}
s.logger.Info("SMS sent successfully",
zap.String("phone", phone),
zap.String("bizId", response.BizId))
return nil
}
// SendBalanceAlert 发送余额预警短信(低余额与欠费共用 balance_alert_template_code模板需包含 name、time、money
func (s *AliSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.PhoneNumbers = phone
request.SignName = s.config.SignName
name := "海宇数据用户"
if len(enterpriseName) > 0 && enterpriseName[0] != "" {
name = enterpriseName[0]
}
t := time.Now().Format("2006-01-02 15:04:05")
var money float64
if alertType == "low_balance" {
money = threshold
} else {
money = balance
}
templateCode := s.config.BalanceAlertTemplateCode
if templateCode == "" {
templateCode = "SMS_500565339"
}
tp, err := json.Marshal(struct {
Name string `json:"name"`
Time string `json:"time"`
Money string `json:"money"`
}{Name: name, Time: t, Money: fmt.Sprintf("%.2f", money)})
if err != nil {
return fmt.Errorf("构建短信模板参数失败: %w", err)
}
request.TemplateCode = templateCode
request.TemplateParam = string(tp)
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("发送余额预警短信失败",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Code != "OK" {
s.logger.Error("余额预警短信发送失败",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.String("code", response.Code),
zap.String("message", response.Message))
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
}
s.logger.Info("余额预警短信发送成功",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.String("bizId", response.BizId))
return nil
}
// GenerateCode 生成验证码
func (s *AliSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
max := big.NewInt(int64(pow10(length)))
n, _ := rand.Int(rand.Reader, max)
format := fmt.Sprintf("%%0%dd", length)
return fmt.Sprintf(format, n.Int64())
}
func pow10(n int) int {
result := 1
for i := 0; i < n; i++ {
result *= 10
}
return result
}

View File

@@ -0,0 +1,48 @@
package sms
import (
"context"
"go.uber.org/zap"
)
// MockSMSService 模拟短信服务(用于开发和测试)
type MockSMSService struct {
logger *zap.Logger
}
// NewMockSMSService 创建模拟短信服务
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
return &MockSMSService{
logger: logger,
}
}
// SendVerificationCode 模拟发送验证码
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
s.logger.Info("Mock SMS sent",
zap.String("phone", phone),
zap.String("code", code))
return nil
}
// SendBalanceAlert 模拟余额预警
func (s *MockSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
s.logger.Info("Mock balance alert SMS",
zap.String("phone", phone),
zap.Float64("balance", balance),
zap.String("alert_type", alertType))
return nil
}
// GenerateCode 生成验证码
func (s *MockSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
result := ""
for i := 0; i < length; i++ {
result += "1"
}
return result
}

View File

@@ -0,0 +1,38 @@
package sms
import (
"context"
"fmt"
"strings"
"go.uber.org/zap"
"hyapi-server/internal/config"
)
// SMSSender 短信发送抽象(验证码 + 余额预警),支持阿里云与腾讯云等实现。
type SMSSender interface {
SendVerificationCode(ctx context.Context, phone, code string) error
SendBalanceAlert(ctx context.Context, phone string, balance, threshold float64, alertType string, enterpriseName ...string) error
GenerateCode(length int) string
}
// NewSMSSender 根据 sms.provider 创建实现mock_enabled 时返回模拟发送器。
// provider 为空时默认 tencent。
func NewSMSSender(cfg config.SMSConfig, logger *zap.Logger) (SMSSender, error) {
if cfg.MockEnabled {
return NewMockSMSService(logger), nil
}
p := strings.ToLower(strings.TrimSpace(cfg.Provider))
if p == "" {
p = "tencent"
}
switch p {
case "tencent":
return NewTencentSMSService(cfg, logger)
case "aliyun", "alicloud", "ali":
return NewAliSMSService(cfg, logger)
default:
return nil, fmt.Errorf("不支持的短信服务商: %s支持 aliyun、tencent", cfg.Provider)
}
}

View File

@@ -0,0 +1,187 @@
package sms
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"strings"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
sms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
"go.uber.org/zap"
"hyapi-server/internal/config"
)
// TencentSMSService 腾讯云短信(与 bdqr-server 接入方式一致)
type TencentSMSService struct {
client *sms.Client
cfg config.SMSConfig
logger *zap.Logger
}
// NewTencentSMSService 创建腾讯云短信客户端
func NewTencentSMSService(cfg config.SMSConfig, logger *zap.Logger) (*TencentSMSService, error) {
tc := cfg.TencentCloud
if tc.SecretId == "" || tc.SecretKey == "" {
return nil, fmt.Errorf("腾讯云短信未配置 secret_id / secret_key")
}
credential := common.NewCredential(tc.SecretId, tc.SecretKey)
cpf := profile.NewClientProfile()
cpf.HttpProfile.ReqMethod = "POST"
cpf.HttpProfile.ReqTimeout = 10
cpf.HttpProfile.Endpoint = "sms.tencentcloudapi.com"
if tc.Endpoint != "" {
cpf.HttpProfile.Endpoint = tc.Endpoint
}
region := tc.Region
if region == "" {
region = "ap-guangzhou"
}
client, err := sms.NewClient(credential, region, cpf)
if err != nil {
return nil, fmt.Errorf("创建腾讯云短信客户端失败: %w", err)
}
return &TencentSMSService{
client: client,
cfg: cfg,
logger: logger,
}, nil
}
func normalizeTencentPhone(phone string) string {
if strings.HasPrefix(phone, "+86") {
return phone
}
return "+86" + phone
}
// SendVerificationCode 发送验证码(模板参数为单个验证码,与 bdqr 一致)
func (s *TencentSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
tc := s.cfg.TencentCloud
request := &sms.SendSmsRequest{}
request.SmsSdkAppId = common.StringPtr(tc.SmsSdkAppId)
request.SignName = common.StringPtr(tc.SignName)
request.TemplateId = common.StringPtr(tc.TemplateID)
request.TemplateParamSet = common.StringPtrs([]string{code})
request.PhoneNumberSet = common.StringPtrs([]string{normalizeTencentPhone(phone)})
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("腾讯云短信发送失败",
zap.String("phone", phone),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Response == nil || len(response.Response.SendStatusSet) == 0 {
return fmt.Errorf("腾讯云短信返回空响应")
}
st := response.Response.SendStatusSet[0]
if st.Code == nil || *st.Code != "Ok" {
msg := ""
if st.Message != nil {
msg = *st.Message
}
s.logger.Error("腾讯云短信业务失败",
zap.String("phone", phone),
zap.String("message", msg))
return fmt.Errorf("短信发送失败: %s", msg)
}
s.logger.Info("腾讯云短信发送成功",
zap.String("phone", phone),
zap.String("serial_no", safeStrPtr(st.SerialNo)))
return nil
}
// SendBalanceAlert 发送余额类预警。低余额与欠费使用不同模板(见 low_balance_template_id / arrears_template_id
// 若未分别配置则回退 balance_alert_template_id。除验证码外腾讯云短信按无变量模板发送。
func (s *TencentSMSService) SendBalanceAlert(ctx context.Context, phone string, balance float64, threshold float64, alertType string, enterpriseName ...string) error {
tc := s.cfg.TencentCloud
tplID := resolveTencentBalanceTemplateID(tc, alertType)
if tplID == "" {
return fmt.Errorf("腾讯云余额类短信模板未配置(请设置 sms.tencent_cloud.low_balance_template_id 与 arrears_template_id或回退 balance_alert_template_id")
}
request := &sms.SendSmsRequest{}
request.SmsSdkAppId = common.StringPtr(tc.SmsSdkAppId)
request.SignName = common.StringPtr(tc.SignName)
request.TemplateId = common.StringPtr(tplID)
request.PhoneNumberSet = common.StringPtrs([]string{normalizeTencentPhone(phone)})
response, err := s.client.SendSms(request)
if err != nil {
s.logger.Error("腾讯云余额预警短信失败",
zap.String("phone", phone),
zap.String("alert_type", alertType),
zap.Error(err))
return fmt.Errorf("短信发送失败: %w", err)
}
if response.Response == nil || len(response.Response.SendStatusSet) == 0 {
return fmt.Errorf("腾讯云短信返回空响应")
}
st := response.Response.SendStatusSet[0]
if st.Code == nil || *st.Code != "Ok" {
msg := ""
if st.Message != nil {
msg = *st.Message
}
return fmt.Errorf("短信发送失败: %s", msg)
}
s.logger.Info("腾讯云余额预警短信发送成功",
zap.String("phone", phone),
zap.String("alert_type", alertType))
return nil
}
// GenerateCode 生成数字验证码
func (s *TencentSMSService) GenerateCode(length int) string {
if length <= 0 {
length = 6
}
max := big.NewInt(int64(pow10Tencent(length)))
n, _ := rand.Int(rand.Reader, max)
format := fmt.Sprintf("%%0%dd", length)
return fmt.Sprintf(format, n.Int64())
}
func safeStrPtr(p *string) string {
if p == nil {
return ""
}
return *p
}
func pow10Tencent(n int) int {
result := 1
for i := 0; i < n; i++ {
result *= 10
}
return result
}
func resolveTencentBalanceTemplateID(tc config.TencentSMSConfig, alertType string) string {
switch alertType {
case "low_balance":
if tc.LowBalanceTemplateID != "" {
return tc.LowBalanceTemplateID
}
case "arrears":
if tc.ArrearsTemplateID != "" {
return tc.ArrearsTemplateID
}
}
return tc.BalanceAlertTemplateID
}

View File

@@ -0,0 +1,115 @@
package storage
import (
"context"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"go.uber.org/zap"
)
// LocalFileStorageService 本地文件存储服务
type LocalFileStorageService struct {
basePath string
logger *zap.Logger
}
// LocalFileStorageConfig 本地文件存储配置
type LocalFileStorageConfig struct {
BasePath string `yaml:"base_path"`
}
// NewLocalFileStorageService 创建本地文件存储服务
func NewLocalFileStorageService(basePath string, logger *zap.Logger) *LocalFileStorageService {
// 确保基础路径存在
if err := os.MkdirAll(basePath, 0755); err != nil {
logger.Error("创建基础存储目录失败", zap.Error(err), zap.String("path", basePath))
}
return &LocalFileStorageService{
basePath: basePath,
logger: logger,
}
}
// StoreFile 存储文件
func (s *LocalFileStorageService) StoreFile(ctx context.Context, file io.Reader, filename string) (string, error) {
// 构建完整文件路径
fullPath := filepath.Join(s.basePath, filename)
// 确保目录存在
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
s.logger.Error("创建目录失败", zap.Error(err), zap.String("dir", dir))
return "", fmt.Errorf("创建目录失败: %w", err)
}
// 创建文件
dst, err := os.Create(fullPath)
if err != nil {
s.logger.Error("创建文件失败", zap.Error(err), zap.String("path", fullPath))
return "", fmt.Errorf("创建文件失败: %w", err)
}
defer dst.Close()
// 复制文件内容
if _, err := io.Copy(dst, file); err != nil {
s.logger.Error("写入文件失败", zap.Error(err), zap.String("path", fullPath))
// 删除部分写入的文件
_ = os.Remove(fullPath)
return "", fmt.Errorf("写入文件失败: %w", err)
}
s.logger.Info("文件存储成功", zap.String("path", fullPath))
return fullPath, nil
}
// StoreMultipartFile 存储multipart文件
func (s *LocalFileStorageService) StoreMultipartFile(ctx context.Context, file *multipart.FileHeader, filename string) (string, error) {
src, err := file.Open()
if err != nil {
return "", fmt.Errorf("打开上传文件失败: %w", err)
}
defer src.Close()
return s.StoreFile(ctx, src, filename)
}
// GetFileURL 获取文件URL
func (s *LocalFileStorageService) GetFileURL(ctx context.Context, filePath string) (string, error) {
// 检查文件是否存在
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("文件不存在: %s", filePath)
}
// 返回文件路径在实际应用中这里应该返回可访问的URL
return filePath, nil
}
// DeleteFile 删除文件
func (s *LocalFileStorageService) DeleteFile(ctx context.Context, filePath string) error {
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
// 文件不存在,不视为错误
return nil
}
s.logger.Error("删除文件失败", zap.Error(err), zap.String("path", filePath))
return fmt.Errorf("删除文件失败: %w", err)
}
s.logger.Info("文件删除成功", zap.String("path", filePath))
return nil
}
// GetFileReader 获取文件读取器
func (s *LocalFileStorageService) GetFileReader(ctx context.Context, filePath string) (io.ReadCloser, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("打开文件失败: %w", err)
}
return file, nil
}

View File

@@ -0,0 +1,110 @@
package storage
import (
"context"
"fmt"
"io"
"mime/multipart"
"os"
"path/filepath"
"go.uber.org/zap"
)
// LocalFileStorageServiceImpl 本地文件存储服务实现
type LocalFileStorageServiceImpl struct {
basePath string
logger *zap.Logger
}
// NewLocalFileStorageServiceImpl 创建本地文件存储服务实现
func NewLocalFileStorageServiceImpl(basePath string, logger *zap.Logger) *LocalFileStorageServiceImpl {
// 确保基础路径存在
if err := os.MkdirAll(basePath, 0755); err != nil {
logger.Error("创建基础存储目录失败", zap.Error(err), zap.String("path", basePath))
}
return &LocalFileStorageServiceImpl{
basePath: basePath,
logger: logger,
}
}
// StoreFile 存储文件
func (s *LocalFileStorageServiceImpl) StoreFile(ctx context.Context, file io.Reader, filename string) (string, error) {
// 构建完整文件路径
fullPath := filepath.Join(s.basePath, filename)
// 确保目录存在
dir := filepath.Dir(fullPath)
if err := os.MkdirAll(dir, 0755); err != nil {
s.logger.Error("创建目录失败", zap.Error(err), zap.String("dir", dir))
return "", fmt.Errorf("创建目录失败: %w", err)
}
// 创建文件
dst, err := os.Create(fullPath)
if err != nil {
s.logger.Error("创建文件失败", zap.Error(err), zap.String("path", fullPath))
return "", fmt.Errorf("创建文件失败: %w", err)
}
defer dst.Close()
// 复制文件内容
if _, err := io.Copy(dst, file); err != nil {
s.logger.Error("写入文件失败", zap.Error(err), zap.String("path", fullPath))
// 删除部分写入的文件
_ = os.Remove(fullPath)
return "", fmt.Errorf("写入文件失败: %w", err)
}
s.logger.Info("文件存储成功", zap.String("path", fullPath))
return fullPath, nil
}
// StoreMultipartFile 存储multipart文件
func (s *LocalFileStorageServiceImpl) StoreMultipartFile(ctx context.Context, file *multipart.FileHeader, filename string) (string, error) {
src, err := file.Open()
if err != nil {
return "", fmt.Errorf("打开上传文件失败: %w", err)
}
defer src.Close()
return s.StoreFile(ctx, src, filename)
}
// GetFileURL 获取文件URL
func (s *LocalFileStorageServiceImpl) GetFileURL(ctx context.Context, filePath string) (string, error) {
// 检查文件是否存在
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return "", fmt.Errorf("文件不存在: %s", filePath)
}
// 返回文件路径在实际应用中这里应该返回可访问的URL
return filePath, nil
}
// DeleteFile 删除文件
func (s *LocalFileStorageServiceImpl) DeleteFile(ctx context.Context, filePath string) error {
if err := os.Remove(filePath); err != nil {
if os.IsNotExist(err) {
// 文件不存在,不视为错误
return nil
}
s.logger.Error("删除文件失败", zap.Error(err), zap.String("path", filePath))
return fmt.Errorf("删除文件失败: %w", err)
}
s.logger.Info("文件删除成功", zap.String("path", filePath))
return nil
}
// GetFileReader 获取文件读取器
func (s *LocalFileStorageServiceImpl) GetFileReader(ctx context.Context, filePath string) (io.ReadCloser, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, fmt.Errorf("打开文件失败: %w", err)
}
return file, nil
}

View File

@@ -0,0 +1,353 @@
package storage
import (
"context"
"crypto/hmac"
"crypto/sha1"
"encoding/base64"
"fmt"
"io"
"net/http"
"path/filepath"
"strings"
"time"
"github.com/qiniu/go-sdk/v7/auth/qbox"
"github.com/qiniu/go-sdk/v7/storage"
"go.uber.org/zap"
sharedStorage "hyapi-server/internal/shared/storage"
)
// QiNiuStorageService 七牛云存储服务
type QiNiuStorageService struct {
accessKey string
secretKey string
bucket string
domain string
logger *zap.Logger
mac *qbox.Mac
bucketManager *storage.BucketManager
}
// QiNiuStorageConfig 七牛云存储配置
type QiNiuStorageConfig struct {
AccessKey string `yaml:"access_key"`
SecretKey string `yaml:"secret_key"`
Bucket string `yaml:"bucket"`
Domain string `yaml:"domain"`
}
// NewQiNiuStorageService 创建七牛云存储服务
func NewQiNiuStorageService(accessKey, secretKey, bucket, domain string, logger *zap.Logger) *QiNiuStorageService {
mac := qbox.NewMac(accessKey, secretKey)
// 使用默认配置不需要指定region
cfg := storage.Config{}
bucketManager := storage.NewBucketManager(mac, &cfg)
return &QiNiuStorageService{
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
domain: domain,
logger: logger,
mac: mac,
bucketManager: bucketManager,
}
}
// UploadFile 上传文件到七牛云
func (s *QiNiuStorageService) UploadFile(ctx context.Context, fileBytes []byte, fileName string) (*sharedStorage.UploadResult, error) {
s.logger.Info("开始上传文件到七牛云",
zap.String("file_name", fileName),
zap.Int("file_size", len(fileBytes)),
)
// 生成唯一的文件key
key := s.generateFileKey(fileName)
// 创建上传凭证
putPolicy := storage.PutPolicy{
Scope: s.bucket,
}
upToken := putPolicy.UploadToken(s.mac)
// 配置上传参数
cfg := storage.Config{}
formUploader := storage.NewFormUploader(&cfg)
ret := storage.PutRet{}
// 上传文件
err := formUploader.Put(ctx, &ret, upToken, key, strings.NewReader(string(fileBytes)), int64(len(fileBytes)), &storage.PutExtra{})
if err != nil {
s.logger.Error("文件上传失败",
zap.String("file_name", fileName),
zap.String("key", key),
zap.Error(err),
)
return nil, fmt.Errorf("文件上传失败: %w", err)
}
// 构建文件URL
fileURL := s.GetFileURL(ctx, key)
s.logger.Info("文件上传成功",
zap.String("file_name", fileName),
zap.String("key", key),
zap.String("url", fileURL),
)
return &sharedStorage.UploadResult{
Key: key,
URL: fileURL,
MimeType: s.getMimeType(fileName),
Size: int64(len(fileBytes)),
Hash: ret.Hash,
}, nil
}
// GenerateUploadToken 生成上传凭证
func (s *QiNiuStorageService) GenerateUploadToken(ctx context.Context, key string) (string, error) {
putPolicy := storage.PutPolicy{
Scope: s.bucket,
// 设置过期时间1小时
Expires: uint64(time.Now().Add(time.Hour).Unix()),
}
token := putPolicy.UploadToken(s.mac)
return token, nil
}
// GetFileURL 获取文件访问URL
func (s *QiNiuStorageService) GetFileURL(ctx context.Context, key string) string {
// 如果是私有空间需要生成带签名的URL
if s.isPrivateBucket() {
deadline := time.Now().Add(time.Hour).Unix() // 1小时过期
privateAccessURL := storage.MakePrivateURL(s.mac, s.domain, key, deadline)
return privateAccessURL
}
// 公开空间直接返回URL
return fmt.Sprintf("%s/%s", s.domain, key)
}
// GetPrivateFileURL 获取私有文件访问URL
func (s *QiNiuStorageService) GetPrivateFileURL(ctx context.Context, key string, expires int64) (string, error) {
baseURL := s.GetFileURL(ctx, key)
// TODO: 实际集成七牛云SDK生成私有URL
s.logger.Info("生成七牛云私有文件URL",
zap.String("key", key),
zap.Int64("expires", expires),
)
// 模拟返回私有URL
return fmt.Sprintf("%s?token=mock_private_token&expires=%d", baseURL, expires), nil
}
// DeleteFile 删除文件
func (s *QiNiuStorageService) DeleteFile(ctx context.Context, key string) error {
s.logger.Info("删除七牛云文件", zap.String("key", key))
err := s.bucketManager.Delete(s.bucket, key)
if err != nil {
s.logger.Error("删除文件失败",
zap.String("key", key),
zap.Error(err),
)
return fmt.Errorf("删除文件失败: %w", err)
}
s.logger.Info("文件删除成功", zap.String("key", key))
return nil
}
// FileExists 检查文件是否存在
func (s *QiNiuStorageService) FileExists(ctx context.Context, key string) (bool, error) {
// TODO: 实际集成七牛云SDK检查文件存在性
s.logger.Info("检查七牛云文件存在性", zap.String("key", key))
// 模拟文件存在
return true, nil
}
// GetFileInfo 获取文件信息
func (s *QiNiuStorageService) GetFileInfo(ctx context.Context, key string) (*sharedStorage.FileInfo, error) {
fileInfo, err := s.bucketManager.Stat(s.bucket, key)
if err != nil {
s.logger.Error("获取文件信息失败",
zap.String("key", key),
zap.Error(err),
)
return nil, fmt.Errorf("获取文件信息失败: %w", err)
}
return &sharedStorage.FileInfo{
Key: key,
Size: fileInfo.Fsize,
MimeType: fileInfo.MimeType,
Hash: fileInfo.Hash,
PutTime: fileInfo.PutTime,
}, nil
}
// ListFiles 列出文件
func (s *QiNiuStorageService) ListFiles(ctx context.Context, prefix string, limit int) ([]*sharedStorage.FileInfo, error) {
entries, _, _, hasMore, err := s.bucketManager.ListFiles(s.bucket, prefix, "", "", limit)
if err != nil {
s.logger.Error("列出文件失败",
zap.String("prefix", prefix),
zap.Error(err),
)
return nil, fmt.Errorf("列出文件失败: %w", err)
}
var fileInfos []*sharedStorage.FileInfo
for _, entry := range entries {
fileInfo := &sharedStorage.FileInfo{
Key: entry.Key,
Size: entry.Fsize,
MimeType: entry.MimeType,
Hash: entry.Hash,
PutTime: entry.PutTime,
}
fileInfos = append(fileInfos, fileInfo)
}
_ = hasMore // 暂时忽略hasMore
return fileInfos, nil
}
// generateFileKey 生成文件key
func (s *QiNiuStorageService) generateFileKey(fileName string) string {
// 生成时间戳
timestamp := time.Now().Format("20060102_150405")
// 生成随机字符串
randomStr := fmt.Sprintf("%d", time.Now().UnixNano()%1000000)
// 获取文件扩展名
ext := filepath.Ext(fileName)
// 构建key: 日期/时间戳_随机数.扩展名
key := fmt.Sprintf("certification/%s/%s_%s%s",
time.Now().Format("20060102"), timestamp, randomStr, ext)
return key
}
// getMimeType 根据文件名获取MIME类型
func (s *QiNiuStorageService) getMimeType(fileName string) string {
ext := strings.ToLower(filepath.Ext(fileName))
switch ext {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".pdf":
return "application/pdf"
case ".gif":
return "image/gif"
case ".bmp":
return "image/bmp"
case ".webp":
return "image/webp"
default:
return "application/octet-stream"
}
}
// isPrivateBucket 判断是否为私有空间
func (s *QiNiuStorageService) isPrivateBucket() bool {
// 这里可以根据配置或域名特征判断
// 私有空间的域名通常包含特定标识
return strings.Contains(s.domain, "private") ||
strings.Contains(s.domain, "auth") ||
strings.Contains(s.domain, "secure")
}
// generateSignature 生成签名(用于私有空间访问)
func (s *QiNiuStorageService) generateSignature(data string) string {
h := hmac.New(sha1.New, []byte(s.secretKey))
h.Write([]byte(data))
return base64.URLEncoding.EncodeToString(h.Sum(nil))
}
// UploadFromReader 从Reader上传文件
func (s *QiNiuStorageService) UploadFromReader(ctx context.Context, reader io.Reader, fileName string, fileSize int64) (*sharedStorage.UploadResult, error) {
// 读取文件内容
fileBytes, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("读取文件失败: %w", err)
}
return s.UploadFile(ctx, fileBytes, fileName)
}
// DownloadFile 从七牛云下载文件
func (s *QiNiuStorageService) DownloadFile(ctx context.Context, fileURL string) ([]byte, error) {
s.logger.Info("开始从七牛云下载文件", zap.String("file_url", fileURL))
// 创建HTTP客户端超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
if err != nil {
return nil, fmt.Errorf("创建请求失败: %w", err)
}
// 发送请求
resp, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error();
errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
errorMsg := "下载文件失败"
if isTimeout {
errorMsg = "下载文件超时"
}
s.logger.Error(errorMsg,
zap.String("file_url", fileURL),
zap.Error(err),
)
return nil, fmt.Errorf("%s: %w", errorMsg, err)
}
defer resp.Body.Close()
// 检查响应状态
if resp.StatusCode != http.StatusOK {
s.logger.Error("下载文件失败,状态码异常",
zap.String("file_url", fileURL),
zap.Int("status_code", resp.StatusCode),
)
return nil, fmt.Errorf("下载文件失败,状态码: %d", resp.StatusCode)
}
// 读取文件内容
fileContent, err := io.ReadAll(resp.Body)
if err != nil {
s.logger.Error("读取文件内容失败",
zap.String("file_url", fileURL),
zap.Error(err),
)
return nil, fmt.Errorf("读取文件内容失败: %w", err)
}
s.logger.Info("文件下载成功",
zap.String("file_url", fileURL),
zap.Int("file_size", len(fileContent)),
)
return fileContent, nil
}

View File

@@ -0,0 +1,183 @@
package tianyancha
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrNotFound = errors.New("查询为空")
ErrSystem = errors.New("系统异常")
ErrInvalidParam = errors.New("参数错误")
)
// APIEndpoints 天眼查 API 端点映射
var APIEndpoints = map[string]string{
"VerifyThreeElements": "/open/ic/verify/2.0", // 企业三要素验证
"InvestHistory": "/open/hi/invest/2.0", // 对外投资历史
"FinancingHistory": "/open/cd/findHistoryRongzi/2.0", // 融资历史
"PunishmentInfo": "/open/mr/punishmentInfo/3.0", // 行政处罚
"AbnormalInfo": "/open/mr/abnormal/2.0", // 经营异常
"OwnTax": "/open/mr/ownTax/2.0", // 欠税公告
"TaxContravention": "/open/mr/taxContravention/2.0", // 税收违法
"holderChange": "/open/ic/holderChange/2.0", // 股权变更
"baseinfo": "/open/ic/baseinfo/normal", // 企业基本信息
"investtree": "/v3/open/investtree", // 股权穿透
}
// TianYanChaConfig 天眼查配置
type TianYanChaConfig struct {
BaseURL string
Token string
Timeout time.Duration
}
// TianYanChaService 天眼查服务
type TianYanChaService struct {
config TianYanChaConfig
}
// APIResponse 标准API响应结构
type APIResponse struct {
Success bool `json:"success"`
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
}
// TianYanChaResponse 天眼查原始响应结构
type TianYanChaResponse struct {
ErrorCode int `json:"error_code"`
Reason string `json:"reason"`
Result interface{} `json:"result"`
}
// NewTianYanChaService 创建天眼查服务实例
func NewTianYanChaService(baseURL, token string, timeout time.Duration) *TianYanChaService {
if timeout == 0 {
timeout = 60 * time.Second
}
return &TianYanChaService{
config: TianYanChaConfig{
BaseURL: baseURL,
Token: token,
Timeout: timeout,
},
}
}
// CallAPI 调用天眼查API - 通用方法,由外部处理器传入具体参数
func (t *TianYanChaService) CallAPI(ctx context.Context, apiCode string, params map[string]string) (*APIResponse, error) {
// 从映射中获取 API 端点
endpoint, exists := APIEndpoints[apiCode]
if !exists {
return nil, errors.Join(ErrInvalidParam, fmt.Errorf("未找到 API 代码对应的端点: %s", apiCode))
}
// 构建完整 URL
fullURL := strings.TrimRight(t.config.BaseURL, "/") + "/" + strings.TrimLeft(endpoint, "/")
// 检查 Token 是否配置
if t.config.Token == "" {
return nil, errors.Join(ErrSystem, fmt.Errorf("天眼查 API Token 未配置"))
}
// 构建查询参数
queryParams := url.Values{}
for key, value := range params {
queryParams.Set(key, value)
}
// 构建完整URL
requestURL := fullURL
if len(queryParams) > 0 {
requestURL += "?" + queryParams.Encode()
}
// 创建请求
req, err := http.NewRequestWithContext(ctx, "GET", requestURL, nil)
if err != nil {
return nil, errors.Join(ErrSystem, fmt.Errorf("创建请求失败: %v", err))
}
// 设置请求头
req.Header.Set("Authorization", t.config.Token)
// 发送请求
client := &http.Client{Timeout: t.config.Timeout}
resp, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
return nil, errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
}
return nil, errors.Join(ErrDatasource, fmt.Errorf("API 请求异常: %v", err))
}
defer resp.Body.Close()
// 检查 HTTP 状态码
if resp.StatusCode != http.StatusOK {
return nil, errors.Join(ErrDatasource, fmt.Errorf("API 请求失败,状态码: %d", resp.StatusCode))
}
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Join(ErrSystem, fmt.Errorf("读取响应体失败: %v", err))
}
// 解析 JSON 响应
var tianYanChaResp TianYanChaResponse
if err := json.Unmarshal(body, &tianYanChaResp); err != nil {
return nil, errors.Join(ErrSystem, fmt.Errorf("解析响应 JSON 失败: %v", err))
}
// 检查天眼查业务状态码
if tianYanChaResp.ErrorCode != 0 {
// 特殊处理ErrorCode 300000 表示查询为空,返回成功但数据为空数组
if tianYanChaResp.ErrorCode == 300000 {
return &APIResponse{
Success: true,
Code: 0,
Message: "",
Data: []interface{}{}, // 返回空数组而不是nil
}, nil
}
return &APIResponse{
Success: false,
Code: tianYanChaResp.ErrorCode,
Message: tianYanChaResp.Reason,
Data: tianYanChaResp.Result,
}, nil
}
// 成功情况
return &APIResponse{
Success: true,
Code: 0,
Message: tianYanChaResp.Reason,
Data: tianYanChaResp.Result,
}, nil
}

View File

@@ -0,0 +1,160 @@
package westdex
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/sha1"
"encoding/base64"
"encoding/hex"
)
const (
KEY_SIZE = 16 // AES-128, 16 bytes
)
// Encrypt encrypts the given data using AES encryption in ECB mode with PKCS5 padding
func Encrypt(data, secretKey string) (string, error) {
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
ciphertext, err := aesEncrypt([]byte(data), key)
if err != nil {
return "", err
}
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt decrypts the given base64-encoded string using AES encryption in ECB mode with PKCS5 padding
func Decrypt(encodedData, secretKey string) ([]byte, error) {
ciphertext, err := base64.StdEncoding.DecodeString(encodedData)
if err != nil {
return nil, err
}
key := generateAESKey(KEY_SIZE*8, []byte(secretKey))
plaintext, err := aesDecrypt(ciphertext, key)
if err != nil {
return nil, err
}
return plaintext, nil
}
// generateAESKey generates a key for AES encryption using a SHA-1 based PRNG
func generateAESKey(length int, password []byte) []byte {
h := sha1.New()
h.Write(password)
state := h.Sum(nil)
keyBytes := make([]byte, 0, length/8)
for len(keyBytes) < length/8 {
h := sha1.New()
h.Write(state)
state = h.Sum(nil)
keyBytes = append(keyBytes, state...)
}
return keyBytes[:length/8]
}
// aesEncrypt encrypts plaintext using AES in ECB mode with PKCS5 padding
func aesEncrypt(plaintext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
paddedPlaintext := pkcs5Padding(plaintext, block.BlockSize())
ciphertext := make([]byte, len(paddedPlaintext))
mode := newECBEncrypter(block)
mode.CryptBlocks(ciphertext, paddedPlaintext)
return ciphertext, nil
}
// aesDecrypt decrypts ciphertext using AES in ECB mode with PKCS5 padding
func aesDecrypt(ciphertext, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
plaintext := make([]byte, len(ciphertext))
mode := newECBDecrypter(block)
mode.CryptBlocks(plaintext, ciphertext)
return pkcs5Unpadding(plaintext), nil
}
// pkcs5Padding pads the input to a multiple of the block size using PKCS5 padding
func pkcs5Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
// pkcs5Unpadding removes PKCS5 padding from the input
func pkcs5Unpadding(src []byte) []byte {
length := len(src)
unpadding := int(src[length-1])
return src[:(length - unpadding)]
}
// ECB mode encryption/decryption
type ecb struct {
b cipher.Block
blockSize int
}
func newECB(b cipher.Block) *ecb {
return &ecb{
b: b,
blockSize: b.BlockSize(),
}
}
type ecbEncrypter ecb
func newECBEncrypter(b cipher.Block) cipher.BlockMode {
return (*ecbEncrypter)(newECB(b))
}
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Encrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
type ecbDecrypter ecb
func newECBDecrypter(b cipher.Block) cipher.BlockMode {
return (*ecbDecrypter)(newECB(b))
}
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("crypto/cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("crypto/cipher: output smaller than input")
}
for len(src) > 0 {
x.b.Decrypt(dst, src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
// Md5Encrypt 用于对传入的message进行MD5加密
func Md5Encrypt(message string) string {
hash := md5.New()
hash.Write([]byte(message)) // 将字符串转换为字节切片并写入
return hex.EncodeToString(hash.Sum(nil)) // 将哈希值转换为16进制字符串并返回
}

View File

@@ -0,0 +1,63 @@
package westdex
import (
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewWestDexServiceWithConfig 使用配置创建西部数据服务
func NewWestDexServiceWithConfig(cfg *config.Config) (*WestDexService, error) {
// 将配置类型转换为通用外部服务日志配置
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.WestDex.Logging.Enabled,
LogDir: cfg.WestDex.Logging.LogDir,
ServiceName: "westdex",
UseDaily: cfg.WestDex.Logging.UseDaily,
EnableLevelSeparation: cfg.WestDex.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
// 转换级别配置
for key, value := range cfg.WestDex.Logging.LevelConfigs {
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: value.MaxSize,
MaxBackups: value.MaxBackups,
MaxAge: value.MaxAge,
Compress: value.Compress,
}
}
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建西部数据服务
service := NewWestDexService(
cfg.WestDex.URL,
cfg.WestDex.Key,
cfg.WestDex.SecretID,
cfg.WestDex.SecretSecondID,
logger,
)
return service, nil
}
// NewWestDexServiceWithLogging 使用自定义日志配置创建西部数据服务
func NewWestDexServiceWithLogging(url, key, secretID, secretSecondID string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*WestDexService, error) {
// 设置服务名称
loggingConfig.ServiceName = "westdex"
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建西部数据服务
service := NewWestDexService(url, key, secretID, secretSecondID, logger)
return service, nil
}

View File

@@ -0,0 +1,418 @@
package westdex
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"time"
"hyapi-server/internal/shared/crypto"
"hyapi-server/internal/shared/external_logger"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
ErrNotFound = errors.New("查询为空")
)
type WestResp struct {
Message string `json:"message"`
Code string `json:"code"`
Data string `json:"data"`
ID string `json:"id"`
ErrorCode *int `json:"error_code"`
Reason string `json:"reason"`
}
type G05HZ01WestResp struct {
Message string `json:"message"`
Code string `json:"code"`
Data json.RawMessage `json:"data"`
ID string `json:"id"`
ErrorCode *int `json:"error_code"`
Reason string `json:"reason"`
}
type WestConfig struct {
Url string
Key string
SecretID string
SecretSecondID string
}
type WestDexService struct {
config WestConfig
logger *external_logger.ExternalServiceLogger
}
// NewWestDexService 是一个构造函数,用于初始化 WestDexService
func NewWestDexService(url, key, secretID, secretSecondID string, logger *external_logger.ExternalServiceLogger) *WestDexService {
return &WestDexService{
config: WestConfig{
Url: url,
Key: key,
SecretID: secretID,
SecretSecondID: secretSecondID,
},
logger: logger,
}
}
// generateRequestID 生成请求ID
func (w *WestDexService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, w.config.Key)))
return fmt.Sprintf("westdex_%x", hash[:8])
}
// buildRequestURL 构建请求URL
func (w *WestDexService) buildRequestURL(code string) string {
timestamp := strconv.FormatInt(time.Now().UnixNano()/int64(time.Millisecond), 10)
return fmt.Sprintf("%s/%s/%s?timestamp=%s", w.config.Url, w.config.SecretID, code, timestamp)
}
// CallAPI 调用西部数据的 API
func (w *WestDexService) CallAPI(ctx context.Context, code string, reqData map[string]interface{}) (resp []byte, err error) {
startTime := time.Now()
requestID := w.generateRequestID()
// 从ctx中获取transactionId
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
// 构建请求URL
reqUrl := w.buildRequestURL(code)
// 记录请求日志
if w.logger != nil {
w.logger.LogRequest(requestID, transactionID, code, reqUrl)
}
jsonData, marshalErr := json.Marshal(reqData)
if marshalErr != nil {
err = errors.Join(ErrSystem, marshalErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 创建HTTP POST请求
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", reqUrl, bytes.NewBuffer(jsonData))
if newRequestErr != nil {
err = errors.Join(ErrSystem, newRequestErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
// 发送请求超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
httpResp, clientDoErr := client.Do(req)
if clientDoErr != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
} else {
err = errors.Join(ErrSystem, clientDoErr)
}
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
defer func(Body io.ReadCloser) {
closeErr := Body.Close()
if closeErr != nil {
// 记录关闭错误
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), reqData)
}
}
}(httpResp.Body)
// 计算请求耗时
duration := time.Since(startTime)
// 检查请求是否成功
if httpResp.StatusCode == 200 {
// 读取响应体
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
if ReadErr != nil {
err = errors.Join(ErrSystem, ReadErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 手动调用 json.Unmarshal 触发自定义的 UnmarshalJSON 方法
var westDexResp WestResp
UnmarshalErr := json.Unmarshal(bodyBytes, &westDexResp)
if UnmarshalErr != nil {
err = errors.Join(ErrSystem, UnmarshalErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 记录响应日志(不记录具体响应数据)
if w.logger != nil {
w.logger.LogResponseWithID(requestID, transactionID, code, httpResp.StatusCode, duration, westDexResp.ID)
}
if westDexResp.Code != "00000" && westDexResp.Code != "200" && westDexResp.Code != "0" {
if westDexResp.Data == "" {
err = errors.Join(ErrSystem, fmt.Errorf(westDexResp.Message))
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
}
return nil, err
}
decryptedData, DecryptErr := crypto.WestDexDecrypt(westDexResp.Data, w.config.Key)
if DecryptErr != nil {
err = errors.Join(ErrSystem, DecryptErr)
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
}
return nil, err
}
// 记录业务错误日志包含响应ID
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, errors.Join(ErrDatasource, fmt.Errorf(westDexResp.Message)), reqData, westDexResp.ID)
}
// 记录性能日志(失败)
// 注意:通用日志系统不包含性能日志功能
return decryptedData, errors.Join(ErrDatasource, fmt.Errorf(westDexResp.Message))
}
if westDexResp.Data == "" {
err = errors.Join(ErrSystem, fmt.Errorf(westDexResp.Message))
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
}
return nil, err
}
decryptedData, DecryptErr := crypto.WestDexDecrypt(westDexResp.Data, w.config.Key)
if DecryptErr != nil {
err = errors.Join(ErrSystem, DecryptErr)
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
}
return nil, err
}
// 记录性能日志(成功)
// 注意:通用日志系统不包含性能日志功能
return decryptedData, nil
}
// 记录HTTP错误
err = errors.Join(ErrSystem, fmt.Errorf("西部请求失败Code: %d", httpResp.StatusCode))
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
// 注意:通用日志系统不包含性能日志功能
}
return nil, err
}
// G05HZ01CallAPI 调用西部数据的 G05HZ01 API
func (w *WestDexService) G05HZ01CallAPI(ctx context.Context, code string, reqData map[string]interface{}) (resp []byte, err error) {
startTime := time.Now()
requestID := w.generateRequestID()
// 从ctx中获取transactionId
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
// 构建请求URL
reqUrl := fmt.Sprintf("%s/%s/%s?timestamp=%d", w.config.Url, w.config.SecretSecondID, code, time.Now().UnixNano()/int64(time.Millisecond))
// 记录请求日志
if w.logger != nil {
w.logger.LogRequest(requestID, transactionID, code, reqUrl)
}
jsonData, marshalErr := json.Marshal(reqData)
if marshalErr != nil {
err = errors.Join(ErrSystem, marshalErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 创建HTTP POST请求
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", reqUrl, bytes.NewBuffer(jsonData))
if newRequestErr != nil {
err = errors.Join(ErrSystem, newRequestErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
// 发送请求超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
httpResp, clientDoErr := client.Do(req)
if clientDoErr != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
} else {
err = errors.Join(ErrSystem, clientDoErr)
}
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
defer func(Body io.ReadCloser) {
closeErr := Body.Close()
if closeErr != nil {
// 记录关闭错误
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), reqData)
}
}
}(httpResp.Body)
// 计算请求耗时
duration := time.Since(startTime)
if httpResp.StatusCode == 200 {
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
if ReadErr != nil {
err = errors.Join(ErrSystem, ReadErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
var westDexResp G05HZ01WestResp
UnmarshalErr := json.Unmarshal(bodyBytes, &westDexResp)
if UnmarshalErr != nil {
err = errors.Join(ErrSystem, UnmarshalErr)
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
}
return nil, err
}
// 记录响应日志(不记录具体响应数据)
if w.logger != nil {
w.logger.LogResponseWithID(requestID, transactionID, code, httpResp.StatusCode, duration, westDexResp.ID)
}
if westDexResp.Code != "0000" {
if westDexResp.Data == nil || westDexResp.Code == "1404" {
err = errors.Join(ErrNotFound, fmt.Errorf(westDexResp.Message))
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
}
return nil, err
} else {
// 记录业务错误日志包含响应ID
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, errors.Join(ErrSystem, fmt.Errorf(string(westDexResp.Data))), reqData, westDexResp.ID)
}
// 记录性能日志(失败)
// 注意:通用日志系统不包含性能日志功能
return westDexResp.Data, errors.Join(ErrSystem, fmt.Errorf(string(westDexResp.Data)))
}
}
if westDexResp.Data == nil {
err = errors.Join(ErrSystem, fmt.Errorf(westDexResp.Message))
if w.logger != nil {
w.logger.LogErrorWithResponseID(requestID, transactionID, code, err, reqData, westDexResp.ID)
}
return nil, err
}
// 记录性能日志(成功)
// 注意:通用日志系统不包含性能日志功能
return westDexResp.Data, nil
} else {
// 记录HTTP错误
err = errors.Join(ErrSystem, fmt.Errorf("西部请求失败Code: %d", httpResp.StatusCode))
if w.logger != nil {
w.logger.LogError(requestID, transactionID, code, err, reqData)
// 注意:通用日志系统不包含性能日志功能
}
return nil, err
}
}
func (w *WestDexService) Encrypt(data string) (string, error) {
encryptedValue, err := crypto.WestDexEncrypt(data, w.config.Key)
if err != nil {
return "", ErrSystem
}
return encryptedValue, nil
}
func (w *WestDexService) Md5Encrypt(data string) string {
result := Md5Encrypt(data)
return result
}
func (w *WestDexService) GetConfig() WestConfig {
return w.config
}

View File

@@ -0,0 +1,62 @@
package xingwei
import (
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewXingweiServiceWithConfig 使用配置创建行为数据服务
func NewXingweiServiceWithConfig(cfg *config.Config) (*XingweiService, error) {
// 将配置类型转换为通用外部服务日志配置
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Xingwei.Logging.Enabled,
LogDir: cfg.Xingwei.Logging.LogDir,
ServiceName: "xingwei",
UseDaily: cfg.Xingwei.Logging.UseDaily,
EnableLevelSeparation: cfg.Xingwei.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
// 转换级别配置
for key, value := range cfg.Xingwei.Logging.LevelConfigs {
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: value.MaxSize,
MaxBackups: value.MaxBackups,
MaxAge: value.MaxAge,
Compress: value.Compress,
}
}
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建行为数据服务
service := NewXingweiService(
cfg.Xingwei.URL,
cfg.Xingwei.ApiID,
cfg.Xingwei.ApiKey,
logger,
)
return service, nil
}
// NewXingweiServiceWithLogging 使用自定义日志配置创建行为数据服务
func NewXingweiServiceWithLogging(url, apiID, apiKey string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*XingweiService, error) {
// 设置服务名称
loggingConfig.ServiceName = "xingwei"
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建行为数据服务
service := NewXingweiService(url, apiID, apiKey, logger)
return service, nil
}

View File

@@ -0,0 +1,296 @@
package xingwei
import (
"bytes"
"context"
"crypto/md5"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"sync/atomic"
"time"
"hyapi-server/internal/shared/external_logger"
)
// 行为数据API状态码常量
const (
CodeSuccess = 200 // 操作成功
CodeSystemError = 500 // 系统内部错误
CodeMerchantError = 3001 // 商家相关报错(商家不存在、商家被禁用、商家余额不足)
CodeAccountExpired = 3002 // 账户已过期
CodeIPWhitelistMissing = 3003 // 未添加ip白名单
CodeUnauthorized = 3004 // 未授权调用该接口
CodeProductIDError = 4001 // 产品id错误
CodeInterfaceDisabled = 4002 // 接口被停用
CodeQueryException = 5001 // 接口查询异常,请联系技术人员
CodeNotFound = 6000 // 未查询到结果
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
ErrNotFound = errors.New("未查询到结果")
// 请求ID计数器确保唯一性
requestIDCounter int64
)
// XingweiResponse 行为数据API响应结构
type XingweiResponse struct {
Msg string `json:"msg"`
Code int `json:"code"`
Data interface{} `json:"data"`
}
// XingweiErrorCode 行为数据错误码定义
type XingweiErrorCode struct {
Code int
Message string
}
// 行为数据错误码映射
var XingweiErrorCodes = map[int]XingweiErrorCode{
CodeSuccess: {Code: CodeSuccess, Message: "操作成功"},
CodeSystemError: {Code: CodeSystemError, Message: "系统内部错误"},
CodeMerchantError: {Code: CodeMerchantError, Message: "商家相关报错(商家不存在、商家被禁用、商家余额不足)"},
CodeAccountExpired: {Code: CodeAccountExpired, Message: "账户已过期"},
CodeIPWhitelistMissing: {Code: CodeIPWhitelistMissing, Message: "未添加ip白名单"},
CodeUnauthorized: {Code: CodeUnauthorized, Message: "未授权调用该接口"},
CodeProductIDError: {Code: CodeProductIDError, Message: "产品id错误"},
CodeInterfaceDisabled: {Code: CodeInterfaceDisabled, Message: "接口被停用"},
CodeQueryException: {Code: CodeQueryException, Message: "接口查询异常,请联系技术人员"},
CodeNotFound: {Code: CodeNotFound, Message: "未查询到结果"},
}
// GetXingweiErrorMessage 根据错误码获取错误消息
func GetXingweiErrorMessage(code int) string {
if errorCode, exists := XingweiErrorCodes[code]; exists {
return errorCode.Message
}
return fmt.Sprintf("未知错误码: %d", code)
}
type XingweiConfig struct {
URL string
ApiID string
ApiKey string
}
type XingweiService struct {
config XingweiConfig
logger *external_logger.ExternalServiceLogger
}
// NewXingweiService 是一个构造函数,用于初始化 XingweiService
func NewXingweiService(url, apiID, apiKey string, logger *external_logger.ExternalServiceLogger) *XingweiService {
return &XingweiService{
config: XingweiConfig{
URL: url,
ApiID: apiID,
ApiKey: apiKey,
},
logger: logger,
}
}
// generateRequestID 生成请求ID
func (x *XingweiService) generateRequestID() string {
timestamp := time.Now().UnixNano()
// 使用原子计数器确保唯一性
counter := atomic.AddInt64(&requestIDCounter, 1)
hash := md5.Sum([]byte(fmt.Sprintf("%d_%d_%s", timestamp, counter, x.config.ApiID)))
return fmt.Sprintf("xingwei_%x", hash[:8])
}
// createSign 创建签名使用MD5算法将apiId、timestamp、apiKey字符串拼接生成sign
// 参考Java示例DigestUtils.md5Hex(apiId + timestamp + apiKey)
func (x *XingweiService) createSign(timestamp int64) string {
signStr := x.config.ApiID + strconv.FormatInt(timestamp, 10) + x.config.ApiKey
hash := md5.Sum([]byte(signStr))
return fmt.Sprintf("%x", hash)
}
// CallAPI 调用行为数据的 API
func (x *XingweiService) CallAPI(ctx context.Context, projectID string, params map[string]interface{}) (resp []byte, err error) {
startTime := time.Now()
requestID := x.generateRequestID()
timestamp := time.Now().UnixMilli()
// 从ctx中获取transactionId
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
// 记录请求日志
if x.logger != nil {
x.logger.LogRequest(requestID, transactionID, "xingwei_api", x.config.URL)
}
// 将请求参数转换为JSON
jsonData, marshalErr := json.Marshal(params)
if marshalErr != nil {
err = errors.Join(ErrSystem, marshalErr)
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
// 创建HTTP POST请求
req, newRequestErr := http.NewRequestWithContext(ctx, "POST", x.config.URL, bytes.NewBuffer(jsonData))
if newRequestErr != nil {
err = errors.Join(ErrSystem, newRequestErr)
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("timestamp", strconv.FormatInt(timestamp, 10))
req.Header.Set("sign", x.createSign(timestamp))
req.Header.Set("API-ID", x.config.ApiID)
req.Header.Set("project_id", projectID)
// 创建HTTP客户端超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
// 发送请求
httpResp, clientDoErr := client.Do(req)
if clientDoErr != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := clientDoErr.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := clientDoErr.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", clientDoErr))
} else {
err = errors.Join(ErrSystem, clientDoErr)
}
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
defer func(Body io.ReadCloser) {
closeErr := Body.Close()
if closeErr != nil {
// 记录关闭错误
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", errors.Join(ErrSystem, fmt.Errorf("关闭响应体失败: %w", closeErr)), params)
}
}
}(httpResp.Body)
// 计算请求耗时
duration := time.Since(startTime)
// 读取响应体
bodyBytes, ReadErr := io.ReadAll(httpResp.Body)
if ReadErr != nil {
err = errors.Join(ErrSystem, ReadErr)
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
// 记录响应日志(不记录具体响应数据)
if x.logger != nil {
x.logger.LogResponse(requestID, transactionID, "xingwei_api", httpResp.StatusCode, duration)
}
// 检查HTTP状态码
if httpResp.StatusCode != http.StatusOK {
err = errors.Join(ErrSystem, fmt.Errorf("行为数据请求失败,状态码: %d", httpResp.StatusCode))
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
// 解析响应结构
var xingweiResp XingweiResponse
if err := json.Unmarshal(bodyBytes, &xingweiResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %w", err))
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
// 检查业务状态码
switch xingweiResp.Code {
case CodeSuccess:
// 成功响应返回data字段
if xingweiResp.Data == nil {
return []byte("{}"), nil
}
// 将data转换为JSON字节
dataBytes, err := json.Marshal(xingweiResp.Data)
if err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("data字段序列化失败: %w", err))
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api", err, params)
}
return nil, err
}
return dataBytes, nil
case CodeNotFound:
// 未查询到结果,返回空数组
if x.logger != nil {
// 这里只记录有响应,不记录具体返回内容
x.logger.LogResponse(requestID, transactionID, "xingwei_api", httpResp.StatusCode, duration)
}
return []byte("[]"), nil
case CodeSystemError:
// 系统内部错误
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
systemErr := fmt.Errorf("行为数据系统错误[%d]: %s", xingweiResp.Code, errorMsg)
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api",
errors.Join(ErrSystem, systemErr), params)
}
return nil, errors.Join(ErrSystem, systemErr)
default:
// 其他业务错误
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
businessErr := fmt.Errorf("行为数据业务错误[%d]: %s", xingweiResp.Code, errorMsg)
if x.logger != nil {
x.logger.LogError(requestID, transactionID, "xingwei_api",
errors.Join(ErrDatasource, businessErr), params)
}
return nil, errors.Join(ErrDatasource, businessErr)
}
}
// GetConfig 获取配置信息
func (x *XingweiService) GetConfig() XingweiConfig {
return x.config
}

View File

@@ -0,0 +1,241 @@
package xingwei
import (
"context"
"encoding/json"
"testing"
)
func TestXingweiService_CreateSign(t *testing.T) {
// 创建测试配置 - 使用nil logger来避免日志问题
service := NewXingweiService(
"https://sjztyh.chengdaoji.cn/dataCenterManageApi/manage/interface/doc/api/handle",
"test_api_id",
"test_api_key",
nil, // 使用nil logger
)
// 测试签名生成
timestamp := int64(1743474772049)
sign := service.createSign(timestamp)
// 验证签名不为空
if sign == "" {
t.Error("签名不能为空")
}
// 验证签名长度MD5应该是32位十六进制字符串
if len(sign) != 32 {
t.Errorf("签名长度应该是32位实际是%d位", len(sign))
}
t.Logf("生成的签名: %s", sign)
}
func TestXingweiService_CallAPI(t *testing.T) {
// 创建测试配置 - 使用nil logger来避免日志问题
service := NewXingweiService(
"https://sjztyh.chengdaoji.cn/dataCenterManageApi/manage/interface/doc/api/handle",
"test_api_id",
"test_api_key",
nil, // 使用nil logger
)
// 创建测试上下文
ctx := context.Background()
// 测试参数
projectID := "test_project_id"
params := map[string]interface{}{
"test_param": "test_value",
}
// 注意这个测试会实际发送HTTP请求所以可能会失败
// 在实际使用中应该使用mock或者测试服务器
resp, err := service.CallAPI(ctx, projectID, params)
// 由于这是真实的外部API调用我们主要测试错误处理
if err != nil {
t.Logf("预期的错误真实API调用: %v", err)
} else {
t.Logf("API调用成功响应长度: %d", len(resp))
}
}
func TestXingweiService_GenerateRequestID(t *testing.T) {
// 创建测试配置 - 使用nil logger来避免日志问题
service := NewXingweiService(
"https://sjztyh.chengdaoji.cn/dataCenterManageApi/manage/interface/doc/api/handle",
"test_api_id",
"test_api_key",
nil, // 使用nil logger
)
// 测试请求ID生成
requestID1 := service.generateRequestID()
requestID2 := service.generateRequestID()
// 验证请求ID不为空
if requestID1 == "" || requestID2 == "" {
t.Error("请求ID不能为空")
}
// 验证请求ID应该以xingwei_开头
if len(requestID1) < 8 || requestID1[:8] != "xingwei_" {
t.Error("请求ID应该以xingwei_开头")
}
// 验证两次生成的请求ID应该不同
if requestID1 == requestID2 {
t.Error("两次生成的请求ID应该不同")
}
t.Logf("请求ID1: %s", requestID1)
t.Logf("请求ID2: %s", requestID2)
}
func TestGetXingweiErrorMessage(t *testing.T) {
// 测试已知错误码(使用常量)
testCases := []struct {
code int
expected string
}{
{CodeSuccess, "操作成功"},
{CodeSystemError, "系统内部错误"},
{CodeMerchantError, "商家相关报错(商家不存在、商家被禁用、商家余额不足)"},
{CodeAccountExpired, "账户已过期"},
{CodeIPWhitelistMissing, "未添加ip白名单"},
{CodeUnauthorized, "未授权调用该接口"},
{CodeProductIDError, "产品id错误"},
{CodeInterfaceDisabled, "接口被停用"},
{CodeQueryException, "接口查询异常,请联系技术人员"},
{CodeNotFound, "未查询到结果"},
{9999, "未知错误码: 9999"}, // 测试未知错误码
}
for _, tc := range testCases {
result := GetXingweiErrorMessage(tc.code)
if result != tc.expected {
t.Errorf("错误码 %d 的消息不正确,期望: %s, 实际: %s", tc.code, tc.expected, result)
}
}
}
func TestXingweiResponseParsing(t *testing.T) {
// 测试响应结构解析
testCases := []struct {
name string
response string
expectedCode int
}{
{
name: "成功响应",
response: `{"msg": "操作成功", "code": 200, "data": {"result": "test"}}`,
expectedCode: CodeSuccess,
},
{
name: "商家错误",
response: `{"msg": "商家相关报错", "code": 3001, "data": null}`,
expectedCode: CodeMerchantError,
},
{
name: "未查询到结果",
response: `{"msg": "未查询到结果", "code": 6000, "data": null}`,
expectedCode: CodeNotFound,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var resp XingweiResponse
err := json.Unmarshal([]byte(tc.response), &resp)
if err != nil {
t.Errorf("解析响应失败: %v", err)
return
}
if resp.Code != tc.expectedCode {
t.Errorf("错误码不匹配,期望: %d, 实际: %d", tc.expectedCode, resp.Code)
}
// 测试错误消息获取
errorMsg := GetXingweiErrorMessage(resp.Code)
if errorMsg == "" {
t.Errorf("无法获取错误码 %d 的消息", resp.Code)
}
t.Logf("响应: %+v, 错误消息: %s", resp, errorMsg)
})
}
}
// TestXingweiErrorHandling 测试错误处理逻辑
func TestXingweiErrorHandling(t *testing.T) {
// 注意:这个测试主要验证常量定义和错误消息,不需要实际的服务实例
// 测试查空错误
t.Run("NotFound错误", func(t *testing.T) {
// 模拟返回查空响应
response := `{"msg": "未查询到结果", "code": 6000, "data": null}`
var xingweiResp XingweiResponse
err := json.Unmarshal([]byte(response), &xingweiResp)
if err != nil {
t.Fatalf("解析响应失败: %v", err)
}
// 验证状态码
if xingweiResp.Code != CodeNotFound {
t.Errorf("期望状态码 %d, 实际 %d", CodeNotFound, xingweiResp.Code)
}
// 验证错误消息
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
if errorMsg != "未查询到结果" {
t.Errorf("期望错误消息 '未查询到结果', 实际 '%s'", errorMsg)
}
t.Logf("查空错误测试通过: 状态码=%d, 消息=%s", xingweiResp.Code, errorMsg)
})
// 测试系统错误
t.Run("SystemError错误", func(t *testing.T) {
response := `{"msg": "系统内部错误", "code": 500, "data": null}`
var xingweiResp XingweiResponse
err := json.Unmarshal([]byte(response), &xingweiResp)
if err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if xingweiResp.Code != CodeSystemError {
t.Errorf("期望状态码 %d, 实际 %d", CodeSystemError, xingweiResp.Code)
}
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
if errorMsg != "系统内部错误" {
t.Errorf("期望错误消息 '系统内部错误', 实际 '%s'", errorMsg)
}
t.Logf("系统错误测试通过: 状态码=%d, 消息=%s", xingweiResp.Code, errorMsg)
})
// 测试成功响应
t.Run("Success响应", func(t *testing.T) {
response := `{"msg": "操作成功", "code": 200, "data": {"result": "test"}}`
var xingweiResp XingweiResponse
err := json.Unmarshal([]byte(response), &xingweiResp)
if err != nil {
t.Fatalf("解析响应失败: %v", err)
}
if xingweiResp.Code != CodeSuccess {
t.Errorf("期望状态码 %d, 实际 %d", CodeSuccess, xingweiResp.Code)
}
errorMsg := GetXingweiErrorMessage(xingweiResp.Code)
if errorMsg != "操作成功" {
t.Errorf("期望错误消息 '操作成功', 实际 '%s'", errorMsg)
}
t.Logf("成功响应测试通过: 状态码=%d, 消息=%s", xingweiResp.Code, errorMsg)
})
}

View File

@@ -0,0 +1,67 @@
package yushan
import (
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewYushanServiceWithConfig 使用配置创建羽山服务
func NewYushanServiceWithConfig(cfg *config.Config) (*YushanService, error) {
// 将配置类型转换为通用外部服务日志配置
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Yushan.Logging.Enabled,
LogDir: cfg.Yushan.Logging.LogDir,
ServiceName: "yushan",
UseDaily: cfg.Yushan.Logging.UseDaily,
EnableLevelSeparation: cfg.Yushan.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
// 转换级别配置
for key, value := range cfg.Yushan.Logging.LevelConfigs {
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: value.MaxSize,
MaxBackups: value.MaxBackups,
MaxAge: value.MaxAge,
Compress: value.Compress,
}
}
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建羽山服务
service := NewYushanService(
cfg.Yushan.URL,
cfg.Yushan.APIKey,
cfg.Yushan.AcctID,
logger,
)
return service, nil
}
// NewYushanServiceWithLogging 使用自定义日志配置创建羽山服务
func NewYushanServiceWithLogging(url, apiKey, acctID string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*YushanService, error) {
// 设置服务名称
loggingConfig.ServiceName = "yushan"
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建羽山服务
service := NewYushanService(url, apiKey, acctID, logger)
return service, nil
}
// NewYushanServiceSimple 创建简单的羽山服务(无日志)
func NewYushanServiceSimple(url, apiKey, acctID string) *YushanService {
return NewYushanService(url, apiKey, acctID, nil)
}

View File

@@ -0,0 +1,287 @@
package yushan
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"crypto/rand"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
"hyapi-server/internal/shared/external_logger"
"github.com/tidwall/gjson"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrNotFound = errors.New("查询为空")
ErrSystem = errors.New("系统异常")
)
type YushanConfig struct {
URL string
ApiKey string
AcctID string
}
type YushanService struct {
config YushanConfig
logger *external_logger.ExternalServiceLogger
}
// NewYushanService 是一个构造函数,用于初始化 YushanService
func NewYushanService(url, apiKey, acctID string, logger *external_logger.ExternalServiceLogger) *YushanService {
return &YushanService{
config: YushanConfig{
URL: url,
ApiKey: apiKey,
AcctID: acctID,
},
logger: logger,
}
}
// CallAPI 调用羽山数据的 API
func (y *YushanService) CallAPI(ctx context.Context, code string, params map[string]interface{}) (respBytes []byte, err error) {
startTime := time.Now()
requestID := y.generateRequestID()
// 从ctx中获取transactionId
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
// 记录请求日志
if y.logger != nil {
y.logger.LogRequest(requestID, transactionID, code, y.config.URL)
}
// 获取当前时间戳
unixMilliseconds := time.Now().UnixNano() / int64(time.Millisecond)
// 生成请求序列号
requestSN, _ := y.GenerateRandomString()
// 构建请求数据
reqData := map[string]interface{}{
"prod_id": code,
"req_time": unixMilliseconds,
"request_sn": requestSN,
"req_data": params,
}
// 将请求数据转换为 JSON 字节数组
messageBytes, err := json.Marshal(reqData)
if err != nil {
err = errors.Join(ErrSystem, err)
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
// 获取 API 密钥
key, err := hex.DecodeString(y.config.ApiKey)
if err != nil {
err = errors.Join(ErrSystem, err)
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
// 使用 AES CBC 加密请求数据
cipherText := y.AES_CBC_Encrypt(messageBytes, key)
// 将加密后的数据编码为 Base64 字符串
content := base64.StdEncoding.EncodeToString(cipherText)
// 发起 HTTP 请求超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
req, err := http.NewRequestWithContext(ctx, "POST", y.config.URL, strings.NewReader(content))
if err != nil {
err = errors.Join(ErrSystem, err)
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("ACCT_ID", y.config.AcctID)
// 执行请求
resp, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
} else {
err = errors.Join(ErrSystem, err)
}
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
defer resp.Body.Close()
// 读取响应体
body, err := io.ReadAll(resp.Body)
if err != nil {
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
var respData []byte
if IsJSON(string(body)) {
respData = body
} else {
sDec, err := base64.StdEncoding.DecodeString(string(body))
if err != nil {
err = errors.Join(ErrSystem, err)
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
respData = y.AES_CBC_Decrypt(sDec, key)
}
retCode := gjson.GetBytes(respData, "retcode").String()
// 记录响应日志(不记录具体响应数据)
if y.logger != nil {
duration := time.Since(startTime)
y.logger.LogResponse(requestID, transactionID, code, resp.StatusCode, duration)
}
if retCode == "100000" {
// retcode 为 100000表示查询为空
return nil, ErrNotFound
} else if retCode == "000000" {
// retcode 为 000000表示有数据返回 retdata
retData := gjson.GetBytes(respData, "retdata")
if !retData.Exists() {
err = errors.Join(ErrDatasource, fmt.Errorf("羽山请求retdata为空"))
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
return []byte(retData.Raw), nil
} else {
err = errors.Join(ErrDatasource, fmt.Errorf("羽山请求未知的状态码"))
if y.logger != nil {
y.logger.LogError(requestID, transactionID, code, err, params)
}
return nil, err
}
}
// generateRequestID 生成请求ID
func (y *YushanService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, y.config.ApiKey)))
return fmt.Sprintf("yushan_%x", hash[:8])
}
// GenerateRandomString 生成一个32位的随机字符串订单号
func (y *YushanService) GenerateRandomString() (string, error) {
// 创建一个16字节的数组
bytes := make([]byte, 16)
// 读取随机字节到数组中
if _, err := rand.Read(bytes); err != nil {
return "", err
}
// 将字节数组编码为16进制字符串
return hex.EncodeToString(bytes), nil
}
// AEC加密CBC模式
func (y *YushanService) AES_CBC_Encrypt(plainText []byte, key []byte) []byte {
//指定加密算法返回一个AES算法的Block接口对象
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
//进行填充
plainText = Padding(plainText, block.BlockSize())
//指定初始向量vi,长度和block的块尺寸一致
iv := []byte("0000000000000000")
//指定分组模式返回一个BlockMode接口对象
blockMode := cipher.NewCBCEncrypter(block, iv)
//加密连续数据库
cipherText := make([]byte, len(plainText))
blockMode.CryptBlocks(cipherText, plainText)
//返回base64密文
return cipherText
}
// AEC解密CBC模式
func (y *YushanService) AES_CBC_Decrypt(cipherText []byte, key []byte) []byte {
//指定解密算法返回一个AES算法的Block接口对象
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
//指定初始化向量IV,和加密的一致
iv := []byte("0000000000000000")
//指定分组模式返回一个BlockMode接口对象
blockMode := cipher.NewCBCDecrypter(block, iv)
//解密
plainText := make([]byte, len(cipherText))
blockMode.CryptBlocks(plainText, cipherText)
//删除填充
plainText = UnPadding(plainText)
return plainText
} // 对明文进行填充
func Padding(plainText []byte, blockSize int) []byte {
//计算要填充的长度
n := blockSize - len(plainText)%blockSize
//对原来的明文填充n个n
temp := bytes.Repeat([]byte{byte(n)}, n)
plainText = append(plainText, temp...)
return plainText
}
// 对密文删除填充
func UnPadding(cipherText []byte) []byte {
//取出密文最后一个字节end
end := cipherText[len(cipherText)-1]
//删除填充
cipherText = cipherText[:len(cipherText)-int(end)]
return cipherText
}
// 判断字符串是否为 JSON 格式
func IsJSON(s string) bool {
var js interface{}
return json.Unmarshal([]byte(s), &js) == nil
}

View File

@@ -0,0 +1,83 @@
package yushan
import (
"testing"
"time"
)
func TestGenerateRequestID(t *testing.T) {
service := &YushanService{
config: YushanConfig{
ApiKey: "test_api_key_123",
},
}
id1 := service.generateRequestID()
// 等待一小段时间确保时间戳不同
time.Sleep(time.Millisecond)
id2 := service.generateRequestID()
if id1 == "" || id2 == "" {
t.Error("请求ID生成失败")
}
if id1 == id2 {
t.Error("不同时间生成的请求ID应该不同")
}
// 验证ID格式
if len(id1) < 20 { // yushan_ + 8位十六进制 + 其他
t.Errorf("请求ID长度不足实际: %s", id1)
}
}
func TestGenerateRandomString(t *testing.T) {
service := &YushanService{}
str1, err := service.GenerateRandomString()
if err != nil {
t.Fatalf("生成随机字符串失败: %v", err)
}
str2, err := service.GenerateRandomString()
if err != nil {
t.Fatalf("生成随机字符串失败: %v", err)
}
if str1 == "" || str2 == "" {
t.Error("随机字符串为空")
}
if str1 == str2 {
t.Error("两次生成的随机字符串应该不同")
}
// 验证长度16字节 = 32位十六进制字符
if len(str1) != 32 || len(str2) != 32 {
t.Error("随机字符串长度应该是32位")
}
}
func TestIsJSON(t *testing.T) {
testCases := []struct {
input string
expected bool
}{
{"{}", true},
{"[]", true},
{"{\"key\": \"value\"}", true},
{"[1, 2, 3]", true},
{"invalid json", false},
{"", false},
{"{invalid}", false},
}
for _, tc := range testCases {
result := IsJSON(tc.input)
if result != tc.expected {
t.Errorf("输入: %s, 期望: %v, 实际: %v", tc.input, tc.expected, result)
}
}
}

View File

@@ -0,0 +1,121 @@
package zhicha
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/base64"
"encoding/hex"
"fmt"
)
const (
KEY_SIZE = 16 // AES-128, 16 bytes
)
// Encrypt 使用AES-128-CBC加密数据
// 对应Python示例中的encrypt函数
func Encrypt(data, key string) (string, error) {
// 将十六进制密钥转换为字节
binKey, err := hex.DecodeString(key)
if err != nil {
return "", fmt.Errorf("密钥格式错误: %w", err)
}
if len(binKey) < KEY_SIZE {
return "", fmt.Errorf("密钥长度不足,需要至少%d字节", KEY_SIZE)
}
// 从密钥前16个字符生成IV
iv := []byte(key[:KEY_SIZE])
// 创建AES加密器
block, err := aes.NewCipher(binKey)
if err != nil {
return "", fmt.Errorf("创建AES加密器失败: %w", err)
}
// 对数据进行PKCS7填充
paddedData := pkcs7Padding([]byte(data), aes.BlockSize)
// 创建CBC模式加密器
mode := cipher.NewCBCEncrypter(block, iv)
// 加密
ciphertext := make([]byte, len(paddedData))
mode.CryptBlocks(ciphertext, paddedData)
// 返回Base64编码结果
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt 使用AES-128-CBC解密数据
// 对应Python示例中的decrypt函数
func Decrypt(encryptedData, key string) (string, error) {
// 将十六进制密钥转换为字节
binKey, err := hex.DecodeString(key)
if err != nil {
return "", fmt.Errorf("密钥格式错误: %w", err)
}
if len(binKey) < KEY_SIZE {
return "", fmt.Errorf("密钥长度不足,需要至少%d字节", KEY_SIZE)
}
// 从密钥前16个字符生成IV
iv := []byte(key[:KEY_SIZE])
// 解码Base64数据
decodedData, err := base64.StdEncoding.DecodeString(encryptedData)
if err != nil {
return "", fmt.Errorf("Base64解码失败: %w", err)
}
// 检查数据长度是否为AES块大小的倍数
if len(decodedData) == 0 || len(decodedData)%aes.BlockSize != 0 {
return "", fmt.Errorf("加密数据长度无效,必须是%d字节的倍数", aes.BlockSize)
}
// 创建AES解密器
block, err := aes.NewCipher(binKey)
if err != nil {
return "", fmt.Errorf("创建AES解密器失败: %w", err)
}
// 创建CBC模式解密器
mode := cipher.NewCBCDecrypter(block, iv)
// 解密
plaintext := make([]byte, len(decodedData))
mode.CryptBlocks(plaintext, decodedData)
// 移除PKCS7填充
unpadded, err := pkcs7Unpadding(plaintext)
if err != nil {
return "", fmt.Errorf("移除填充失败: %w", err)
}
return string(unpadded), nil
}
// pkcs7Padding 使用PKCS7填充数据
func pkcs7Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
// pkcs7Unpadding 移除PKCS7填充
func pkcs7Unpadding(src []byte) ([]byte, error) {
length := len(src)
if length == 0 {
return nil, fmt.Errorf("数据为空")
}
unpadding := int(src[length-1])
if unpadding > length {
return nil, fmt.Errorf("填充长度无效")
}
return src[:length-unpadding], nil
}

View File

@@ -0,0 +1,170 @@
package zhicha
import (
"fmt"
)
// ZhichaError 智查金控服务错误
type ZhichaError struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Error 实现error接口
func (e *ZhichaError) Error() string {
return fmt.Sprintf("智查金控错误 [%s]: %s", e.Code, e.Message)
}
// IsSuccess 检查是否成功
func (e *ZhichaError) IsSuccess() bool {
return e.Code == "200"
}
// IsNoRecord 检查是否查询无记录
func (e *ZhichaError) IsNoRecord() bool {
return e.Code == "201"
}
// IsBusinessError 检查是否是业务错误(非系统错误)
func (e *ZhichaError) IsBusinessError() bool {
return e.Code >= "302" && e.Code <= "320"
}
// IsSystemError 检查是否是系统错误
func (e *ZhichaError) IsSystemError() bool {
return e.Code == "500"
}
// IsAuthError 检查是否是认证相关错误
func (e *ZhichaError) IsAuthError() bool {
return e.Code == "304" || e.Code == "318" || e.Code == "319" || e.Code == "320"
}
// IsParamError 检查是否是参数相关错误
func (e *ZhichaError) IsParamError() bool {
return e.Code == "302" || e.Code == "303" || e.Code == "305" || e.Code == "306" || e.Code == "307" || e.Code == "316" || e.Code == "317"
}
// IsServiceError 检查是否是服务相关错误
func (e *ZhichaError) IsServiceError() bool {
return e.Code == "308" || e.Code == "309" || e.Code == "310" || e.Code == "311"
}
// IsUserError 检查是否是用户相关错误
func (e *ZhichaError) IsUserError() bool {
return e.Code == "312" || e.Code == "313" || e.Code == "314" || e.Code == "315"
}
// 预定义错误常量
var (
// 成功状态
ErrSuccess = &ZhichaError{Code: "200", Message: "请求成功"}
ErrNoRecord = &ZhichaError{Code: "201", Message: "查询无记录"}
// 业务参数错误
ErrBusinessParamMissing = &ZhichaError{Code: "302", Message: "业务参数缺失"}
ErrParamError = &ZhichaError{Code: "303", Message: "参数错误"}
ErrHeaderParamMissing = &ZhichaError{Code: "304", Message: "请求头参数缺失"}
ErrNameError = &ZhichaError{Code: "305", Message: "姓名错误"}
ErrPhoneError = &ZhichaError{Code: "306", Message: "手机号错误"}
ErrIDCardError = &ZhichaError{Code: "307", Message: "身份证号错误"}
// 服务相关错误
ErrServiceNotExist = &ZhichaError{Code: "308", Message: "服务不存在"}
ErrServiceNotEnabled = &ZhichaError{Code: "309", Message: "服务未开通"}
ErrInsufficientBalance = &ZhichaError{Code: "310", Message: "余额不足"}
ErrRemoteDataError = &ZhichaError{Code: "311", Message: "调用远程数据异常"}
// 用户相关错误
ErrUserNotExist = &ZhichaError{Code: "312", Message: "用户不存在"}
ErrUserStatusError = &ZhichaError{Code: "313", Message: "用户状态异常"}
ErrUserUnauthorized = &ZhichaError{Code: "314", Message: "用户未授权"}
ErrWhitelistError = &ZhichaError{Code: "315", Message: "白名单错误"}
// 时间戳和认证错误
ErrTimestampInvalid = &ZhichaError{Code: "316", Message: "timestamp不合法"}
ErrTimestampExpired = &ZhichaError{Code: "317", Message: "timestamp已过期"}
ErrSignVerifyFailed = &ZhichaError{Code: "318", Message: "验签失败"}
ErrDecryptFailed = &ZhichaError{Code: "319", Message: "解密失败"}
ErrUnauthorized = &ZhichaError{Code: "320", Message: "未授权"}
// 系统错误
ErrSystemError = &ZhichaError{Code: "500", Message: "系统异常,请联系管理员"}
)
// NewZhichaError 创建新的智查金控错误
func NewZhichaError(code, message string) *ZhichaError {
return &ZhichaError{
Code: code,
Message: message,
}
}
// NewZhichaErrorFromCode 根据状态码创建错误
func NewZhichaErrorFromCode(code string) *ZhichaError {
switch code {
case "200":
return ErrSuccess
case "201":
return ErrNoRecord
case "302":
return ErrBusinessParamMissing
case "303":
return ErrParamError
case "304":
return ErrHeaderParamMissing
case "305":
return ErrNameError
case "306":
return ErrPhoneError
case "307":
return ErrIDCardError
case "308":
return ErrServiceNotExist
case "309":
return ErrServiceNotEnabled
case "310":
return ErrInsufficientBalance
case "311":
return ErrRemoteDataError
case "312":
return ErrUserNotExist
case "313":
return ErrUserStatusError
case "314":
return ErrUserUnauthorized
case "315":
return ErrWhitelistError
case "316":
return ErrTimestampInvalid
case "317":
return ErrTimestampExpired
case "318":
return ErrSignVerifyFailed
case "319":
return ErrDecryptFailed
case "320":
return ErrUnauthorized
case "500":
return ErrSystemError
default:
return &ZhichaError{
Code: code,
Message: "未知错误",
}
}
}
// IsZhichaError 检查是否是智查金控错误
func IsZhichaError(err error) bool {
_, ok := err.(*ZhichaError)
return ok
}
// GetZhichaError 获取智查金控错误
func GetZhichaError(err error) *ZhichaError {
if zhichaErr, ok := err.(*ZhichaError); ok {
return zhichaErr
}
return nil
}

View File

@@ -0,0 +1,68 @@
package zhicha
import (
"hyapi-server/internal/config"
"hyapi-server/internal/shared/external_logger"
)
// NewZhichaServiceWithConfig 使用配置创建智查金控服务
func NewZhichaServiceWithConfig(cfg *config.Config) (*ZhichaService, error) {
// 将配置类型转换为通用外部服务日志配置
loggingConfig := external_logger.ExternalServiceLoggingConfig{
Enabled: cfg.Zhicha.Logging.Enabled,
LogDir: cfg.Zhicha.Logging.LogDir,
ServiceName: "zhicha",
UseDaily: cfg.Zhicha.Logging.UseDaily,
EnableLevelSeparation: cfg.Zhicha.Logging.EnableLevelSeparation,
LevelConfigs: make(map[string]external_logger.ExternalServiceLevelFileConfig),
}
// 转换级别配置
for key, value := range cfg.Zhicha.Logging.LevelConfigs {
loggingConfig.LevelConfigs[key] = external_logger.ExternalServiceLevelFileConfig{
MaxSize: value.MaxSize,
MaxBackups: value.MaxBackups,
MaxAge: value.MaxAge,
Compress: value.Compress,
}
}
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建智查金控服务
service := NewZhichaService(
cfg.Zhicha.URL,
cfg.Zhicha.AppID,
cfg.Zhicha.AppSecret,
cfg.Zhicha.EncryptKey,
logger,
)
return service, nil
}
// NewZhichaServiceWithLogging 使用自定义日志配置创建智查金控服务
func NewZhichaServiceWithLogging(url, appID, appSecret, encryptKey string, loggingConfig external_logger.ExternalServiceLoggingConfig) (*ZhichaService, error) {
// 设置服务名称
loggingConfig.ServiceName = "zhicha"
// 创建通用外部服务日志器
logger, err := external_logger.NewExternalServiceLogger(loggingConfig)
if err != nil {
return nil, err
}
// 创建智查金控服务
service := NewZhichaService(url, appID, appSecret, encryptKey, logger)
return service, nil
}
// NewZhichaServiceSimple 创建简单的智查金控服务(无日志)
func NewZhichaServiceSimple(url, appID, appSecret, encryptKey string) *ZhichaService {
return NewZhichaService(url, appID, appSecret, encryptKey, nil)
}

View File

@@ -0,0 +1,338 @@
package zhicha
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/md5"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"time"
"hyapi-server/internal/shared/external_logger"
)
var (
ErrDatasource = errors.New("数据源异常")
ErrSystem = errors.New("系统异常")
)
// contextKey 用于在 context 中存储不跳过 201 错误检查的标志
type contextKey string
const dontSkipCode201CheckKey contextKey = "dont_skip_code_201_check"
type ZhichaResp struct {
Code string `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
Success bool `json:"success"`
}
type ZhichaConfig struct {
URL string
AppID string
AppSecret string
EncryptKey string
}
type ZhichaService struct {
config ZhichaConfig
logger *external_logger.ExternalServiceLogger
}
// NewZhichaService 是一个构造函数,用于初始化 ZhichaService
func NewZhichaService(url, appID, appSecret, encryptKey string, logger *external_logger.ExternalServiceLogger) *ZhichaService {
return &ZhichaService{
config: ZhichaConfig{
URL: url,
AppID: appID,
AppSecret: appSecret,
EncryptKey: encryptKey,
},
logger: logger,
}
}
// generateRequestID 生成请求ID
func (z *ZhichaService) generateRequestID() string {
timestamp := time.Now().UnixNano()
hash := md5.Sum([]byte(fmt.Sprintf("%d_%s", timestamp, z.config.AppID)))
return fmt.Sprintf("zhicha_%x", hash[:8])
}
// generateSign 生成签名
func (z *ZhichaService) generateSign(timestamp int64) string {
// 第一步对app_secret进行MD5加密
encryptedSecret := fmt.Sprintf("%x", md5.Sum([]byte(z.config.AppSecret)))
// 第二步将加密后的密钥和时间戳拼接再次MD5加密
signStr := encryptedSecret + strconv.FormatInt(timestamp, 10)
sign := fmt.Sprintf("%x", md5.Sum([]byte(signStr)))
return sign
}
// CallAPI 调用智查金控的 API
func (z *ZhichaService) CallAPI(ctx context.Context, proID string, params map[string]interface{}) (data interface{}, err error) {
startTime := time.Now()
requestID := z.generateRequestID()
timestamp := time.Now().Unix()
// 从ctx中获取transactionId
var transactionID string
if ctxTransactionID, ok := ctx.Value("transaction_id").(string); ok {
transactionID = ctxTransactionID
}
// 记录请求日志
if z.logger != nil {
z.logger.LogRequest(requestID, transactionID, proID, z.config.URL)
}
jsonData, marshalErr := json.Marshal(params)
if marshalErr != nil {
err = errors.Join(ErrSystem, marshalErr)
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, err, params)
}
return nil, err
}
// 创建HTTP POST请求
req, err := http.NewRequestWithContext(ctx, "POST", z.config.URL, bytes.NewBuffer(jsonData))
if err != nil {
err = errors.Join(ErrSystem, err)
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, err, params)
}
return nil, err
}
// 设置请求头
req.Header.Set("Content-Type", "application/json")
req.Header.Set("appId", z.config.AppID)
req.Header.Set("proId", proID)
req.Header.Set("timestamp", strconv.FormatInt(timestamp, 10))
req.Header.Set("sign", z.generateSign(timestamp))
// 创建HTTP客户端超时时间设置为60秒
client := &http.Client{
Timeout: 60 * time.Second,
}
// 发送请求
response, err := client.Do(req)
if err != nil {
// 检查是否是超时错误
isTimeout := false
if ctx.Err() == context.DeadlineExceeded {
isTimeout = true
} else if netErr, ok := err.(interface{ Timeout() bool }); ok && netErr.Timeout() {
// 检查是否是网络超时错误
isTimeout = true
} else if errStr := err.Error(); errStr == "context deadline exceeded" ||
errStr == "timeout" ||
errStr == "Client.Timeout exceeded" ||
errStr == "net/http: request canceled" {
isTimeout = true
}
if isTimeout {
// 超时错误应该返回数据源异常,而不是系统异常
err = errors.Join(ErrDatasource, fmt.Errorf("API请求超时: %v", err))
} else {
err = errors.Join(ErrSystem, err)
}
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, err, params)
}
return nil, err
}
defer response.Body.Close()
// 读取响应
respBody, err := io.ReadAll(response.Body)
if err != nil {
err = errors.Join(ErrSystem, err)
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, err, params)
}
return nil, err
}
// 记录响应日志(不记录具体响应数据)
if z.logger != nil {
duration := time.Since(startTime)
z.logger.LogResponse(requestID, transactionID, proID, response.StatusCode, duration)
}
// 检查HTTP状态码
if response.StatusCode != http.StatusOK {
err = errors.Join(ErrDatasource, fmt.Errorf("HTTP状态码 %d", response.StatusCode))
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, err, params)
}
return nil, err
}
// 解析响应
var zhichaResp ZhichaResp
if err := json.Unmarshal(respBody, &zhichaResp); err != nil {
err = errors.Join(ErrSystem, fmt.Errorf("响应解析失败: %s", err.Error()))
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, err, params)
}
return nil, err
}
// 检查业务状态码
if zhichaResp.Code != "200" && zhichaResp.Code != "201" {
// 创建智查金控错误用于日志记录
zhichaErr := NewZhichaErrorFromCode(zhichaResp.Code)
if zhichaErr.Code == "未知错误" {
zhichaErr.Message = zhichaResp.Message
}
// 记录智查金控的详细错误信息到日志
if z.logger != nil {
z.logger.LogError(requestID, transactionID, proID, zhichaErr, params)
}
// 对外统一返回数据源异常错误
return nil, ErrDatasource
}
// 201 表示查询为空兼容其它情况如果data也为空则返回空对象
if zhichaResp.Code == "201" {
// 先做类型断言
dataMap, ok := zhichaResp.Data.(map[string]interface{})
if ok && len(dataMap) > 0 {
return dataMap, nil
}
return map[string]interface{}{}, nil
}
// 返回data字段
return zhichaResp.Data, nil
}
// Encrypt 使用配置的加密密钥对数据进行AES-128-CBC加密
func (z *ZhichaService) Encrypt(data string) (string, error) {
if z.config.EncryptKey == "" {
return "", fmt.Errorf("加密密钥未配置")
}
// 将十六进制密钥转换为字节
binKey, err := hex.DecodeString(z.config.EncryptKey)
if err != nil {
return "", fmt.Errorf("密钥格式错误: %w", err)
}
if len(binKey) < 16 { // AES-128, 16 bytes
return "", fmt.Errorf("密钥长度不足需要至少16字节")
}
// 从密钥前16个字符生成IV
iv := []byte(z.config.EncryptKey[:16])
// 创建AES加密器
block, err := aes.NewCipher(binKey)
if err != nil {
return "", fmt.Errorf("创建AES加密器失败: %w", err)
}
// 对数据进行PKCS7填充
paddedData := z.pkcs7Padding([]byte(data), aes.BlockSize)
// 创建CBC模式加密器
mode := cipher.NewCBCEncrypter(block, iv)
// 加密
ciphertext := make([]byte, len(paddedData))
mode.CryptBlocks(ciphertext, paddedData)
// 返回Base64编码结果
return base64.StdEncoding.EncodeToString(ciphertext), nil
}
// Decrypt 使用配置的加密密钥对数据进行AES-128-CBC解密
func (z *ZhichaService) Decrypt(encryptedData string) (string, error) {
if z.config.EncryptKey == "" {
return "", fmt.Errorf("加密密钥未配置")
}
// 将十六进制密钥转换为字节
binKey, err := hex.DecodeString(z.config.EncryptKey)
if err != nil {
return "", fmt.Errorf("密钥格式错误: %w", err)
}
if len(binKey) < 16 { // AES-128, 16 bytes
return "", fmt.Errorf("密钥长度不足需要至少16字节")
}
// 从密钥前16个字符生成IV
iv := []byte(z.config.EncryptKey[:16])
// 解码Base64数据
decodedData, err := base64.StdEncoding.DecodeString(encryptedData)
if err != nil {
return "", fmt.Errorf("Base64解码失败: %w", err)
}
// 检查数据长度是否为AES块大小的倍数
if len(decodedData) == 0 || len(decodedData)%aes.BlockSize != 0 {
return "", fmt.Errorf("加密数据长度无效,必须是%d字节的倍数", aes.BlockSize)
}
// 创建AES解密器
block, err := aes.NewCipher(binKey)
if err != nil {
return "", fmt.Errorf("创建AES解密器失败: %w", err)
}
// 创建CBC模式解密器
mode := cipher.NewCBCDecrypter(block, iv)
// 解密
plaintext := make([]byte, len(decodedData))
mode.CryptBlocks(plaintext, decodedData)
// 移除PKCS7填充
unpadded, err := z.pkcs7Unpadding(plaintext)
if err != nil {
return "", fmt.Errorf("移除填充失败: %w", err)
}
return string(unpadded), nil
}
// pkcs7Padding 使用PKCS7填充数据
func (z *ZhichaService) pkcs7Padding(src []byte, blockSize int) []byte {
padding := blockSize - len(src)%blockSize
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(src, padtext...)
}
// pkcs7Unpadding 移除PKCS7填充
func (z *ZhichaService) pkcs7Unpadding(src []byte) ([]byte, error) {
length := len(src)
if length == 0 {
return nil, fmt.Errorf("数据为空")
}
unpadding := int(src[length-1])
if unpadding > length {
return nil, fmt.Errorf("填充长度无效")
}
return src[:length-unpadding], nil
}

View File

@@ -0,0 +1,703 @@
package zhicha
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"time"
)
func TestGenerateSign(t *testing.T) {
service := &ZhichaService{
config: ZhichaConfig{
AppSecret: "test_secret_123",
},
}
timestamp := int64(1640995200) // 2022-01-01 00:00:00
sign := service.generateSign(timestamp)
if sign == "" {
t.Error("签名生成失败,签名为空")
}
// 验证签名长度MD5是32位十六进制
if len(sign) != 32 {
t.Errorf("签名长度错误期望32位实际%d位", len(sign))
}
// 验证相同参数生成相同签名
sign2 := service.generateSign(timestamp)
if sign != sign2 {
t.Error("相同参数生成的签名不一致")
}
}
func TestEncryptDecrypt(t *testing.T) {
// 测试密钥32位十六进制
key := "1234567890abcdef1234567890abcdef"
// 测试数据
testData := "这是一个测试数据包含中文和English"
// 加密
encrypted, err := Encrypt(testData, key)
if err != nil {
t.Fatalf("加密失败: %v", err)
}
if encrypted == "" {
t.Error("加密结果为空")
}
// 解密
decrypted, err := Decrypt(encrypted, key)
if err != nil {
t.Fatalf("解密失败: %v", err)
}
if decrypted != testData {
t.Errorf("解密结果不匹配,期望: %s, 实际: %s", testData, decrypted)
}
}
func TestEncryptWithInvalidKey(t *testing.T) {
// 测试无效密钥
invalidKeys := []string{
"", // 空密钥
"123", // 太短
"invalid_key_string", // 非十六进制
"1234567890abcdef", // 16位不足32位
}
testData := "test data"
for _, key := range invalidKeys {
_, err := Encrypt(testData, key)
if err == nil {
t.Errorf("使用无效密钥 %s 应该返回错误", key)
}
}
}
func TestDecryptWithInvalidData(t *testing.T) {
key := "af4ca0098e6a202a5c08c413ebd9fd62"
// 测试无效的加密数据
invalidData := []string{
"", // 空数据
"invalid_base64", // 无效的Base64
"dGVzdA==", // 有效的Base64但不是AES加密数据
"i96w+SDjwENjuvsokMFbLw==",
"oaihmICgEcszWMk0gXoB12E/ygF4g78x0/sC3/KHnBk=",
"5bx+WvXvdNRVVOp9UuNFHg==",
}
for _, data := range invalidData {
decrypted, err := Decrypt(data, key)
if err == nil {
t.Errorf("使用无效数据 %s 应该返回错误", data)
}
fmt.Println("data: ", data)
fmt.Println("decrypted: ", decrypted)
}
}
func TestPKCS7Padding(t *testing.T) {
testCases := []struct {
input string
blockSize int
expected int
}{
{"", 16, 16},
{"a", 16, 16},
{"ab", 16, 16},
{"abc", 16, 16},
{"abcd", 16, 16},
{"abcde", 16, 16},
{"abcdef", 16, 16},
{"abcdefg", 16, 16},
{"abcdefgh", 16, 16},
{"abcdefghi", 16, 16},
{"abcdefghij", 16, 16},
{"abcdefghijk", 16, 16},
{"abcdefghijkl", 16, 16},
{"abcdefghijklm", 16, 16},
{"abcdefghijklmn", 16, 16},
{"abcdefghijklmno", 16, 16},
{"abcdefghijklmnop", 16, 16},
}
for _, tc := range testCases {
padded := pkcs7Padding([]byte(tc.input), tc.blockSize)
if len(padded)%tc.blockSize != 0 {
t.Errorf("输入: %s, 期望块大小倍数,实际: %d", tc.input, len(padded))
}
// 测试移除填充
unpadded, err := pkcs7Unpadding(padded)
if err != nil {
t.Errorf("移除填充失败: %v", err)
}
if string(unpadded) != tc.input {
t.Errorf("输入: %s, 期望: %s, 实际: %s", tc.input, tc.input, string(unpadded))
}
}
}
func TestGenerateRequestID(t *testing.T) {
service := &ZhichaService{
config: ZhichaConfig{
AppID: "test_app_id",
},
}
id1 := service.generateRequestID()
// 等待一小段时间确保时间戳不同
time.Sleep(time.Millisecond)
id2 := service.generateRequestID()
if id1 == "" || id2 == "" {
t.Error("请求ID生成失败")
}
if id1 == id2 {
t.Error("不同时间生成的请求ID应该不同")
}
// 验证ID格式
if len(id1) < 20 { // zhicha_ + 8位十六进制 + 其他
t.Errorf("请求ID长度不足实际: %s", id1)
}
}
func TestCallAPISuccess(t *testing.T) {
// 创建测试服务
service := &ZhichaService{
config: ZhichaConfig{
URL: "http://proxy.haiyudata.com/dataMiddle/api/handle",
AppID: "4b78fff61ab8426f",
AppSecret: "1128f01b94124ae899c2e9f2b1f37681",
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
},
logger: nil, // 测试时不使用日志
}
// 测试参数
idCardEncrypted, err := service.Encrypt("45212220000827423X")
if err != nil {
t.Fatalf("加密身份证号失败: %v", err)
}
nameEncrypted, err := service.Encrypt("张荣宏")
if err != nil {
t.Fatalf("加密姓名失败: %v", err)
}
params := map[string]interface{}{
"idCard": idCardEncrypted,
"name": nameEncrypted,
"authorized": "1",
}
// 创建带超时的context
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 调用API
data, err := service.CallAPI(ctx, "ZCI001", params)
// 注意这是真实API调用可能会因为网络、认证等原因失败
// 我们主要测试方法调用是否正常不强制要求API返回成功
if err != nil {
// 如果是网络错误或认证错误,这是正常的
t.Logf("API调用返回错误: %v", err)
return
}
// 如果成功,验证响应
if data == nil {
t.Error("响应数据为空")
return
}
// 将data转换为字符串进行显示
var dataStr string
if str, ok := data.(string); ok {
dataStr = str
} else {
// 如果不是字符串尝试JSON序列化
if dataBytes, err := json.Marshal(data); err == nil {
dataStr = string(dataBytes)
} else {
dataStr = fmt.Sprintf("%v", data)
}
}
t.Logf("API调用成功响应内容: %s", dataStr)
}
func TestCallAPIWithInvalidURL(t *testing.T) {
// 创建使用无效URL的服务
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://invalid-url-that-does-not-exist.com/api",
AppID: "test_app_id",
AppSecret: "test_app_secret",
EncryptKey: "test_encrypt_key",
},
logger: nil,
}
params := map[string]interface{}{
"test": "data",
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
// 应该返回错误
_, err := service.CallAPI(ctx, "test_pro_id", params)
if err == nil {
t.Error("使用无效URL应该返回错误")
}
t.Logf("预期的错误: %v", err)
}
func TestCallAPIWithContextCancellation(t *testing.T) {
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://www.zhichajinkong.com/dataMiddle/api/handle",
AppID: "4b78fff61ab8426f",
AppSecret: "1128f01b94124ae899c2e9f2b1f37681",
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
},
logger: nil,
}
params := map[string]interface{}{
"test": "data",
}
// 创建可取消的context
ctx, cancel := context.WithCancel(context.Background())
// 立即取消
cancel()
// 应该返回context取消错误
_, err := service.CallAPI(ctx, "test_pro_id", params)
if err == nil {
t.Error("context取消后应该返回错误")
}
// 检查是否是context取消错误
if err != context.Canceled && !strings.Contains(err.Error(), "context") {
t.Errorf("期望context相关错误实际: %v", err)
}
t.Logf("Context取消错误: %v", err)
}
func TestCallAPIWithTimeout(t *testing.T) {
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://www.zhichajinkong.com/dataMiddle/api/handle",
AppID: "4b78fff61ab8426f",
AppSecret: "1128f01b94124ae899c2e9f2b1f37681",
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
},
logger: nil,
}
params := map[string]interface{}{
"test": "data",
}
// 创建很短的超时
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
// 应该因为超时而失败
_, err := service.CallAPI(ctx, "test_pro_id", params)
if err == nil {
t.Error("超时后应该返回错误")
}
// 检查是否是超时错误
if err != context.DeadlineExceeded && !strings.Contains(err.Error(), "timeout") && !strings.Contains(err.Error(), "deadline") {
t.Errorf("期望超时相关错误,实际: %v", err)
}
t.Logf("超时错误: %v", err)
}
func TestCallAPIRequestHeaders(t *testing.T) {
// 这个测试验证请求头是否正确设置
// 由于我们不能直接访问HTTP请求我们通过日志或其他方式来验证
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://www.zhichajinkong.com/dataMiddle/api/handle",
AppID: "test_app_id",
AppSecret: "test_app_secret",
EncryptKey: "test_encrypt_key",
},
logger: nil,
}
params := map[string]interface{}{
"test": "headers",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 调用API可能会失败但我们主要测试请求头设置
_, err := service.CallAPI(ctx, "test_pro_id", params)
// 验证签名生成是否正确
timestamp := time.Now().Unix()
sign := service.generateSign(timestamp)
if sign == "" {
t.Error("签名生成失败")
}
if len(sign) != 32 {
t.Errorf("签名长度错误期望32位实际%d位", len(sign))
}
t.Logf("签名生成成功: %s", sign)
t.Logf("API调用结果: %v", err)
}
func TestZhichaErrorHandling(t *testing.T) {
// 测试核心错误类型
testCases := []struct {
name string
code string
message string
expectedErr *ZhichaError
}{
{
name: "成功状态",
code: "200",
message: "请求成功",
expectedErr: ErrSuccess,
},
{
name: "查询无记录",
code: "201",
message: "查询无记录",
expectedErr: ErrNoRecord,
},
{
name: "手机号错误",
code: "306",
message: "手机号错误",
expectedErr: ErrPhoneError,
},
{
name: "姓名错误",
code: "305",
message: "姓名错误",
expectedErr: ErrNameError,
},
{
name: "身份证号错误",
code: "307",
message: "身份证号错误",
expectedErr: ErrIDCardError,
},
{
name: "余额不足",
code: "310",
message: "余额不足",
expectedErr: ErrInsufficientBalance,
},
{
name: "用户不存在",
code: "312",
message: "用户不存在",
expectedErr: ErrUserNotExist,
},
{
name: "系统异常",
code: "500",
message: "系统异常,请联系管理员",
expectedErr: ErrSystemError,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 测试从状态码创建错误
err := NewZhichaErrorFromCode(tc.code)
if err.Code != tc.expectedErr.Code {
t.Errorf("期望错误码 %s实际 %s", tc.expectedErr.Code, err.Code)
}
if err.Message != tc.expectedErr.Message {
t.Errorf("期望错误消息 %s实际 %s", tc.expectedErr.Message, err.Message)
}
})
}
}
func TestZhichaErrorHelpers(t *testing.T) {
// 测试错误类型判断函数
err := NewZhichaError("302", "业务参数缺失")
// 测试IsZhichaError
if !IsZhichaError(err) {
t.Error("IsZhichaError应该返回true")
}
// 测试GetZhichaError
zhichaErr := GetZhichaError(err)
if zhichaErr == nil {
t.Error("GetZhichaError应该返回非nil值")
}
if zhichaErr.Code != "302" {
t.Errorf("期望错误码302实际%s", zhichaErr.Code)
}
// 测试普通错误
normalErr := fmt.Errorf("普通错误")
if IsZhichaError(normalErr) {
t.Error("普通错误不应该被识别为智查金控错误")
}
if GetZhichaError(normalErr) != nil {
t.Error("普通错误的GetZhichaError应该返回nil")
}
}
func TestZhichaErrorString(t *testing.T) {
// 测试错误字符串格式
err := NewZhichaError("304", "请求头参数缺失")
expectedStr := "智查金控错误 [304]: 请求头参数缺失"
if err.Error() != expectedStr {
t.Errorf("期望错误字符串 %s实际 %s", expectedStr, err.Error())
}
}
func TestErrorsIsFunctionality(t *testing.T) {
// 测试 errors.Is() 功能是否正常工作
// 创建各种错误
testCases := []struct {
name string
err error
expected error
shouldMatch bool
}{
{
name: "手机号错误匹配",
err: ErrPhoneError,
expected: ErrPhoneError,
shouldMatch: true,
},
{
name: "姓名错误匹配",
err: ErrNameError,
expected: ErrNameError,
shouldMatch: true,
},
{
name: "身份证号错误匹配",
err: ErrIDCardError,
expected: ErrIDCardError,
shouldMatch: true,
},
{
name: "余额不足错误匹配",
err: ErrInsufficientBalance,
expected: ErrInsufficientBalance,
shouldMatch: true,
},
{
name: "用户不存在错误匹配",
err: ErrUserNotExist,
expected: ErrUserNotExist,
shouldMatch: true,
},
{
name: "系统错误匹配",
err: ErrSystemError,
expected: ErrSystemError,
shouldMatch: true,
},
{
name: "不同错误不匹配",
err: ErrPhoneError,
expected: ErrNameError,
shouldMatch: false,
},
{
name: "手机号错误与身份证号错误不匹配",
err: ErrPhoneError,
expected: ErrIDCardError,
shouldMatch: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// 使用 errors.Is() 进行判断
if errors.Is(tc.err, tc.expected) != tc.shouldMatch {
if tc.shouldMatch {
t.Errorf("期望 errors.Is(%v, %v) 返回 true", tc.err, tc.expected)
} else {
t.Errorf("期望 errors.Is(%v, %v) 返回 false", tc.err, tc.expected)
}
}
})
}
}
func TestErrorsIsInSwitch(t *testing.T) {
// 测试在 switch 语句中使用 errors.Is()
// 模拟API调用返回手机号错误
err := ErrPhoneError
// 使用 switch 语句进行错误判断
var result string
switch {
case errors.Is(err, ErrSuccess):
result = "请求成功"
case errors.Is(err, ErrNoRecord):
result = "查询无记录"
case errors.Is(err, ErrPhoneError):
result = "手机号格式错误"
case errors.Is(err, ErrNameError):
result = "姓名格式错误"
case errors.Is(err, ErrIDCardError):
result = "身份证号格式错误"
case errors.Is(err, ErrHeaderParamMissing):
result = "请求头参数缺失"
case errors.Is(err, ErrInsufficientBalance):
result = "余额不足"
case errors.Is(err, ErrUserNotExist):
result = "用户不存在"
case errors.Is(err, ErrUserUnauthorized):
result = "用户未授权"
case errors.Is(err, ErrSystemError):
result = "系统异常"
default:
result = "未知错误"
}
// 验证结果
expected := "手机号格式错误"
if result != expected {
t.Errorf("期望结果 %s实际 %s", expected, result)
}
t.Logf("Switch语句错误判断结果: %s", result)
}
func TestServiceEncryptDecrypt(t *testing.T) {
// 创建测试服务
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://test.com",
AppID: "test_app_id",
AppSecret: "test_app_secret",
EncryptKey: "af4ca0098e6a202a5c08c413ebd9fd62",
},
logger: nil,
}
// 测试数据
testData := "Hello, 智查金控!"
// 测试加密
encrypted, err := service.Encrypt(testData)
if err != nil {
t.Fatalf("加密失败: %v", err)
}
if encrypted == "" {
t.Error("加密结果为空")
}
if encrypted == testData {
t.Error("加密结果与原文相同")
}
t.Logf("原文: %s", testData)
t.Logf("加密后: %s", encrypted)
// 测试解密
decrypted, err := service.Decrypt(encrypted)
if err != nil {
t.Fatalf("解密失败: %v", err)
}
if decrypted != testData {
t.Errorf("解密结果不匹配,期望: %s实际: %s", testData, decrypted)
}
t.Logf("解密后: %s", decrypted)
}
func TestEncryptWithoutKey(t *testing.T) {
// 创建没有加密密钥的服务
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://test.com",
AppID: "test_app_id",
AppSecret: "test_app_secret",
// 没有设置 EncryptKey
},
logger: nil,
}
// 应该返回错误
_, err := service.Encrypt("test data")
if err == nil {
t.Error("没有加密密钥时应该返回错误")
}
if !strings.Contains(err.Error(), "加密密钥未配置") {
t.Errorf("期望错误包含'加密密钥未配置',实际: %v", err)
}
t.Logf("预期的错误: %v", err)
}
func TestDecryptWithoutKey(t *testing.T) {
// 创建没有加密密钥的服务
service := &ZhichaService{
config: ZhichaConfig{
URL: "https://test.com",
AppID: "test_app_id",
AppSecret: "test_app_secret",
// 没有设置 EncryptKey
},
logger: nil,
}
// 应该返回错误
_, err := service.Decrypt("test encrypted data")
if err == nil {
t.Error("没有加密密钥时应该返回错误")
}
if !strings.Contains(err.Error(), "加密密钥未配置") {
t.Errorf("期望错误包含'加密密钥未配置',实际: %v", err)
}
t.Logf("预期的错误: %v", err)
}

View File

@@ -0,0 +1,168 @@
package handlers
import (
"strconv"
"strings"
"time"
securityEntities "hyapi-server/internal/domains/security/entities"
"hyapi-server/internal/shared/interfaces"
"hyapi-server/internal/shared/ipgeo"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"gorm.io/gorm"
)
// AdminSecurityHandler 管理员安全数据处理器
type AdminSecurityHandler struct {
db *gorm.DB
responseBuilder interfaces.ResponseBuilder
logger *zap.Logger
ipLocator *ipgeo.Locator
}
func NewAdminSecurityHandler(
db *gorm.DB,
responseBuilder interfaces.ResponseBuilder,
logger *zap.Logger,
ipLocator *ipgeo.Locator,
) *AdminSecurityHandler {
return &AdminSecurityHandler{
db: db,
responseBuilder: responseBuilder,
logger: logger,
ipLocator: ipLocator,
}
}
func (h *AdminSecurityHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
if value := c.Query(key); value != "" {
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
return intValue
}
}
return defaultValue
}
func (h *AdminSecurityHandler) parseRange(c *gin.Context) (time.Time, time.Time, bool) {
startTime := time.Now().Add(-24 * time.Hour)
endTime := time.Now()
if start := strings.TrimSpace(c.Query("start_time")); start != "" {
t, err := time.Parse("2006-01-02 15:04:05", start)
if err != nil {
h.responseBuilder.BadRequest(c, "start_time格式错误示例2026-03-19 10:00:00")
return time.Time{}, time.Time{}, false
}
startTime = t
}
if end := strings.TrimSpace(c.Query("end_time")); end != "" {
t, err := time.Parse("2006-01-02 15:04:05", end)
if err != nil {
h.responseBuilder.BadRequest(c, "end_time格式错误示例2026-03-19 12:00:00")
return time.Time{}, time.Time{}, false
}
endTime = t
}
return startTime, endTime, true
}
// ListSuspiciousIPs 获取可疑IP列表
func (h *AdminSecurityHandler) ListSuspiciousIPs(c *gin.Context) {
page := h.getIntQuery(c, "page", 1)
pageSize := h.getIntQuery(c, "page_size", 20)
if pageSize > 100 {
pageSize = 100
}
startTime, endTime, ok := h.parseRange(c)
if !ok {
return
}
ip := strings.TrimSpace(c.Query("ip"))
path := strings.TrimSpace(c.Query("path"))
query := h.db.Model(&securityEntities.SuspiciousIPRecord{}).
Where("created_at >= ? AND created_at <= ?", startTime, endTime)
if ip != "" {
query = query.Where("ip = ?", ip)
}
if path != "" {
query = query.Where("path LIKE ?", "%"+path+"%")
}
var total int64
if err := query.Count(&total).Error; err != nil {
h.logger.Error("查询可疑IP总数失败", zap.Error(err))
h.responseBuilder.InternalError(c, "查询失败")
return
}
var items []securityEntities.SuspiciousIPRecord
if err := query.Order("created_at DESC").Offset((page - 1) * pageSize).Limit(pageSize).Find(&items).Error; err != nil {
h.logger.Error("查询可疑IP列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "查询失败")
return
}
h.responseBuilder.Success(c, gin.H{
"items": items,
"total": total,
}, "获取成功")
}
type geoStreamRow struct {
IP string `json:"ip"`
Path string `json:"path"`
Count int `json:"count"`
}
// GetSuspiciousIPGeoStream 获取地球请求流数据
func (h *AdminSecurityHandler) GetSuspiciousIPGeoStream(c *gin.Context) {
startTime, endTime, ok := h.parseRange(c)
if !ok {
return
}
topN := h.getIntQuery(c, "top_n", 200)
if topN > 1000 {
topN = 1000
}
var rows []geoStreamRow
err := h.db.Model(&securityEntities.SuspiciousIPRecord{}).
Select("ip, path, COUNT(1) as count").
Where("created_at >= ? AND created_at <= ?", startTime, endTime).
Group("ip, path").
Order("count DESC").
Limit(topN).
Scan(&rows).Error
if err != nil {
h.logger.Error("查询地球请求流失败", zap.Error(err))
h.responseBuilder.InternalError(c, "查询失败")
return
}
// 目标固定服务器点位(上海)
const serverName = "HYAPI-Server"
const serverLng = 121.4737
const serverLat = 31.2304
result := make([]gin.H, 0, len(rows))
for _, row := range rows {
record := securityEntities.SuspiciousIPRecord{IP: row.IP}
fromName, fromLng, fromLat := h.ipLocator.ToGeoPoint(record)
result = append(result, gin.H{
"from_name": fromName,
"from_lng": fromLng,
"from_lat": fromLat,
"to_name": serverName,
"to_lng": serverLng,
"to_lat": serverLat,
"value": row.Count,
"path": row.Path,
"ip": row.IP,
})
}
h.responseBuilder.Success(c, result, "获取成功")
}

View File

@@ -0,0 +1,411 @@
package handlers
import (
"hyapi-server/internal/application/article"
"hyapi-server/internal/application/article/dto/commands"
appQueries "hyapi-server/internal/application/article/dto/queries"
"hyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// AnnouncementHandler 公告HTTP处理器
type AnnouncementHandler struct {
appService article.AnnouncementApplicationService
responseBuilder interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
}
// NewAnnouncementHandler 创建公告HTTP处理器
func NewAnnouncementHandler(
appService article.AnnouncementApplicationService,
responseBuilder interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
) *AnnouncementHandler {
return &AnnouncementHandler{
appService: appService,
responseBuilder: responseBuilder,
validator: validator,
logger: logger,
}
}
// CreateAnnouncement 创建公告
// @Summary 创建公告
// @Description 创建新的公告
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.CreateAnnouncementCommand true "创建公告请求"
// @Success 201 {object} map[string]interface{} "公告创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements [post]
func (h *AnnouncementHandler) CreateAnnouncement(c *gin.Context) {
var cmd commands.CreateAnnouncementCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
// 验证用户是否已登录
if _, exists := c.Get("user_id"); !exists {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
if err := h.appService.CreateAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("创建公告失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, nil, "公告创建成功")
}
// GetAnnouncementByID 获取公告详情
// @Summary 获取公告详情
// @Description 根据ID获取公告详情
// @Tags 公告管理-用户端
// @Accept json
// @Produce json
// @Param id path string true "公告ID"
// @Success 200 {object} responses.AnnouncementInfoResponse "获取公告详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 404 {object} map[string]interface{} "公告不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/announcements/{id} [get]
func (h *AnnouncementHandler) GetAnnouncementByID(c *gin.Context) {
var query appQueries.GetAnnouncementQuery
// 绑定URI参数公告ID
if err := h.validator.ValidateParam(c, &query); err != nil {
return
}
response, err := h.appService.GetAnnouncementByID(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取公告详情失败", zap.Error(err))
h.responseBuilder.NotFound(c, "公告不存在")
return
}
h.responseBuilder.Success(c, response, "获取公告详情成功")
}
// ListAnnouncements 获取公告列表
// @Summary 获取公告列表
// @Description 分页获取公告列表,支持多种筛选条件
// @Tags 公告管理-用户端
// @Accept json
// @Produce json
// @Param page query int false "页码" default(1)
// @Param page_size query int false "每页数量" default(10)
// @Param status query string false "公告状态"
// @Param title query string false "标题关键词"
// @Param order_by query string false "排序字段"
// @Param order_dir query string false "排序方向"
// @Success 200 {object} responses.AnnouncementListResponse "获取公告列表成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/announcements [get]
func (h *AnnouncementHandler) ListAnnouncements(c *gin.Context) {
var query appQueries.ListAnnouncementQuery
if err := h.validator.ValidateQuery(c, &query); err != nil {
return
}
// 设置默认值
if query.Page <= 0 {
query.Page = 1
}
if query.PageSize <= 0 {
query.PageSize = 10
}
if query.PageSize > 100 {
query.PageSize = 100
}
response, err := h.appService.ListAnnouncements(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取公告列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取公告列表失败")
return
}
h.responseBuilder.Success(c, response, "获取公告列表成功")
}
// PublishAnnouncement 发布公告
// @Summary 发布公告
// @Description 发布指定的公告
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Success 200 {object} map[string]interface{} "发布成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id}/publish [post]
func (h *AnnouncementHandler) PublishAnnouncement(c *gin.Context) {
var cmd commands.PublishAnnouncementCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.PublishAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("发布公告失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "发布成功")
}
// WithdrawAnnouncement 撤回公告
// @Summary 撤回公告
// @Description 撤回已发布的公告
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Success 200 {object} map[string]interface{} "撤回成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id}/withdraw [post]
func (h *AnnouncementHandler) WithdrawAnnouncement(c *gin.Context) {
var cmd commands.WithdrawAnnouncementCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.WithdrawAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("撤回公告失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "撤回成功")
}
// ArchiveAnnouncement 归档公告
// @Summary 归档公告
// @Description 归档指定的公告
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Success 200 {object} map[string]interface{} "归档成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id}/archive [post]
func (h *AnnouncementHandler) ArchiveAnnouncement(c *gin.Context) {
var cmd commands.ArchiveAnnouncementCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.ArchiveAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("归档公告失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "归档成功")
}
// UpdateAnnouncement 更新公告
// @Summary 更新公告
// @Description 更新指定的公告
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Param request body commands.UpdateAnnouncementCommand true "更新公告请求"
// @Success 200 {object} map[string]interface{} "更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id} [put]
func (h *AnnouncementHandler) UpdateAnnouncement(c *gin.Context) {
var cmd commands.UpdateAnnouncementCommand
// 先绑定URI参数公告ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体公告信息
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.UpdateAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("更新公告失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "更新成功")
}
// DeleteAnnouncement 删除公告
// @Summary 删除公告
// @Description 删除指定的公告
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Success 200 {object} map[string]interface{} "删除成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id} [delete]
func (h *AnnouncementHandler) DeleteAnnouncement(c *gin.Context) {
var cmd commands.DeleteAnnouncementCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.DeleteAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("删除公告失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "删除成功")
}
// SchedulePublishAnnouncement 定时发布公告
// @Summary 定时发布公告
// @Description 设置公告的定时发布时间
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Param request body commands.SchedulePublishAnnouncementCommand true "定时发布请求"
// @Success 200 {object} map[string]interface{} "设置成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id}/schedule-publish [post]
func (h *AnnouncementHandler) SchedulePublishAnnouncement(c *gin.Context) {
var cmd commands.SchedulePublishAnnouncementCommand
// 先绑定URI参数公告ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体定时发布时间
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.SchedulePublishAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("设置定时发布失败", zap.String("id", cmd.ID), zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "设置成功")
}
// UpdateSchedulePublishAnnouncement 更新定时发布公告
// @Summary 更新定时发布公告
// @Description 修改公告的定时发布时间
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Param request body commands.UpdateSchedulePublishAnnouncementCommand true "更新定时发布请求"
// @Success 200 {object} map[string]interface{} "更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id}/update-schedule-publish [post]
func (h *AnnouncementHandler) UpdateSchedulePublishAnnouncement(c *gin.Context) {
var cmd commands.UpdateSchedulePublishAnnouncementCommand
// 先绑定URI参数公告ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体定时发布时间
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.UpdateSchedulePublishAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("更新定时发布时间失败", zap.String("id", cmd.ID), zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "更新成功")
}
// CancelSchedulePublishAnnouncement 取消定时发布公告
// @Summary 取消定时发布公告
// @Description 取消公告的定时发布
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "公告ID"
// @Success 200 {object} map[string]interface{} "取消成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/{id}/cancel-schedule [post]
func (h *AnnouncementHandler) CancelSchedulePublishAnnouncement(c *gin.Context) {
var cmd commands.CancelSchedulePublishAnnouncementCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.CancelSchedulePublishAnnouncement(c.Request.Context(), &cmd); err != nil {
h.logger.Error("取消定时发布失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "取消成功")
}
// GetAnnouncementStats 获取公告统计信息
// @Summary 获取公告统计信息
// @Description 获取公告的统计数据
// @Tags 公告管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.AnnouncementStatsResponse "获取统计信息成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/announcements/stats [get]
func (h *AnnouncementHandler) GetAnnouncementStats(c *gin.Context) {
response, err := h.appService.GetAnnouncementStats(c.Request.Context())
if err != nil {
h.logger.Error("获取公告统计信息失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取统计信息失败")
return
}
h.responseBuilder.Success(c, response, "获取统计信息成功")
}

View File

@@ -0,0 +1,666 @@
package handlers
import (
"strconv"
"time"
"hyapi-server/internal/application/api"
"hyapi-server/internal/application/api/commands"
"hyapi-server/internal/application/api/dto"
"hyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ApiHandler API调用HTTP处理器
type ApiHandler struct {
appService api.ApiApplicationService
responseBuilder interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
}
// NewApiHandler 创建API调用HTTP处理器
func NewApiHandler(
appService api.ApiApplicationService,
responseBuilder interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
) *ApiHandler {
return &ApiHandler{
appService: appService,
responseBuilder: responseBuilder,
validator: validator,
logger: logger,
}
}
// HandleApiCall 统一API调用入口
// @Summary API调用
// @Description 统一API调用入口参数加密传输
// @Tags API调用
// @Accept json
// @Produce json
// @Param request body commands.ApiCallCommand true "API调用请求"
// @Success 200 {object} dto.ApiCallResponse "调用成功"
// @Failure 400 {object} dto.ApiCallResponse "请求参数错误"
// @Failure 401 {object} dto.ApiCallResponse "未授权"
// @Failure 429 {object} dto.ApiCallResponse "请求过于频繁"
// @Failure 500 {object} dto.ApiCallResponse "服务器内部错误"
// @Router /api/v1/:api_name [post]
func (h *ApiHandler) HandleApiCall(c *gin.Context) {
// 1. 基础参数校验
accessId := c.GetHeader("Access-Id")
if accessId == "" {
response := dto.NewErrorResponse(1005, "缺少Access-Id", "")
c.JSON(200, response)
return
}
// 2. 绑定和校验请求参数
var cmd commands.ApiCallCommand
cmd.ClientIP = c.ClientIP()
cmd.AccessId = accessId
cmd.ApiName = c.Param("api_name")
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
response := dto.NewErrorResponse(1003, "请求参数结构不正确", "")
c.JSON(200, response)
return
}
// 3. 调用应用服务
transactionId, encryptedResp, err := h.appService.CallApi(c.Request.Context(), &cmd)
if err != nil {
// 根据错误类型返回对应的错误码和预定义错误消息
errorCode := api.GetErrorCode(err)
errorMessage := api.GetErrorMessage(err)
response := dto.NewErrorResponse(errorCode, errorMessage, transactionId)
c.JSON(200, response) // API调用接口统一返回200状态码
return
}
// 4. 返回成功响应
response := dto.NewSuccessResponse(transactionId, encryptedResp)
c.JSON(200, response)
}
// GetUserApiKeys 获取用户API密钥
func (h *ApiHandler) GetUserApiKeys(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
result, err := h.appService.GetUserApiKeys(c.Request.Context(), userID)
if err != nil {
h.logger.Error("获取用户API密钥失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "获取API密钥成功")
}
// GetUserWhiteList 获取用户白名单列表
func (h *ApiHandler) GetUserWhiteList(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
// 获取查询参数
remarkKeyword := c.Query("remark") // 备注模糊查询关键词
result, err := h.appService.GetUserWhiteList(c.Request.Context(), userID, remarkKeyword)
if err != nil {
h.logger.Error("获取用户白名单失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, result, "获取白名单成功")
}
// AddWhiteListIP 添加白名单IP
func (h *ApiHandler) AddWhiteListIP(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var req dto.WhiteListRequest
if err := h.validator.BindAndValidate(c, &req); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
err := h.appService.AddWhiteListIP(c.Request.Context(), userID, req.IPAddress, req.Remark)
if err != nil {
h.logger.Error("添加白名单IP失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "添加白名单IP成功")
}
// DeleteWhiteListIP 删除白名单IP
// @Summary 删除白名单IP
// @Description 从当前用户的白名单中删除指定IP地址
// @Tags API管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param ip path string true "IP地址"
// @Success 200 {object} map[string]interface{} "删除白名单IP成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/my/whitelist/{ip} [delete]
func (h *ApiHandler) DeleteWhiteListIP(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
ipAddress := c.Param("ip")
if ipAddress == "" {
h.responseBuilder.BadRequest(c, "IP地址不能为空")
return
}
err := h.appService.DeleteWhiteListIP(c.Request.Context(), userID, ipAddress)
if err != nil {
h.logger.Error("删除白名单IP失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "删除白名单IP成功")
}
// EncryptParams 加密参数接口(用于前端调试)
// @Summary 加密参数
// @Description 用于前端调试时加密API调用参数
// @Tags API调试
// @Accept json
// @Produce json
// @Param request body commands.EncryptCommand true "加密请求"
// @Success 200 {object} dto.EncryptResponse "加密成功"
// @Failure 400 {object} dto.EncryptResponse "请求参数错误"
// @Failure 401 {object} dto.EncryptResponse "未授权"
// @Router /api/v1/encrypt [post]
func (h *ApiHandler) EncryptParams(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var cmd commands.EncryptCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
// 调用应用服务层进行加密
encryptedData, err := h.appService.EncryptParams(c.Request.Context(), userID, &cmd)
if err != nil {
h.logger.Error("加密参数失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "加密参数失败")
return
}
response := dto.EncryptResponse{
EncryptedData: encryptedData,
}
h.responseBuilder.Success(c, response, "加密成功")
}
// DecryptParams 解密参数
// @Summary 解密参数
// @Description 使用密钥解密加密的数据
// @Tags API调试
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.DecryptCommand true "解密请求"
// @Success 200 {object} map[string]interface{} "解密成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "解密失败"
// @Router /api/v1/decrypt [post]
func (h *ApiHandler) DecryptParams(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var cmd commands.DecryptCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误")
return
}
// 调用应用服务层进行解密
decryptedData, err := h.appService.DecryptParams(c.Request.Context(), userID, &cmd)
if err != nil {
h.logger.Error("解密参数失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "解密参数失败")
return
}
h.responseBuilder.Success(c, decryptedData, "解密成功")
}
// GetFormConfig 获取指定API的表单配置
// @Summary 获取表单配置
// @Description 获取指定API的表单配置用于前端动态生成表单
// @Tags API调试
// @Accept json
// @Produce json
// @Security Bearer
// @Param api_code path string true "API代码"
// @Success 200 {object} map[string]interface{} "获取成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 404 {object} map[string]interface{} "API接口不存在"
// @Router /api/v1/form-config/{api_code} [get]
func (h *ApiHandler) GetFormConfig(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
apiCode := c.Param("api_code")
if apiCode == "" {
h.responseBuilder.BadRequest(c, "API代码不能为空")
return
}
h.logger.Info("获取表单配置", zap.String("api_code", apiCode), zap.String("user_id", userID))
// 获取表单配置
config, err := h.appService.GetFormConfig(c.Request.Context(), apiCode)
if err != nil {
h.logger.Error("获取表单配置失败", zap.String("api_code", apiCode), zap.String("user_id", userID), zap.Error(err))
h.responseBuilder.BadRequest(c, "获取表单配置失败")
return
}
if config == nil {
h.responseBuilder.BadRequest(c, "API接口不存在")
return
}
h.logger.Info("获取表单配置成功", zap.String("api_code", apiCode), zap.String("user_id", userID), zap.Int("field_count", len(config.Fields)))
h.responseBuilder.Success(c, config, "获取表单配置成功")
}
// getCurrentUserID 获取当前用户ID
func (h *ApiHandler) getCurrentUserID(c *gin.Context) string {
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}
// GetUserApiCalls 获取用户API调用记录
// @Summary 获取用户API调用记录
// @Description 获取当前用户的API调用记录列表支持分页和筛选
// @Tags API管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param page query int false "页码" default(1)
// @Param page_size query int false "每页数量" default(10)
// @Param start_time query string false "开始时间 (格式: 2006-01-02 15:04:05)"
// @Param end_time query string false "结束时间 (格式: 2006-01-02 15:04:05)"
// @Param transaction_id query string false "交易ID"
// @Param product_name query string false "产品名称"
// @Param status query string false "状态 (pending/success/failed)"
// @Success 200 {object} dto.ApiCallListResponse "获取成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/my/api-calls [get]
func (h *ApiHandler) GetUserApiCalls(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
// 解析查询参数
page := h.getIntQuery(c, "page", 1)
pageSize := h.getIntQuery(c, "page_size", 10)
// 构建筛选条件
filters := make(map[string]interface{})
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 交易ID筛选
if transactionId := c.Query("transaction_id"); transactionId != "" {
filters["transaction_id"] = transactionId
}
// 产品名称筛选
if productName := c.Query("product_name"); productName != "" {
filters["product_name"] = productName
}
// 状态筛选
if status := c.Query("status"); status != "" {
filters["status"] = status
}
// 构建分页选项
options := interfaces.ListOptions{
Page: page,
PageSize: pageSize,
Sort: "created_at",
Order: "desc",
}
result, err := h.appService.GetUserApiCalls(c.Request.Context(), userID, filters, options)
if err != nil {
h.logger.Error("获取用户API调用记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "获取API调用记录失败")
return
}
h.responseBuilder.Success(c, result, "获取API调用记录成功")
}
// GetAdminApiCalls 获取管理端API调用记录
// @Summary 获取管理端API调用记录
// @Description 管理员获取API调用记录支持筛选和分页
// @Tags API管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param page query int false "页码" default(1)
// @Param page_size query int false "每页数量" default(10)
// @Param user_id query string false "用户ID"
// @Param transaction_id query string false "交易ID"
// @Param product_name query string false "产品名称"
// @Param status query string false "状态"
// @Param start_time query string false "开始时间" format(date-time)
// @Param end_time query string false "结束时间" format(date-time)
// @Param sort_by query string false "排序字段"
// @Param sort_order query string false "排序方向" Enums(asc, desc)
// @Success 200 {object} dto.ApiCallListResponse "获取API调用记录成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/api-calls [get]
func (h *ApiHandler) GetAdminApiCalls(c *gin.Context) {
// 解析查询参数
page := h.getIntQuery(c, "page", 1)
pageSize := h.getIntQuery(c, "page_size", 10)
// 构建筛选条件
filters := make(map[string]interface{})
// 用户ID筛选
if userId := c.Query("user_id"); userId != "" {
filters["user_id"] = userId
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 交易ID筛选
if transactionId := c.Query("transaction_id"); transactionId != "" {
filters["transaction_id"] = transactionId
}
// 产品名称筛选
if productName := c.Query("product_name"); productName != "" {
filters["product_name"] = productName
}
// 状态筛选
if status := c.Query("status"); status != "" {
filters["status"] = status
}
// 构建分页选项
options := interfaces.ListOptions{
Page: page,
PageSize: pageSize,
Sort: "created_at",
Order: "desc",
}
result, err := h.appService.GetAdminApiCalls(c.Request.Context(), filters, options)
if err != nil {
h.logger.Error("获取管理端API调用记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "获取API调用记录失败")
return
}
h.responseBuilder.Success(c, result, "获取API调用记录成功")
}
// ExportAdminApiCalls 导出管理端API调用记录
// @Summary 导出管理端API调用记录
// @Description 管理员导出API调用记录支持Excel和CSV格式
// @Tags API调用管理
// @Accept json
// @Produce application/vnd.openxmlformats-officedocument.spreadsheetml.sheet,text/csv
// @Security Bearer
// @Param user_ids query string false "用户ID列表逗号分隔"
// @Param product_ids query string false "产品ID列表逗号分隔"
// @Param start_time query string false "开始时间" format(date-time)
// @Param end_time query string false "结束时间" format(date-time)
// @Param format query string false "导出格式" Enums(excel, csv) default(excel)
// @Success 200 {file} file "导出文件"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/api-calls/export [get]
func (h *ApiHandler) ExportAdminApiCalls(c *gin.Context) {
// 解析查询参数
filters := make(map[string]interface{})
// 用户ID筛选
if userIds := c.Query("user_ids"); userIds != "" {
filters["user_ids"] = userIds
}
// 产品ID筛选
if productIds := c.Query("product_ids"); productIds != "" {
filters["product_ids"] = productIds
}
// 时间范围筛选
if startTime := c.Query("start_time"); startTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", startTime); err == nil {
filters["start_time"] = t
}
}
if endTime := c.Query("end_time"); endTime != "" {
if t, err := time.Parse("2006-01-02 15:04:05", endTime); err == nil {
filters["end_time"] = t
}
}
// 获取导出格式默认为excel
format := c.DefaultQuery("format", "excel")
if format != "excel" && format != "csv" {
h.responseBuilder.BadRequest(c, "不支持的导出格式")
return
}
// 调用应用服务导出数据
fileData, err := h.appService.ExportAdminApiCalls(c.Request.Context(), filters, format)
if err != nil {
h.logger.Error("导出API调用记录失败", zap.Error(err))
h.responseBuilder.BadRequest(c, "导出API调用记录失败")
return
}
// 设置响应头
contentType := "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
filename := "API调用记录.xlsx"
if format == "csv" {
contentType = "text/csv;charset=utf-8"
filename = "API调用记录.csv"
}
c.Header("Content-Type", contentType)
c.Header("Content-Disposition", "attachment; filename="+filename)
c.Data(200, contentType, fileData)
}
// getIntQuery 获取整数查询参数
func (h *ApiHandler) getIntQuery(c *gin.Context, key string, defaultValue int) int {
if value := c.Query(key); value != "" {
if intValue, err := strconv.Atoi(value); err == nil && intValue > 0 {
return intValue
}
}
return defaultValue
}
// GetUserBalanceAlertSettings 获取用户余额预警设置
// @Summary 获取用户余额预警设置
// @Description 获取当前用户的余额预警配置
// @Tags 用户设置
// @Accept json
// @Produce json
// @Success 200 {object} map[string]interface{} "获取成功"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/user/balance-alert/settings [get]
func (h *ApiHandler) GetUserBalanceAlertSettings(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
settings, err := h.appService.GetUserBalanceAlertSettings(c.Request.Context(), userID)
if err != nil {
h.logger.Error("获取用户余额预警设置失败",
zap.String("user_id", userID),
zap.Error(err))
h.responseBuilder.InternalError(c, "获取预警设置失败")
return
}
h.responseBuilder.Success(c, settings, "获取成功")
}
// UpdateUserBalanceAlertSettings 更新用户余额预警设置
// @Summary 更新用户余额预警设置
// @Description 更新当前用户的余额预警配置
// @Tags 用户设置
// @Accept json
// @Produce json
// @Param request body map[string]interface{} true "预警设置"
// @Success 200 {object} map[string]interface{} "更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/user/balance-alert/settings [put]
func (h *ApiHandler) UpdateUserBalanceAlertSettings(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var request struct {
Enabled bool `json:"enabled" binding:"required"`
Threshold float64 `json:"threshold" binding:"required,min=0"`
AlertPhone string `json:"alert_phone" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误: "+err.Error())
return
}
err := h.appService.UpdateUserBalanceAlertSettings(c.Request.Context(), userID, request.Enabled, request.Threshold, request.AlertPhone)
if err != nil {
h.logger.Error("更新用户余额预警设置失败",
zap.String("user_id", userID),
zap.Error(err))
h.responseBuilder.InternalError(c, "更新预警设置失败")
return
}
h.responseBuilder.Success(c, gin.H{}, "更新成功")
}
// TestBalanceAlertSms 测试余额预警短信
// @Summary 测试余额预警短信
// @Description 发送测试预警短信到指定手机号
// @Tags 用户设置
// @Accept json
// @Produce json
// @Param request body map[string]interface{} true "测试参数"
// @Success 200 {object} map[string]interface{} "发送成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未授权"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/user/balance-alert/test-sms [post]
func (h *ApiHandler) TestBalanceAlertSms(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
var request struct {
Phone string `json:"phone" binding:"required,len=11"`
Balance float64 `json:"balance" binding:"required"`
AlertType string `json:"alert_type" binding:"required,oneof=low_balance arrears"`
}
if err := c.ShouldBindJSON(&request); err != nil {
h.responseBuilder.BadRequest(c, "请求参数错误: "+err.Error())
return
}
err := h.appService.TestBalanceAlertSms(c.Request.Context(), userID, request.Phone, request.Balance, request.AlertType)
if err != nil {
h.logger.Error("发送测试预警短信失败",
zap.String("user_id", userID),
zap.Error(err))
h.responseBuilder.InternalError(c, "发送测试短信失败")
return
}
h.responseBuilder.Success(c, gin.H{}, "测试短信发送成功")
}

View File

@@ -0,0 +1,776 @@
//nolint:unused
package handlers
import (
"hyapi-server/internal/application/article"
"hyapi-server/internal/application/article/dto/commands"
appQueries "hyapi-server/internal/application/article/dto/queries"
_ "hyapi-server/internal/application/article/dto/responses"
"hyapi-server/internal/shared/interfaces"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
// ArticleHandler 文章HTTP处理器
type ArticleHandler struct {
appService article.ArticleApplicationService
responseBuilder interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
}
// NewArticleHandler 创建文章HTTP处理器
func NewArticleHandler(
appService article.ArticleApplicationService,
responseBuilder interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
) *ArticleHandler {
return &ArticleHandler{
appService: appService,
responseBuilder: responseBuilder,
validator: validator,
logger: logger,
}
}
// CreateArticle 创建文章
// @Summary 创建文章
// @Description 创建新的文章
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.CreateArticleCommand true "创建文章请求"
// @Success 201 {object} map[string]interface{} "文章创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles [post]
func (h *ArticleHandler) CreateArticle(c *gin.Context) {
var cmd commands.CreateArticleCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
// 验证用户是否已登录
if _, exists := c.Get("user_id"); !exists {
h.responseBuilder.Unauthorized(c, "用户未登录")
return
}
if err := h.appService.CreateArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("创建文章失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, nil, "文章创建成功")
}
// GetArticleByID 获取文章详情
// @Summary 获取文章详情
// @Description 根据ID获取文章详情
// @Tags 文章管理-用户端
// @Accept json
// @Produce json
// @Param id path string true "文章ID"
// @Success 200 {object} responses.ArticleInfoResponse "获取文章详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/articles/{id} [get]
func (h *ArticleHandler) GetArticleByID(c *gin.Context) {
var query appQueries.GetArticleQuery
// 绑定URI参数文章ID
if err := h.validator.ValidateParam(c, &query); err != nil {
return
}
response, err := h.appService.GetArticleByID(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取文章详情失败", zap.Error(err))
h.responseBuilder.NotFound(c, "文章不存在")
return
}
h.responseBuilder.Success(c, response, "获取文章详情成功")
}
// ListArticles 获取文章列表
// @Summary 获取文章列表
// @Description 分页获取文章列表,支持多种筛选条件
// @Tags 文章管理-用户端
// @Accept json
// @Produce json
// @Param page query int false "页码" default(1)
// @Param page_size query int false "每页数量" default(10)
// @Param status query string false "文章状态"
// @Param category_id query string false "分类ID"
// @Param tag_id query string false "标签ID"
// @Param title query string false "标题关键词"
// @Param summary query string false "摘要关键词"
// @Param is_featured query bool false "是否推荐"
// @Param order_by query string false "排序字段"
// @Param order_dir query string false "排序方向"
// @Success 200 {object} responses.ArticleListResponse "获取文章列表成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/articles [get]
func (h *ArticleHandler) ListArticles(c *gin.Context) {
var query appQueries.ListArticleQuery
if err := h.validator.ValidateQuery(c, &query); err != nil {
return
}
// 设置默认值
if query.Page <= 0 {
query.Page = 1
}
if query.PageSize <= 0 {
query.PageSize = 10
}
if query.PageSize > 100 {
query.PageSize = 100
}
response, err := h.appService.ListArticles(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取文章列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取文章列表失败")
return
}
h.responseBuilder.Success(c, response, "获取文章列表成功")
}
// ListArticlesForAdmin 获取文章列表(管理员端)
// @Summary 获取文章列表(管理员端)
// @Description 分页获取文章列表,支持多种筛选条件,包含所有状态的文章
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param page query int false "页码" default(1)
// @Param page_size query int false "每页数量" default(10)
// @Param status query string false "文章状态"
// @Param category_id query string false "分类ID"
// @Param tag_id query string false "标签ID"
// @Param title query string false "标题关键词"
// @Param summary query string false "摘要关键词"
// @Param is_featured query bool false "是否推荐"
// @Param order_by query string false "排序字段"
// @Param order_dir query string false "排序方向"
// @Success 200 {object} responses.ArticleListResponse "获取文章列表成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles [get]
func (h *ArticleHandler) ListArticlesForAdmin(c *gin.Context) {
var query appQueries.ListArticleQuery
if err := h.validator.ValidateQuery(c, &query); err != nil {
return
}
// 设置默认值
if query.Page <= 0 {
query.Page = 1
}
if query.PageSize <= 0 {
query.PageSize = 10
}
if query.PageSize > 100 {
query.PageSize = 100
}
response, err := h.appService.ListArticlesForAdmin(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取文章列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取文章列表失败")
return
}
h.responseBuilder.Success(c, response, "获取文章列表成功")
}
// UpdateArticle 更新文章
// @Summary 更新文章
// @Description 更新文章信息
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Param request body commands.UpdateArticleCommand true "更新文章请求"
// @Success 200 {object} map[string]interface{} "文章更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id} [put]
func (h *ArticleHandler) UpdateArticle(c *gin.Context) {
var cmd commands.UpdateArticleCommand
// 先绑定URI参数文章ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体文章信息
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.UpdateArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("更新文章失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "文章更新成功")
}
// DeleteArticle 删除文章
// @Summary 删除文章
// @Description 删除指定文章
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Success 200 {object} map[string]interface{} "文章删除成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id} [delete]
func (h *ArticleHandler) DeleteArticle(c *gin.Context) {
var cmd commands.DeleteArticleCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.DeleteArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("删除文章失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "文章删除成功")
}
// PublishArticle 发布文章
// @Summary 发布文章
// @Description 将草稿文章发布
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Success 200 {object} map[string]interface{} "文章发布成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id}/publish [post]
func (h *ArticleHandler) PublishArticle(c *gin.Context) {
var cmd commands.PublishArticleCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.PublishArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("发布文章失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "文章发布成功")
}
// SchedulePublishArticle 定时发布文章
// @Summary 定时发布文章
// @Description 设置文章的定时发布时间支持格式YYYY-MM-DD HH:mm:ss
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Param request body commands.SchedulePublishCommand true "定时发布请求"
// @Success 200 {object} map[string]interface{} "定时发布设置成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id}/schedule-publish [post]
func (h *ArticleHandler) SchedulePublishArticle(c *gin.Context) {
var cmd commands.SchedulePublishCommand
// 先绑定URI参数文章ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体定时发布时间
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.SchedulePublishArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("设置定时发布失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "定时发布设置成功")
}
// CancelSchedulePublishArticle 取消定时发布文章
// @Summary 取消定时发布文章
// @Description 取消文章的定时发布设置
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Success 200 {object} map[string]interface{} "取消定时发布成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id}/cancel-schedule [post]
func (h *ArticleHandler) CancelSchedulePublishArticle(c *gin.Context) {
var cmd commands.CancelScheduleCommand
// 绑定URI参数文章ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.CancelSchedulePublishArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("取消定时发布失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "取消定时发布成功")
}
// ArchiveArticle 归档文章
// @Summary 归档文章
// @Description 将已发布文章归档
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Success 200 {object} map[string]interface{} "文章归档成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id}/archive [post]
func (h *ArticleHandler) ArchiveArticle(c *gin.Context) {
var cmd commands.ArchiveArticleCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.ArchiveArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("归档文章失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "文章归档成功")
}
// SetFeatured 设置推荐状态
// @Summary 设置推荐状态
// @Description 设置文章的推荐状态
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Param request body commands.SetFeaturedCommand true "设置推荐状态请求"
// @Success 200 {object} map[string]interface{} "设置推荐状态成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id}/featured [put]
func (h *ArticleHandler) SetFeatured(c *gin.Context) {
var cmd commands.SetFeaturedCommand
// 先绑定URI参数文章ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体推荐状态
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.SetFeatured(c.Request.Context(), &cmd); err != nil {
h.logger.Error("设置推荐状态失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "设置推荐状态成功")
}
// GetArticleStats 获取文章统计
// @Summary 获取文章统计
// @Description 获取文章相关统计数据
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.ArticleStatsResponse "获取统计成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/stats [get]
func (h *ArticleHandler) GetArticleStats(c *gin.Context) {
response, err := h.appService.GetArticleStats(c.Request.Context())
if err != nil {
h.logger.Error("获取文章统计失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取文章统计失败")
return
}
h.responseBuilder.Success(c, response, "获取统计成功")
}
// UpdateSchedulePublishArticle 修改定时发布时间
// @Summary 修改定时发布时间
// @Description 修改文章的定时发布时间
// @Tags 文章管理-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "文章ID"
// @Param request body commands.SchedulePublishCommand true "修改定时发布请求"
// @Success 200 {object} map[string]interface{} "修改定时发布时间成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "文章不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/articles/{id}/update-schedule-publish [post]
func (h *ArticleHandler) UpdateSchedulePublishArticle(c *gin.Context) {
var cmd commands.SchedulePublishCommand
// 先绑定URI参数文章ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体定时发布时间
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.UpdateSchedulePublishArticle(c.Request.Context(), &cmd); err != nil {
h.logger.Error("修改定时发布时间失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "修改定时发布时间成功")
}
// ==================== 分类相关方法 ====================
// ListCategories 获取分类列表
// @Summary 获取分类列表
// @Description 获取所有文章分类
// @Tags 文章分类-用户端
// @Accept json
// @Produce json
// @Success 200 {object} responses.CategoryListResponse "获取分类列表成功"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/article-categories [get]
func (h *ArticleHandler) ListCategories(c *gin.Context) {
response, err := h.appService.ListCategories(c.Request.Context())
if err != nil {
h.logger.Error("获取分类列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取分类列表失败")
return
}
h.responseBuilder.Success(c, response, "获取分类列表成功")
}
// GetCategoryByID 获取分类详情
// @Summary 获取分类详情
// @Description 根据ID获取分类详情
// @Tags 文章分类-用户端
// @Accept json
// @Produce json
// @Param id path string true "分类ID"
// @Success 200 {object} responses.CategoryInfoResponse "获取分类详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 404 {object} map[string]interface{} "分类不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/article-categories/{id} [get]
func (h *ArticleHandler) GetCategoryByID(c *gin.Context) {
var query appQueries.GetCategoryQuery
// 绑定URI参数分类ID
if err := h.validator.ValidateParam(c, &query); err != nil {
return
}
response, err := h.appService.GetCategoryByID(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取分类详情失败", zap.Error(err))
h.responseBuilder.NotFound(c, "分类不存在")
return
}
h.responseBuilder.Success(c, response, "获取分类详情成功")
}
// CreateCategory 创建分类
// @Summary 创建分类
// @Description 创建新的文章分类
// @Tags 文章分类-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.CreateCategoryCommand true "创建分类请求"
// @Success 201 {object} map[string]interface{} "分类创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/article-categories [post]
func (h *ArticleHandler) CreateCategory(c *gin.Context) {
var cmd commands.CreateCategoryCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.CreateCategory(c.Request.Context(), &cmd); err != nil {
h.logger.Error("创建分类失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, nil, "分类创建成功")
}
// UpdateCategory 更新分类
// @Summary 更新分类
// @Description 更新分类信息
// @Tags 文章分类-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "分类ID"
// @Param request body commands.UpdateCategoryCommand true "更新分类请求"
// @Success 200 {object} map[string]interface{} "分类更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "分类不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/article-categories/{id} [put]
func (h *ArticleHandler) UpdateCategory(c *gin.Context) {
var cmd commands.UpdateCategoryCommand
// 先绑定URI参数分类ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体分类信息
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.UpdateCategory(c.Request.Context(), &cmd); err != nil {
h.logger.Error("更新分类失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "分类更新成功")
}
// DeleteCategory 删除分类
// @Summary 删除分类
// @Description 删除指定分类
// @Tags 文章分类-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "分类ID"
// @Success 200 {object} map[string]interface{} "分类删除成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "分类不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/article-categories/{id} [delete]
func (h *ArticleHandler) DeleteCategory(c *gin.Context) {
var cmd commands.DeleteCategoryCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.DeleteCategory(c.Request.Context(), &cmd); err != nil {
h.logger.Error("删除分类失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "分类删除成功")
}
// ==================== 标签相关方法 ====================
// ListTags 获取标签列表
// @Summary 获取标签列表
// @Description 获取所有文章标签
// @Tags 文章标签-用户端
// @Accept json
// @Produce json
// @Success 200 {object} responses.TagListResponse "获取标签列表成功"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/article-tags [get]
func (h *ArticleHandler) ListTags(c *gin.Context) {
response, err := h.appService.ListTags(c.Request.Context())
if err != nil {
h.logger.Error("获取标签列表失败", zap.Error(err))
h.responseBuilder.InternalError(c, "获取标签列表失败")
return
}
h.responseBuilder.Success(c, response, "获取标签列表成功")
}
// GetTagByID 获取标签详情
// @Summary 获取标签详情
// @Description 根据ID获取标签详情
// @Tags 文章标签-用户端
// @Accept json
// @Produce json
// @Param id path string true "标签ID"
// @Success 200 {object} responses.TagInfoResponse "获取标签详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 404 {object} map[string]interface{} "标签不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/article-tags/{id} [get]
func (h *ArticleHandler) GetTagByID(c *gin.Context) {
var query appQueries.GetTagQuery
// 绑定URI参数标签ID
if err := h.validator.ValidateParam(c, &query); err != nil {
return
}
response, err := h.appService.GetTagByID(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取标签详情失败", zap.Error(err))
h.responseBuilder.NotFound(c, "标签不存在")
return
}
h.responseBuilder.Success(c, response, "获取标签详情成功")
}
// CreateTag 创建标签
// @Summary 创建标签
// @Description 创建新的文章标签
// @Tags 文章标签-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.CreateTagCommand true "创建标签请求"
// @Success 201 {object} map[string]interface{} "标签创建成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/article-tags [post]
func (h *ArticleHandler) CreateTag(c *gin.Context) {
var cmd commands.CreateTagCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.CreateTag(c.Request.Context(), &cmd); err != nil {
h.logger.Error("创建标签失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Created(c, nil, "标签创建成功")
}
// UpdateTag 更新标签
// @Summary 更新标签
// @Description 更新标签信息
// @Tags 文章标签-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "标签ID"
// @Param request body commands.UpdateTagCommand true "更新标签请求"
// @Success 200 {object} map[string]interface{} "标签更新成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "标签不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/article-tags/{id} [put]
func (h *ArticleHandler) UpdateTag(c *gin.Context) {
var cmd commands.UpdateTagCommand
// 先绑定URI参数标签ID
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
// 再绑定JSON请求体标签信息
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
if err := h.appService.UpdateTag(c.Request.Context(), &cmd); err != nil {
h.logger.Error("更新标签失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "标签更新成功")
}
// DeleteTag 删除标签
// @Summary 删除标签
// @Description 删除指定标签
// @Tags 文章标签-管理端
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "标签ID"
// @Success 200 {object} map[string]interface{} "标签删除成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "标签不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/admin/article-tags/{id} [delete]
func (h *ArticleHandler) DeleteTag(c *gin.Context) {
var cmd commands.DeleteTagCommand
if err := h.validator.ValidateParam(c, &cmd); err != nil {
return
}
if err := h.appService.DeleteTag(c.Request.Context(), &cmd); err != nil {
h.logger.Error("删除标签失败", zap.Error(err))
h.responseBuilder.BadRequest(c, err.Error())
return
}
h.responseBuilder.Success(c, nil, "标签删除成功")
}

View File

@@ -0,0 +1,92 @@
package handlers
import (
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"hyapi-server/internal/config"
"hyapi-server/internal/infrastructure/external/captcha"
"hyapi-server/internal/shared/interfaces"
)
// CaptchaHandler 验证码滑块HTTP 处理器
type CaptchaHandler struct {
captchaService *captcha.CaptchaService
response interfaces.ResponseBuilder
config *config.Config
logger *zap.Logger
}
// NewCaptchaHandler 创建验证码处理器
func NewCaptchaHandler(
captchaService *captcha.CaptchaService,
response interfaces.ResponseBuilder,
cfg *config.Config,
logger *zap.Logger,
) *CaptchaHandler {
return &CaptchaHandler{
captchaService: captchaService,
response: response,
config: cfg,
logger: logger,
}
}
// EncryptedSceneIdReq 获取加密场景 ID 的请求(可选参数)
type EncryptedSceneIdReq struct {
ExpireSeconds *int `form:"expire_seconds" json:"expire_seconds"` // 有效期秒数186400默认 3600
}
// GetEncryptedSceneId 获取加密场景 ID供前端加密模式初始化阿里云验证码
// @Summary 获取验证码加密场景ID
// @Description 用于加密模式下发 EncryptedSceneId前端用此初始化滑块验证码
// @Tags 验证码
// @Accept json
// @Produce json
// @Param body body EncryptedSceneIdReq false "可选expire_seconds 有效期(1-86400)默认3600"
// @Success 200 {object} map[string]interface{} "encryptedSceneId"
// @Failure 400 {object} map[string]interface{} "配置未启用或参数错误"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/captcha/encryptedSceneId [post]
func (h *CaptchaHandler) GetEncryptedSceneId(c *gin.Context) {
expireSec := 3600
if c.Request.ContentLength > 0 {
var req EncryptedSceneIdReq
if err := c.ShouldBindJSON(&req); err == nil && req.ExpireSeconds != nil {
expireSec = *req.ExpireSeconds
}
}
if expireSec <= 0 || expireSec > 86400 {
h.response.BadRequest(c, "expire_seconds 必须在 186400 之间")
return
}
encrypted, err := h.captchaService.GetEncryptedSceneId(expireSec)
if err != nil {
if err == captcha.ErrCaptchaEncryptMissing || err == captcha.ErrCaptchaConfig {
h.logger.Warn("验证码加密场景ID生成失败", zap.Error(err))
h.response.BadRequest(c, "验证码加密模式未配置或配置错误")
return
}
h.logger.Error("验证码加密场景ID生成失败", zap.Error(err))
h.response.InternalError(c, "生成失败,请稍后重试")
return
}
h.response.Success(c, map[string]string{"encryptedSceneId": encrypted}, "ok")
}
// GetConfig 获取验证码前端配置是否启用、场景ID等便于前端决定是否展示滑块
// @Summary 获取验证码配置
// @Description 返回是否启用滑块、场景ID非加密模式用
// @Tags 验证码
// @Produce json
// @Success 200 {object} map[string]interface{} "captchaEnabled, sceneId"
// @Router /api/v1/captcha/config [get]
func (h *CaptchaHandler) GetConfig(c *gin.Context) {
data := map[string]interface{}{
"captchaEnabled": h.config.SMS.CaptchaEnabled,
"sceneId": h.config.SMS.SceneID,
}
h.response.Success(c, data, "ok")
}

View File

@@ -0,0 +1,729 @@
//nolint:unused
package handlers
import (
"bytes"
"encoding/json"
"io"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"hyapi-server/internal/application/certification"
"hyapi-server/internal/application/certification/dto/commands"
"hyapi-server/internal/application/certification/dto/queries"
_ "hyapi-server/internal/application/certification/dto/responses"
"hyapi-server/internal/infrastructure/external/storage"
"hyapi-server/internal/shared/interfaces"
"hyapi-server/internal/shared/middleware"
)
// CertificationHandler 认证HTTP处理器
type CertificationHandler struct {
appService certification.CertificationApplicationService
response interfaces.ResponseBuilder
validator interfaces.RequestValidator
logger *zap.Logger
jwtAuth *middleware.JWTAuthMiddleware
storageService *storage.QiNiuStorageService
}
// NewCertificationHandler 创建认证处理器
func NewCertificationHandler(
appService certification.CertificationApplicationService,
response interfaces.ResponseBuilder,
validator interfaces.RequestValidator,
logger *zap.Logger,
jwtAuth *middleware.JWTAuthMiddleware,
storageService *storage.QiNiuStorageService,
) *CertificationHandler {
return &CertificationHandler{
appService: appService,
response: response,
validator: validator,
logger: logger,
jwtAuth: jwtAuth,
storageService: storageService,
}
}
// ================ 认证申请管理 ================
// GetCertification 获取认证详情
// @Summary 获取认证详情
// @Description 根据认证ID获取认证详情
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Success 200 {object} responses.CertificationResponse "获取认证详情成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/details [get]
func (h *CertificationHandler) GetCertification(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
query := &queries.GetCertificationQuery{
UserID: userID,
}
result, err := h.appService.GetCertification(c.Request.Context(), query)
if err != nil {
h.logger.Error("获取认证详情失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取认证详情成功")
}
// ================ 企业信息管理 ================
// SubmitEnterpriseInfo 提交企业信息
// @Summary 提交企业信息
// @Description 提交企业认证所需的企业信息
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.SubmitEnterpriseInfoCommand true "提交企业信息请求"
// @Success 200 {object} responses.CertificationResponse "企业信息提交成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/enterprise-info [post]
func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
var cmd commands.SubmitEnterpriseInfoCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
cmd.UserID = userID
result, err := h.appService.SubmitEnterpriseInfo(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("提交企业信息失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "企业信息提交成功")
}
// ConfirmAuth 前端确认是否完成认证
// @Summary 前端确认认证状态
// @Description 前端轮询确认企业认证是否完成
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body queries.ConfirmAuthCommand true "确认状态请求"
// @Success 200 {object} responses.ConfirmAuthResponse "状态确认成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/confirm-auth [post]
func (h *CertificationHandler) ConfirmAuth(c *gin.Context) {
var cmd queries.ConfirmAuthCommand
cmd.UserID = h.getCurrentUserID(c)
if cmd.UserID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
result, err := h.appService.ConfirmAuth(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("确认认证/签署状态失败", zap.Error(err), zap.String("user_id", cmd.UserID))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "状态确认成功")
}
// ConfirmSign 前端确认是否完成签署
// @Summary 前端确认签署状态
// @Description 前端轮询确认合同签署是否完成
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body queries.ConfirmSignCommand true "确认状态请求"
// @Success 200 {object} responses.ConfirmSignResponse "状态确认成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/confirm-sign [post]
func (h *CertificationHandler) ConfirmSign(c *gin.Context) {
var cmd queries.ConfirmSignCommand
cmd.UserID = h.getCurrentUserID(c)
if cmd.UserID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
result, err := h.appService.ConfirmSign(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("确认认证/签署状态失败", zap.Error(err), zap.String("user_id", cmd.UserID))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "状态确认成功")
}
// ================ 合同管理 ================
// ApplyContract 申请合同签署
// @Summary 申请合同签署
// @Description 申请企业认证合同签署
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.ApplyContractCommand true "申请合同请求"
// @Success 200 {object} responses.ContractSignUrlResponse "合同申请成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 404 {object} map[string]interface{} "认证记录不存在"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/apply-contract [post]
func (h *CertificationHandler) ApplyContract(c *gin.Context) {
var cmd commands.ApplyContractCommand
cmd.UserID = h.getCurrentUserID(c)
if cmd.UserID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
result, err := h.appService.ApplyContract(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("申请合同失败", zap.Error(err), zap.String("user_id", cmd.UserID))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "合同申请成功")
}
// RecognizeBusinessLicense OCR识别营业执照
// @Summary OCR识别营业执照
// @Description 上传营业执照图片进行OCR识别自动填充企业信息
// @Tags 认证管理
// @Accept multipart/form-data
// @Produce json
// @Security Bearer
// @Param image formData file true "营业执照图片文件"
// @Success 200 {object} responses.BusinessLicenseResult "营业执照识别成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/ocr/business-license [post]
func (h *CertificationHandler) RecognizeBusinessLicense(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
// 获取上传的文件
file, err := c.FormFile("image")
if err != nil {
h.logger.Error("获取上传文件失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "请选择要上传的营业执照图片")
return
}
// 验证文件类型
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/png": true,
"image/webp": true,
}
if !allowedTypes[file.Header.Get("Content-Type")] {
h.response.BadRequest(c, "只支持JPG、PNG、WEBP格式的图片")
return
}
// 验证文件大小限制为5MB
if file.Size > 5*1024*1024 {
h.response.BadRequest(c, "图片大小不能超过5MB")
return
}
// 打开文件
src, err := file.Open()
if err != nil {
h.logger.Error("打开上传文件失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "文件读取失败")
return
}
defer src.Close()
// 读取文件内容
imageBytes, err := io.ReadAll(src)
if err != nil {
h.logger.Error("读取文件内容失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "文件读取失败")
return
}
// 调用OCR服务识别营业执照
result, err := h.appService.RecognizeBusinessLicense(c.Request.Context(), imageBytes)
if err != nil {
h.logger.Error("营业执照OCR识别失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "营业执照识别失败:"+err.Error())
return
}
h.logger.Info("营业执照OCR识别成功",
zap.String("user_id", userID),
zap.String("company_name", result.CompanyName),
zap.Float64("confidence", result.Confidence),
)
h.response.Success(c, result, "营业执照识别成功")
}
// UploadCertificationFile 上传认证相关图片到七牛云(企业信息中的营业执照、办公场地、场景附件、授权代表身份证等)
// @Summary 上传认证图片
// @Description 上传企业信息中使用的图片到七牛云,返回可访问的 URL
// @Tags 认证管理
// @Accept multipart/form-data
// @Produce json
// @Security Bearer
// @Param file formData file true "图片文件"
// @Success 200 {object} map[string]string "上传成功,返回 url 与 key"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/upload [post]
func (h *CertificationHandler) UploadCertificationFile(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
file, err := c.FormFile("file")
if err != nil {
h.logger.Error("获取上传文件失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "请选择要上传的图片文件")
return
}
allowedTypes := map[string]bool{
"image/jpeg": true,
"image/jpg": true,
"image/png": true,
"image/webp": true,
}
contentType := file.Header.Get("Content-Type")
if !allowedTypes[contentType] {
h.response.BadRequest(c, "只支持 JPG、PNG、WEBP 格式的图片")
return
}
if file.Size > 5*1024*1024 {
h.response.BadRequest(c, "图片大小不能超过 5MB")
return
}
src, err := file.Open()
if err != nil {
h.logger.Error("打开上传文件失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "文件读取失败")
return
}
defer src.Close()
fileBytes, err := io.ReadAll(src)
if err != nil {
h.logger.Error("读取文件内容失败", zap.Error(err), zap.String("user_id", userID))
h.response.BadRequest(c, "文件读取失败")
return
}
uploadResult, err := h.storageService.UploadFile(c.Request.Context(), fileBytes, file.Filename)
if err != nil {
h.logger.Error("上传文件到七牛云失败", zap.Error(err), zap.String("user_id", userID), zap.String("file_name", file.Filename))
h.response.BadRequest(c, "图片上传失败,请稍后重试")
return
}
h.response.Success(c, map[string]string{
"url": uploadResult.URL,
"key": uploadResult.Key,
}, "上传成功")
}
// ListCertifications 获取认证列表(管理员)
// @Summary 获取认证列表
// @Description 管理员获取认证申请列表
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param page query int false "页码" default(1)
// @Param page_size query int false "每页数量" default(10)
// @Param sort_by query string false "排序字段"
// @Param sort_order query string false "排序方向" Enums(asc, desc)
// @Param status query string false "认证状态"
// @Param user_id query string false "用户ID"
// @Param company_name query string false "公司名称"
// @Param legal_person_name query string false "法人姓名"
// @Param search_keyword query string false "搜索关键词"
// @Success 200 {object} responses.CertificationListResponse "获取认证列表成功"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 403 {object} map[string]interface{} "权限不足"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications [get]
func (h *CertificationHandler) ListCertifications(c *gin.Context) {
userID := h.getCurrentUserID(c)
if userID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
var query queries.ListCertificationsQuery
if err := h.validator.BindAndValidate(c, &query); err != nil {
return
}
result, err := h.appService.ListCertifications(c.Request.Context(), &query)
if err != nil {
h.logger.Error("获取认证列表失败", zap.Error(err))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取认证列表成功")
}
// AdminCompleteCertificationWithoutContract 管理员代用户完成认证(暂不关联合同)
// @Summary 管理员代用户完成认证(暂不关联合同)
// @Description 后台补充企业信息并直接完成认证,暂时不要求上传合同
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.AdminCompleteCertificationCommand true "管理员代用户完成认证请求"
// @Success 200 {object} responses.CertificationResponse "代用户完成认证成功"
// @Failure 400 {object} map[string]interface{} "请求参数错误"
// @Failure 401 {object} map[string]interface{} "未认证"
// @Failure 403 {object} map[string]interface{} "权限不足"
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
// @Router /api/v1/certifications/admin/complete-without-contract [post]
func (h *CertificationHandler) AdminCompleteCertificationWithoutContract(c *gin.Context) {
adminID := h.getCurrentUserID(c)
if adminID == "" {
h.response.Unauthorized(c, "用户未登录")
return
}
var cmd commands.AdminCompleteCertificationCommand
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
return
}
cmd.AdminID = adminID
result, err := h.appService.AdminCompleteCertificationWithoutContract(c.Request.Context(), &cmd)
if err != nil {
h.logger.Error("管理员代用户完成认证失败", zap.Error(err), zap.String("admin_id", adminID), zap.String("user_id", cmd.UserID))
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "代用户完成认证成功")
}
// AdminListSubmitRecords 管理端分页查询企业信息提交记录
// @Summary 管理端企业审核列表
// @Tags 认证管理
// @Produce json
// @Security Bearer
// @Param page query int false "页码"
// @Param page_size query int false "每页条数"
// @Param certification_status query string false "按状态机筛选info_pending_review/info_submitted/info_rejected空为全部"
// @Success 200 {object} responses.AdminSubmitRecordsListResponse
// @Router /api/v1/certifications/admin/submit-records [get]
func (h *CertificationHandler) AdminListSubmitRecords(c *gin.Context) {
query := &queries.AdminListSubmitRecordsQuery{}
if err := c.ShouldBindQuery(query); err != nil {
h.response.BadRequest(c, "参数错误")
return
}
result, err := h.appService.AdminListSubmitRecords(c.Request.Context(), query)
if err != nil {
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取成功")
}
// AdminGetSubmitRecordByID 管理端获取单条提交记录详情
// @Summary 管理端企业审核详情
// @Tags 认证管理
// @Produce json
// @Security Bearer
// @Param id path string true "记录ID"
// @Success 200 {object} responses.AdminSubmitRecordDetail
// @Router /api/v1/certifications/admin/submit-records/{id} [get]
func (h *CertificationHandler) AdminGetSubmitRecordByID(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "记录ID不能为空")
return
}
result, err := h.appService.AdminGetSubmitRecordByID(c.Request.Context(), id)
if err != nil {
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, result, "获取成功")
}
// AdminApproveSubmitRecord 管理端审核通过
// @Summary 管理端企业审核通过
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "记录ID"
// @Param request body object true "可选 remark"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/certifications/admin/submit-records/{id}/approve [post]
func (h *CertificationHandler) AdminApproveSubmitRecord(c *gin.Context) {
adminID := h.getCurrentUserID(c)
if adminID == "" {
h.response.Unauthorized(c, "未登录")
return
}
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "记录ID不能为空")
return
}
var body struct {
Remark string `json:"remark"`
}
_ = c.ShouldBindJSON(&body)
if err := h.appService.AdminApproveSubmitRecord(c.Request.Context(), id, adminID, body.Remark); err != nil {
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "审核通过")
}
// AdminRejectSubmitRecord 管理端审核拒绝
// @Summary 管理端企业审核拒绝
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param id path string true "记录ID"
// @Param request body object true "remark 必填"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/certifications/admin/submit-records/{id}/reject [post]
func (h *CertificationHandler) AdminRejectSubmitRecord(c *gin.Context) {
adminID := h.getCurrentUserID(c)
if adminID == "" {
h.response.Unauthorized(c, "未登录")
return
}
id := c.Param("id")
if id == "" {
h.response.BadRequest(c, "记录ID不能为空")
return
}
var body struct {
Remark string `json:"remark" binding:"required"`
}
if err := c.ShouldBindJSON(&body); err != nil {
h.response.BadRequest(c, "请填写拒绝原因(remark)")
return
}
if err := h.appService.AdminRejectSubmitRecord(c.Request.Context(), id, adminID, body.Remark); err != nil {
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "已拒绝")
}
// AdminTransitionCertificationStatus 管理端按用户变更认证状态(以状态机为准)
// @Summary 管理端变更认证状态
// @Tags 认证管理
// @Accept json
// @Produce json
// @Security Bearer
// @Param request body commands.AdminTransitionCertificationStatusCommand true "user_id, target_status(info_submitted/info_rejected), remark"
// @Success 200 {object} map[string]interface{}
// @Router /api/v1/certifications/admin/transition-status [post]
func (h *CertificationHandler) AdminTransitionCertificationStatus(c *gin.Context) {
adminID := h.getCurrentUserID(c)
if adminID == "" {
h.response.Unauthorized(c, "未登录")
return
}
var cmd commands.AdminTransitionCertificationStatusCommand
if err := c.ShouldBindJSON(&cmd); err != nil {
h.response.BadRequest(c, "参数错误")
return
}
cmd.AdminID = adminID
if err := h.appService.AdminTransitionCertificationStatus(c.Request.Context(), &cmd); err != nil {
h.response.BadRequest(c, err.Error())
return
}
h.response.Success(c, nil, "状态已更新")
}
// ================ 回调处理 ================
// HandleEsignCallback 处理e签宝回调
// @Summary 处理e签宝回调
// @Description 处理e签宝的异步回调通知
// @Tags 认证管理
// @Accept application/json
// @Produce text/plain
// @Success 200 {string} string "success"
// @Failure 400 {string} string "fail"
// @Router /api/v1/certifications/esign/callback [post]
func (h *CertificationHandler) HandleEsignCallback(c *gin.Context) {
// 记录请求基本信息
h.logger.Info("收到e签宝回调请求",
zap.String("method", c.Request.Method),
zap.String("url", c.Request.URL.String()),
zap.String("remote_addr", c.ClientIP()),
zap.String("user_agent", c.GetHeader("User-Agent")),
)
// 记录所有请求头
headers := make(map[string]string)
for key, values := range c.Request.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
h.logger.Info("回调请求头信息", zap.Any("headers", headers))
// 记录URL查询参数
queryParams := make(map[string]string)
for key, values := range c.Request.URL.Query() {
if len(values) > 0 {
queryParams[key] = values[0]
}
}
if len(queryParams) > 0 {
h.logger.Info("回调URL查询参数", zap.Any("query_params", queryParams))
}
// 读取并记录请求体
var callbackData *commands.EsignCallbackData
if c.Request.Body != nil {
bodyBytes, err := c.GetRawData()
if err != nil {
h.logger.Error("读取回调请求体失败", zap.Error(err))
h.response.BadRequest(c, "读取请求体失败")
return
}
if err := json.Unmarshal(bodyBytes, &callbackData); err != nil {
h.logger.Error("回调请求体不是有效的JSON格式", zap.Error(err))
h.response.BadRequest(c, "请求体格式错误")
return
}
h.logger.Info("回调请求体内容", zap.Any("body", callbackData))
// 如果后续还需要用 c.Request.Body
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
}
// 记录Content-Type
contentType := c.GetHeader("Content-Type")
h.logger.Info("回调请求Content-Type", zap.String("content_type", contentType))
// 记录Content-Length
contentLength := c.GetHeader("Content-Length")
if contentLength != "" {
h.logger.Info("回调请求Content-Length", zap.String("content_length", contentLength))
}
// 记录时间戳
h.logger.Info("回调请求时间",
zap.Time("request_time", time.Now()),
zap.String("request_id", c.GetHeader("X-Request-ID")),
)
// 记录完整的请求信息摘要
h.logger.Info("e签宝回调完整信息摘要",
zap.String("method", c.Request.Method),
zap.String("url", c.Request.URL.String()),
zap.String("client_ip", c.ClientIP()),
zap.String("content_type", contentType),
zap.Any("headers", headers),
zap.Any("query_params", queryParams),
zap.Any("body", callbackData),
)
// 处理回调数据
if callbackData != nil {
// 构建请求头映射
headers := make(map[string]string)
for key, values := range c.Request.Header {
if len(values) > 0 {
headers[key] = values[0]
}
}
// 构建查询参数映射
queryParams := make(map[string]string)
for key, values := range c.Request.URL.Query() {
if len(values) > 0 {
queryParams[key] = values[0]
}
}
if err := h.appService.HandleEsignCallback(c.Request.Context(), &commands.EsignCallbackCommand{
Data: callbackData,
Headers: headers,
QueryParams: queryParams,
}); err != nil {
h.logger.Error("处理e签宝回调失败", zap.Error(err))
h.response.BadRequest(c, "回调处理失败: "+err.Error())
return
}
}
// 返回成功响应
c.JSON(200, map[string]interface{}{
"code": "200",
"msg": "success",
})
}
// ================ 辅助方法 ================
// getCurrentUserID 获取当前用户ID
func (h *CertificationHandler) getCurrentUserID(c *gin.Context) string {
if userID, exists := c.Get("user_id"); exists {
if id, ok := userID.(string); ok {
return id
}
}
return ""
}

View File

@@ -0,0 +1,471 @@
package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/shopspring/decimal"
"go.uber.org/zap"
"hyapi-server/internal/application/product"
"hyapi-server/internal/config"
financeRepositories "hyapi-server/internal/domains/finance/repositories"
)
// ComponentReportOrderHandler 组件报告订单处理器
type ComponentReportOrderHandler struct {
service *product.ComponentReportOrderService
purchaseOrderRepo financeRepositories.PurchaseOrderRepository
config *config.Config
logger *zap.Logger
}
// NewComponentReportOrderHandler 创建组件报告订单处理器
func NewComponentReportOrderHandler(
service *product.ComponentReportOrderService,
purchaseOrderRepo financeRepositories.PurchaseOrderRepository,
config *config.Config,
logger *zap.Logger,
) *ComponentReportOrderHandler {
return &ComponentReportOrderHandler{
service: service,
purchaseOrderRepo: purchaseOrderRepo,
config: config,
logger: logger,
}
}
// CheckDownloadAvailability 检查下载可用性
// GET /api/v1/products/:id/component-report/check
func (h *ComponentReportOrderHandler) CheckDownloadAvailability(c *gin.Context) {
h.logger.Info("开始检查下载可用性")
productID := c.Param("id")
h.logger.Info("获取产品ID", zap.String("product_id", productID))
if productID == "" {
h.logger.Error("产品ID不能为空")
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "产品ID不能为空",
})
return
}
userID := c.GetString("user_id")
h.logger.Info("获取用户ID", zap.String("user_id", userID))
if userID == "" {
h.logger.Error("用户未登录")
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "用户未登录",
})
return
}
// 调用服务获取订单信息,检查是否可以下载
orderInfo, err := h.service.GetOrderInfo(c.Request.Context(), userID, productID)
if err != nil {
h.logger.Error("获取订单信息失败", zap.Error(err), zap.String("product_id", productID), zap.String("user_id", userID))
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "获取订单信息失败",
"error": err.Error(),
})
return
}
h.logger.Info("获取订单信息成功", zap.Bool("can_download", orderInfo.CanDownload), zap.Bool("is_package", orderInfo.IsPackage))
// 返回检查结果
message := "需要购买"
if orderInfo.CanDownload {
message = "可以下载"
}
c.JSON(http.StatusOK, gin.H{
"code": 200,
"data": gin.H{
"can_download": orderInfo.CanDownload,
"is_package": orderInfo.IsPackage,
"message": message,
},
})
}
// GetDownloadInfo 获取下载信息和价格计算
// GET /api/v1/products/:id/component-report/info
func (h *ComponentReportOrderHandler) GetDownloadInfo(c *gin.Context) {
h.logger.Info("开始获取下载信息和价格计算")
productID := c.Param("id")
h.logger.Info("获取产品ID", zap.String("product_id", productID))
if productID == "" {
h.logger.Error("产品ID不能为空")
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "产品ID不能为空",
})
return
}
userID := c.GetString("user_id")
h.logger.Info("获取用户ID", zap.String("user_id", userID))
if userID == "" {
h.logger.Error("用户未登录")
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "用户未登录",
})
return
}
orderInfo, err := h.service.GetOrderInfo(c.Request.Context(), userID, productID)
if err != nil {
h.logger.Error("获取订单信息失败", zap.Error(err), zap.String("product_id", productID), zap.String("user_id", userID))
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "获取订单信息失败",
"error": err.Error(),
})
return
}
// 记录详细的订单信息
h.logger.Info("获取订单信息成功",
zap.String("product_id", orderInfo.ProductID),
zap.String("product_code", orderInfo.ProductCode),
zap.String("product_name", orderInfo.ProductName),
zap.Bool("is_package", orderInfo.IsPackage),
zap.Int("sub_products_count", len(orderInfo.SubProducts)),
zap.String("price", orderInfo.Price),
zap.Strings("purchased_product_codes", orderInfo.PurchasedProductCodes),
zap.Bool("can_download", orderInfo.CanDownload),
)
// 记录子产品详情
for i, subProduct := range orderInfo.SubProducts {
h.logger.Info("子产品信息",
zap.Int("index", i),
zap.String("sub_product_id", subProduct.ProductID),
zap.String("sub_product_code", subProduct.ProductCode),
zap.String("sub_product_name", subProduct.ProductName),
zap.String("price", subProduct.Price),
zap.Bool("is_purchased", subProduct.IsPurchased),
)
}
c.JSON(http.StatusOK, gin.H{
"code": 200,
"data": orderInfo,
})
}
// CreatePaymentOrder 创建支付订单
// POST /api/v1/products/:id/component-report/create-order
func (h *ComponentReportOrderHandler) CreatePaymentOrder(c *gin.Context) {
h.logger.Info("开始创建支付订单")
productID := c.Param("id")
h.logger.Info("获取产品ID", zap.String("product_id", productID))
if productID == "" {
h.logger.Error("产品ID不能为空")
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "产品ID不能为空",
})
return
}
userID := c.GetString("user_id")
h.logger.Info("获取用户ID", zap.String("user_id", userID))
if userID == "" {
h.logger.Error("用户未登录")
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "用户未登录",
})
return
}
var req product.CreatePaymentOrderRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error("请求参数错误", zap.Error(err))
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "请求参数错误",
"error": err.Error(),
})
return
}
// 记录请求参数
h.logger.Info("支付订单请求参数",
zap.String("user_id", userID),
zap.String("product_id", productID),
zap.String("payment_type", req.PaymentType),
zap.String("platform", req.Platform),
zap.Strings("sub_product_codes", req.SubProductCodes),
)
// 设置用户ID和产品ID
req.UserID = userID
req.ProductID = productID
// 如果未指定支付平台根据User-Agent判断
if req.Platform == "" {
userAgent := c.GetHeader("User-Agent")
req.Platform = h.detectPlatform(userAgent)
h.logger.Info("根据User-Agent检测平台", zap.String("user_agent", userAgent), zap.String("detected_platform", req.Platform))
}
response, err := h.service.CreatePaymentOrder(c.Request.Context(), &req)
if err != nil {
h.logger.Error("创建支付订单失败", zap.Error(err),
zap.String("product_id", productID),
zap.String("user_id", userID),
zap.String("payment_type", req.PaymentType))
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "创建支付订单失败",
"error": err.Error(),
})
return
}
// 记录创建订单成功响应
h.logger.Info("创建支付订单成功",
zap.String("order_id", response.OrderID),
zap.String("order_no", response.OrderNo),
zap.String("payment_type", response.PaymentType),
zap.String("amount", response.Amount),
zap.String("code_url", response.CodeURL),
zap.String("pay_url", response.PayURL),
)
// 开发环境下,自动将订单状态设置为已支付
if h.config != nil && h.config.App.IsDevelopment() {
h.logger.Info("开发环境:自动设置订单为已支付状态",
zap.String("order_id", response.OrderID),
zap.String("order_no", response.OrderNo))
// 获取订单信息
purchaseOrder, err := h.purchaseOrderRepo.GetByID(c.Request.Context(), response.OrderID)
if err != nil {
h.logger.Error("开发环境:获取订单信息失败", zap.Error(err), zap.String("order_id", response.OrderID))
} else {
// 解析金额
amount, err := decimal.NewFromString(response.Amount)
if err != nil {
h.logger.Error("开发环境:解析订单金额失败", zap.Error(err), zap.String("amount", response.Amount))
} else {
// 标记为已支付(使用开发环境的模拟交易号)
tradeNo := "DEV_" + response.OrderNo
purchaseOrder.MarkPaid(tradeNo, "", "", amount, amount)
// 更新订单状态
err = h.purchaseOrderRepo.Update(c.Request.Context(), purchaseOrder)
if err != nil {
h.logger.Error("开发环境:更新订单状态失败", zap.Error(err), zap.String("order_id", response.OrderID))
} else {
h.logger.Info("开发环境:订单状态已自动设置为已支付",
zap.String("order_id", response.OrderID),
zap.String("order_no", response.OrderNo),
zap.String("trade_no", tradeNo))
}
}
}
}
c.JSON(http.StatusOK, gin.H{
"code": 200,
"data": response,
})
}
// CheckPaymentStatus 检查支付状态
// GET /api/v1/component-report/check-payment/:orderId
func (h *ComponentReportOrderHandler) CheckPaymentStatus(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "用户未登录",
})
return
}
orderID := c.Param("orderId")
if orderID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "订单ID不能为空",
})
return
}
response, err := h.service.CheckPaymentStatus(c.Request.Context(), orderID)
if err != nil {
h.logger.Error("检查支付状态失败", zap.Error(err), zap.String("order_id", orderID))
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "检查支付状态失败",
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"data": response,
"message": "查询支付状态成功",
})
}
// DownloadFile 下载文件
// GET /api/v1/component-report/download/:orderId
func (h *ComponentReportOrderHandler) DownloadFile(c *gin.Context) {
h.logger.Info("开始处理文件下载请求")
userID := c.GetString("user_id")
if userID == "" {
h.logger.Error("用户未登录")
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "用户未登录",
})
return
}
h.logger.Info("获取用户ID", zap.String("user_id", userID))
orderID := c.Param("orderId")
if orderID == "" {
h.logger.Error("订单ID不能为空")
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "订单ID不能为空",
})
return
}
h.logger.Info("获取订单ID", zap.String("order_id", orderID))
filePath, err := h.service.DownloadFile(c.Request.Context(), orderID)
if err != nil {
h.logger.Error("下载文件失败", zap.Error(err), zap.String("order_id", orderID), zap.String("user_id", userID))
// 根据错误类型返回不同的状态码和消息
errorMessage := err.Error()
statusCode := http.StatusInternalServerError
// 根据错误消息判断具体错误类型
if strings.Contains(errorMessage, "购买订单不存在") {
statusCode = http.StatusNotFound
} else if strings.Contains(errorMessage, "订单未支付") || strings.Contains(errorMessage, "已过期") {
statusCode = http.StatusForbidden
} else if strings.Contains(errorMessage, "生成报告文件失败") {
statusCode = http.StatusInternalServerError
}
c.JSON(statusCode, gin.H{
"code": statusCode,
"message": "下载文件失败",
"error": errorMessage,
})
return
}
h.logger.Info("成功获取文件路径",
zap.String("order_id", orderID),
zap.String("user_id", userID),
zap.String("file_path", filePath))
// 设置响应头
c.Header("Content-Type", "application/zip")
c.Header("Content-Disposition", "attachment; filename=component_report.zip")
// 发送文件
h.logger.Info("开始发送文件", zap.String("file_path", filePath))
c.File(filePath)
h.logger.Info("文件发送成功", zap.String("file_path", filePath))
}
// GetUserOrders 获取用户订单列表
// GET /api/v1/component-report/orders
func (h *ComponentReportOrderHandler) GetUserOrders(c *gin.Context) {
userID := c.GetString("user_id")
if userID == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"code": 401,
"message": "用户未登录",
})
return
}
// 解析分页参数
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10"))
if page < 1 {
page = 1
}
if pageSize < 1 || pageSize > 100 {
pageSize = 10
}
offset := (page - 1) * pageSize
orders, total, err := h.service.GetUserOrders(c.Request.Context(), userID, pageSize, offset)
if err != nil {
h.logger.Error("获取用户订单列表失败", zap.Error(err), zap.String("user_id", userID))
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": "获取用户订单列表失败",
"error": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 200,
"data": gin.H{
"list": orders,
"total": total,
"page": page,
"page_size": pageSize,
},
})
}
// detectPlatform 根据 User-Agent 检测支付平台类型
func (h *ComponentReportOrderHandler) detectPlatform(userAgent string) string {
if userAgent == "" {
return "h5" // 默认 H5
}
ua := strings.ToLower(userAgent)
// 检测移动设备
if strings.Contains(ua, "mobile") || strings.Contains(ua, "android") ||
strings.Contains(ua, "iphone") || strings.Contains(ua, "ipad") {
// 检测是否是支付宝或微信内置浏览器
if strings.Contains(ua, "alipay") {
return "app" // 支付宝 APP
}
if strings.Contains(ua, "micromessenger") {
return "h5" // 微信 H5
}
return "h5" // 移动端默认 H5
}
// PC 端
return "pc"
}

View File

@@ -0,0 +1,92 @@
package handlers
import (
"strings"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
"hyapi-server/internal/application/product"
"hyapi-server/internal/shared/interfaces"
)
// FileDownloadHandler 文件下载处理器
type FileDownloadHandler struct {
uiComponentAppService product.UIComponentApplicationService
responseBuilder interfaces.ResponseBuilder
logger *zap.Logger
}
// NewFileDownloadHandler 创建文件下载处理器
func NewFileDownloadHandler(
uiComponentAppService product.UIComponentApplicationService,
responseBuilder interfaces.ResponseBuilder,
logger *zap.Logger,
) *FileDownloadHandler {
return &FileDownloadHandler{
uiComponentAppService: uiComponentAppService,
responseBuilder: responseBuilder,
logger: logger,
}
}
// DownloadUIComponentFile 下载UI组件文件
// @Summary 下载UI组件文件
// @Description 下载UI组件文件
// @Tags 文件下载
// @Accept json
// @Produce application/octet-stream
// @Param id path string true "UI组件ID"
// @Success 200 {file} file "文件内容"
// @Failure 400 {object} interfaces.Response "请求参数错误"
// @Failure 404 {object} interfaces.Response "UI组件不存在或文件不存在"
// @Failure 500 {object} interfaces.Response "服务器内部错误"
// @Router /api/v1/ui-components/{id}/download [get]
func (h *FileDownloadHandler) DownloadUIComponentFile(c *gin.Context) {
id := c.Param("id")
if id == "" {
h.responseBuilder.BadRequest(c, "UI组件ID不能为空")
return
}
// 获取UI组件信息
component, err := h.uiComponentAppService.GetUIComponentByID(c.Request.Context(), id)
if err != nil {
h.logger.Error("获取UI组件失败", zap.Error(err), zap.String("id", id))
h.responseBuilder.InternalError(c, "获取UI组件失败")
return
}
if component == nil {
h.responseBuilder.NotFound(c, "UI组件不存在")
return
}
if component.FilePath == nil {
h.responseBuilder.NotFound(c, "UI组件文件不存在")
return
}
// 获取文件路径
filePath, err := h.uiComponentAppService.DownloadUIComponentFile(c.Request.Context(), id)
if err != nil {
h.logger.Error("获取UI组件文件路径失败", zap.Error(err), zap.String("id", id))
h.responseBuilder.InternalError(c, "获取UI组件文件路径失败")
return
}
// 设置下载文件名
fileName := component.ComponentName
if !strings.HasSuffix(strings.ToLower(fileName), ".zip") {
fileName += ".zip"
}
// 设置响应头
c.Header("Content-Description", "File Transfer")
c.Header("Content-Transfer-Encoding", "binary")
c.Header("Content-Disposition", "attachment; filename="+fileName)
c.Header("Content-Type", "application/octet-stream")
// 发送文件
c.File(filePath)
}

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More