基础架构
This commit is contained in:
284
internal/infrastructure/cache/redis_cache.go
vendored
Normal file
284
internal/infrastructure/cache/redis_cache.go
vendored
Normal file
@@ -0,0 +1,284 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// RedisCache Redis缓存实现
|
||||
type RedisCache struct {
|
||||
client *redis.Client
|
||||
logger *zap.Logger
|
||||
prefix string
|
||||
|
||||
// 统计信息
|
||||
hits int64
|
||||
misses int64
|
||||
}
|
||||
|
||||
// NewRedisCache 创建Redis缓存实例
|
||||
func NewRedisCache(client *redis.Client, logger *zap.Logger, prefix string) *RedisCache {
|
||||
return &RedisCache{
|
||||
client: client,
|
||||
logger: logger,
|
||||
prefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
// Name 返回服务名称
|
||||
func (r *RedisCache) Name() string {
|
||||
return "redis-cache"
|
||||
}
|
||||
|
||||
// Initialize 初始化服务
|
||||
func (r *RedisCache) Initialize(ctx context.Context) error {
|
||||
// 测试连接
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to connect to Redis", zap.Error(err))
|
||||
return fmt.Errorf("redis connection failed: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("Redis cache service initialized")
|
||||
return nil
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (r *RedisCache) HealthCheck(ctx context.Context) error {
|
||||
_, err := r.client.Ping(ctx).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Shutdown 关闭服务
|
||||
func (r *RedisCache) Shutdown(ctx context.Context) error {
|
||||
return r.client.Close()
|
||||
}
|
||||
|
||||
// Get 获取缓存值
|
||||
func (r *RedisCache) Get(ctx context.Context, key string, dest interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
val, err := r.client.Get(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
r.misses++
|
||||
return fmt.Errorf("cache miss: key %s not found", key)
|
||||
}
|
||||
r.logger.Error("Failed to get cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.hits++
|
||||
return json.Unmarshal([]byte(val), dest)
|
||||
}
|
||||
|
||||
// Set 设置缓存值
|
||||
func (r *RedisCache) Set(ctx context.Context, key string, value interface{}, ttl ...interface{}) error {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal value: %w", err)
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
case string:
|
||||
expiration, _ = time.ParseDuration(v)
|
||||
default:
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour // 默认24小时
|
||||
}
|
||||
|
||||
err = r.client.Set(ctx, fullKey, data, expiration).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to set cache", zap.String("key", key), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除缓存
|
||||
func (r *RedisCache) Delete(ctx context.Context, keys ...string) error {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
err := r.client.Del(ctx, fullKeys...).Err()
|
||||
if err != nil {
|
||||
r.logger.Error("Failed to delete cache", zap.Strings("keys", keys), zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Exists 检查键是否存在
|
||||
func (r *RedisCache) Exists(ctx context.Context, key string) (bool, error) {
|
||||
fullKey := r.getFullKey(key)
|
||||
|
||||
count, err := r.client.Exists(ctx, fullKey).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetMultiple 批量获取
|
||||
func (r *RedisCache) GetMultiple(ctx context.Context, keys []string) (map[string]interface{}, error) {
|
||||
if len(keys) == 0 {
|
||||
return make(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
fullKeys := make([]string, len(keys))
|
||||
for i, key := range keys {
|
||||
fullKeys[i] = r.getFullKey(key)
|
||||
}
|
||||
|
||||
values, err := r.client.MGet(ctx, fullKeys...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make(map[string]interface{})
|
||||
for i, val := range values {
|
||||
if val != nil {
|
||||
var data interface{}
|
||||
if err := json.Unmarshal([]byte(val.(string)), &data); err == nil {
|
||||
result[keys[i]] = data
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// SetMultiple 批量设置
|
||||
func (r *RedisCache) SetMultiple(ctx context.Context, data map[string]interface{}, ttl ...interface{}) error {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var expiration time.Duration
|
||||
if len(ttl) > 0 {
|
||||
switch v := ttl[0].(type) {
|
||||
case time.Duration:
|
||||
expiration = v
|
||||
case int:
|
||||
expiration = time.Duration(v) * time.Second
|
||||
default:
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
} else {
|
||||
expiration = 24 * time.Hour
|
||||
}
|
||||
|
||||
pipe := r.client.Pipeline()
|
||||
for key, value := range data {
|
||||
fullKey := r.getFullKey(key)
|
||||
jsonData, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
pipe.Set(ctx, fullKey, jsonData, expiration)
|
||||
}
|
||||
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeletePattern 按模式删除
|
||||
func (r *RedisCache) DeletePattern(ctx context.Context, pattern string) error {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(keys) > 0 {
|
||||
return r.client.Del(ctx, keys...).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Keys 获取匹配的键
|
||||
func (r *RedisCache) Keys(ctx context.Context, pattern string) ([]string, error) {
|
||||
fullPattern := r.getFullKey(pattern)
|
||||
|
||||
keys, err := r.client.Keys(ctx, fullPattern).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 移除前缀
|
||||
result := make([]string, len(keys))
|
||||
prefixLen := len(r.prefix) + 1 // +1 for ":"
|
||||
for i, key := range keys {
|
||||
if len(key) > prefixLen {
|
||||
result[i] = key[prefixLen:]
|
||||
} else {
|
||||
result[i] = key
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Stats 获取缓存统计
|
||||
func (r *RedisCache) Stats(ctx context.Context) (interfaces.CacheStats, error) {
|
||||
dbSize, _ := r.client.DBSize(ctx).Result()
|
||||
|
||||
return interfaces.CacheStats{
|
||||
Hits: r.hits,
|
||||
Misses: r.misses,
|
||||
Keys: dbSize,
|
||||
Memory: 0, // 暂时设为0,后续可解析Redis info
|
||||
Connections: 0, // 暂时设为0,后续可解析Redis info
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getFullKey 获取完整键名
|
||||
func (r *RedisCache) getFullKey(key string) string {
|
||||
if r.prefix == "" {
|
||||
return key
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", r.prefix, key)
|
||||
}
|
||||
|
||||
// Flush 清空所有缓存
|
||||
func (r *RedisCache) Flush(ctx context.Context) error {
|
||||
if r.prefix == "" {
|
||||
return r.client.FlushDB(ctx).Err()
|
||||
}
|
||||
|
||||
// 只删除带前缀的键
|
||||
return r.DeletePattern(ctx, "*")
|
||||
}
|
||||
|
||||
// GetClient 获取原始Redis客户端
|
||||
func (r *RedisCache) GetClient() *redis.Client {
|
||||
return r.client
|
||||
}
|
||||
199
internal/infrastructure/database/database.go
Normal file
199
internal/infrastructure/database/database.go
Normal file
@@ -0,0 +1,199 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Config 数据库配置
|
||||
type Config struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Name string
|
||||
SSLMode string
|
||||
Timezone string
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
}
|
||||
|
||||
// DB 数据库包装器
|
||||
type DB struct {
|
||||
*gorm.DB
|
||||
config Config
|
||||
}
|
||||
|
||||
// NewConnection 创建新的数据库连接
|
||||
func NewConnection(config Config) (*DB, error) {
|
||||
// 构建DSN
|
||||
dsn := buildDSN(config)
|
||||
|
||||
// 配置GORM
|
||||
gormConfig := &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Info),
|
||||
NamingStrategy: schema.NamingStrategy{
|
||||
SingularTable: true, // 使用单数表名
|
||||
},
|
||||
DisableForeignKeyConstraintWhenMigrating: true,
|
||||
NowFunc: func() time.Time {
|
||||
return time.Now().In(time.FixedZone("CST", 8*3600)) // 强制使用北京时间
|
||||
},
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
db, err := gorm.Open(postgres.Open(dsn), gormConfig)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("连接数据库失败: %w", err)
|
||||
}
|
||||
|
||||
// 获取底层sql.DB
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取数据库实例失败: %w", err)
|
||||
}
|
||||
|
||||
// 配置连接池
|
||||
sqlDB.SetMaxOpenConns(config.MaxOpenConns)
|
||||
sqlDB.SetMaxIdleConns(config.MaxIdleConns)
|
||||
sqlDB.SetConnMaxLifetime(config.ConnMaxLifetime)
|
||||
|
||||
// 测试连接
|
||||
if err := sqlDB.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("数据库连接测试失败: %w", err)
|
||||
}
|
||||
|
||||
return &DB{
|
||||
DB: db,
|
||||
config: config,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildDSN 构建数据库连接字符串
|
||||
func buildDSN(config Config) string {
|
||||
return fmt.Sprintf(
|
||||
"host=%s user=%s password=%s dbname=%s port=%s sslmode=%s TimeZone=%s options='-c timezone=%s'",
|
||||
config.Host,
|
||||
config.User,
|
||||
config.Password,
|
||||
config.Name,
|
||||
config.Port,
|
||||
config.SSLMode,
|
||||
config.Timezone,
|
||||
config.Timezone,
|
||||
)
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (db *DB) Close() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Close()
|
||||
}
|
||||
|
||||
// Ping 检查数据库连接
|
||||
func (db *DB) Ping() error {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sqlDB.Ping()
|
||||
}
|
||||
|
||||
// GetStats 获取连接池统计信息
|
||||
func (db *DB) GetStats() (map[string]interface{}, error) {
|
||||
sqlDB, err := db.DB.DB()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats := sqlDB.Stats()
|
||||
return map[string]interface{}{
|
||||
"max_open_connections": stats.MaxOpenConnections,
|
||||
"open_connections": stats.OpenConnections,
|
||||
"in_use": stats.InUse,
|
||||
"idle": stats.Idle,
|
||||
"wait_count": stats.WaitCount,
|
||||
"wait_duration": stats.WaitDuration,
|
||||
"max_idle_closed": stats.MaxIdleClosed,
|
||||
"max_idle_time_closed": stats.MaxIdleTimeClosed,
|
||||
"max_lifetime_closed": stats.MaxLifetimeClosed,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// BeginTx 开始事务
|
||||
func (db *DB) BeginTx() *gorm.DB {
|
||||
return db.DB.Begin()
|
||||
}
|
||||
|
||||
// Migrate 执行数据库迁移
|
||||
func (db *DB) Migrate(models ...interface{}) error {
|
||||
return db.DB.AutoMigrate(models...)
|
||||
}
|
||||
|
||||
// IsHealthy 检查数据库健康状态
|
||||
func (db *DB) IsHealthy() bool {
|
||||
return db.Ping() == nil
|
||||
}
|
||||
|
||||
// WithContext 返回带上下文的数据库实例
|
||||
func (db *DB) WithContext(ctx interface{}) *gorm.DB {
|
||||
if c, ok := ctx.(context.Context); ok {
|
||||
return db.DB.WithContext(c)
|
||||
}
|
||||
return db.DB
|
||||
}
|
||||
|
||||
// 事务包装器
|
||||
type TxWrapper struct {
|
||||
tx *gorm.DB
|
||||
}
|
||||
|
||||
// NewTxWrapper 创建事务包装器
|
||||
func (db *DB) NewTxWrapper() *TxWrapper {
|
||||
return &TxWrapper{
|
||||
tx: db.BeginTx(),
|
||||
}
|
||||
}
|
||||
|
||||
// Commit 提交事务
|
||||
func (tx *TxWrapper) Commit() error {
|
||||
return tx.tx.Commit().Error
|
||||
}
|
||||
|
||||
// Rollback 回滚事务
|
||||
func (tx *TxWrapper) Rollback() error {
|
||||
return tx.tx.Rollback().Error
|
||||
}
|
||||
|
||||
// GetDB 获取事务数据库实例
|
||||
func (tx *TxWrapper) GetDB() *gorm.DB {
|
||||
return tx.tx
|
||||
}
|
||||
|
||||
// WithTx 在事务中执行函数
|
||||
func (db *DB) WithTx(fn func(*gorm.DB) error) error {
|
||||
tx := db.BeginTx()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
panic(r)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit().Error
|
||||
}
|
||||
@@ -0,0 +1,225 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/admin/entities"
|
||||
"tyapi-server/internal/domains/admin/repositories"
|
||||
"tyapi-server/internal/domains/admin/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormAdminLoginLogRepository 管理员登录日志GORM仓储实现
|
||||
type GormAdminLoginLogRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.AdminLoginLogRepository = (*GormAdminLoginLogRepository)(nil)
|
||||
|
||||
// NewGormAdminLoginLogRepository 创建管理员登录日志GORM仓储
|
||||
func NewGormAdminLoginLogRepository(db *gorm.DB, logger *zap.Logger) repositories.AdminLoginLogRepository {
|
||||
return &GormAdminLoginLogRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建登录日志
|
||||
func (r *GormAdminLoginLogRepository) Create(ctx context.Context, log entities.AdminLoginLog) (entities.AdminLoginLog, error) {
|
||||
r.logger.Info("创建管理员登录日志", zap.String("admin_id", log.AdminID))
|
||||
err := r.db.WithContext(ctx).Create(&log).Error
|
||||
return log, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取登录日志
|
||||
func (r *GormAdminLoginLogRepository) GetByID(ctx context.Context, id string) (entities.AdminLoginLog, error) {
|
||||
var log entities.AdminLoginLog
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&log).Error
|
||||
return log, err
|
||||
}
|
||||
|
||||
// Update 更新登录日志
|
||||
func (r *GormAdminLoginLogRepository) Update(ctx context.Context, log entities.AdminLoginLog) error {
|
||||
r.logger.Info("更新管理员登录日志", zap.String("id", log.ID))
|
||||
return r.db.WithContext(ctx).Save(&log).Error
|
||||
}
|
||||
|
||||
// Delete 删除登录日志
|
||||
func (r *GormAdminLoginLogRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除管理员登录日志", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminLoginLog{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除登录日志
|
||||
func (r *GormAdminLoginLogRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除管理员登录日志", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminLoginLog{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复登录日志
|
||||
func (r *GormAdminLoginLogRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复管理员登录日志", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.AdminLoginLog{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计登录日志数量
|
||||
func (r *GormAdminLoginLogRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("admin_id LIKE ? OR ip_address LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查登录日志是否存在
|
||||
func (r *GormAdminLoginLogRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建登录日志
|
||||
func (r *GormAdminLoginLogRepository) CreateBatch(ctx context.Context, logs []entities.AdminLoginLog) error {
|
||||
r.logger.Info("批量创建管理员登录日志", zap.Int("count", len(logs)))
|
||||
return r.db.WithContext(ctx).Create(&logs).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取登录日志
|
||||
func (r *GormAdminLoginLogRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.AdminLoginLog, error) {
|
||||
var logs []entities.AdminLoginLog
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新登录日志
|
||||
func (r *GormAdminLoginLogRepository) UpdateBatch(ctx context.Context, logs []entities.AdminLoginLog) error {
|
||||
r.logger.Info("批量更新管理员登录日志", zap.Int("count", len(logs)))
|
||||
return r.db.WithContext(ctx).Save(&logs).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除登录日志
|
||||
func (r *GormAdminLoginLogRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除管理员登录日志", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminLoginLog{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取登录日志列表
|
||||
func (r *GormAdminLoginLogRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.AdminLoginLog, error) {
|
||||
var logs []entities.AdminLoginLog
|
||||
query := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("admin_id LIKE ? OR ip_address LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return logs, query.Find(&logs).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormAdminLoginLogRepository) WithTx(tx interface{}) interfaces.Repository[entities.AdminLoginLog] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormAdminLoginLogRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// ListLogs 获取登录日志列表(带分页和筛选)
|
||||
func (r *GormAdminLoginLogRepository) ListLogs(ctx context.Context, query *queries.ListAdminLoginLogQuery) ([]*entities.AdminLoginLog, int64, error) {
|
||||
var logs []entities.AdminLoginLog
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.AdminID != "" {
|
||||
dbQuery = dbQuery.Where("admin_id = ?", query.AdminID)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
logPtrs := make([]*entities.AdminLoginLog, len(logs))
|
||||
for i := range logs {
|
||||
logPtrs[i] = &logs[i]
|
||||
}
|
||||
|
||||
return logPtrs, total, nil
|
||||
}
|
||||
|
||||
// GetTodayLoginCount 获取今日登录次数
|
||||
func (r *GormAdminLoginLogRepository) GetTodayLoginCount(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("created_at >= ?", today).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetLoginCountByAdmin 获取指定管理员在指定天数内的登录次数
|
||||
func (r *GormAdminLoginLogRepository) GetLoginCountByAdmin(ctx context.Context, adminID string, days int) (int64, error) {
|
||||
var count int64
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("admin_id = ? AND created_at >= ?", adminID, startDate).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
@@ -0,0 +1,236 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/admin/entities"
|
||||
"tyapi-server/internal/domains/admin/repositories"
|
||||
"tyapi-server/internal/domains/admin/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormAdminOperationLogRepository 管理员操作日志GORM仓储实现
|
||||
type GormAdminOperationLogRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.AdminOperationLogRepository = (*GormAdminOperationLogRepository)(nil)
|
||||
|
||||
// NewGormAdminOperationLogRepository 创建管理员操作日志GORM仓储
|
||||
func NewGormAdminOperationLogRepository(db *gorm.DB, logger *zap.Logger) repositories.AdminOperationLogRepository {
|
||||
return &GormAdminOperationLogRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建操作日志
|
||||
func (r *GormAdminOperationLogRepository) Create(ctx context.Context, log entities.AdminOperationLog) (entities.AdminOperationLog, error) {
|
||||
r.logger.Info("创建管理员操作日志", zap.String("admin_id", log.AdminID), zap.String("action", log.Action))
|
||||
err := r.db.WithContext(ctx).Create(&log).Error
|
||||
return log, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取操作日志
|
||||
func (r *GormAdminOperationLogRepository) GetByID(ctx context.Context, id string) (entities.AdminOperationLog, error) {
|
||||
var log entities.AdminOperationLog
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&log).Error
|
||||
return log, err
|
||||
}
|
||||
|
||||
// Update 更新操作日志
|
||||
func (r *GormAdminOperationLogRepository) Update(ctx context.Context, log entities.AdminOperationLog) error {
|
||||
r.logger.Info("更新管理员操作日志", zap.String("id", log.ID))
|
||||
return r.db.WithContext(ctx).Save(&log).Error
|
||||
}
|
||||
|
||||
// Delete 删除操作日志
|
||||
func (r *GormAdminOperationLogRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除管理员操作日志", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminOperationLog{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除操作日志
|
||||
func (r *GormAdminOperationLogRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除管理员操作日志", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminOperationLog{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复操作日志
|
||||
func (r *GormAdminOperationLogRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复管理员操作日志", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.AdminOperationLog{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计操作日志数量
|
||||
func (r *GormAdminOperationLogRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("admin_id LIKE ? OR action LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查操作日志是否存在
|
||||
func (r *GormAdminOperationLogRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建操作日志
|
||||
func (r *GormAdminOperationLogRepository) CreateBatch(ctx context.Context, logs []entities.AdminOperationLog) error {
|
||||
r.logger.Info("批量创建管理员操作日志", zap.Int("count", len(logs)))
|
||||
return r.db.WithContext(ctx).Create(&logs).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取操作日志
|
||||
func (r *GormAdminOperationLogRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.AdminOperationLog, error) {
|
||||
var logs []entities.AdminOperationLog
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&logs).Error
|
||||
return logs, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新操作日志
|
||||
func (r *GormAdminOperationLogRepository) UpdateBatch(ctx context.Context, logs []entities.AdminOperationLog) error {
|
||||
r.logger.Info("批量更新管理员操作日志", zap.Int("count", len(logs)))
|
||||
return r.db.WithContext(ctx).Save(&logs).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除操作日志
|
||||
func (r *GormAdminOperationLogRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除管理员操作日志", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminOperationLog{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取操作日志列表
|
||||
func (r *GormAdminOperationLogRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.AdminOperationLog, error) {
|
||||
var logs []entities.AdminOperationLog
|
||||
query := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("admin_id LIKE ? OR action LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return logs, query.Find(&logs).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormAdminOperationLogRepository) WithTx(tx interface{}) interfaces.Repository[entities.AdminOperationLog] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormAdminOperationLogRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// ListLogs 获取操作日志列表(带分页和筛选)
|
||||
func (r *GormAdminOperationLogRepository) ListLogs(ctx context.Context, query *queries.ListAdminOperationLogQuery) ([]*entities.AdminOperationLog, int64, error) {
|
||||
var logs []entities.AdminOperationLog
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.AdminID != "" {
|
||||
dbQuery = dbQuery.Where("admin_id = ?", query.AdminID)
|
||||
}
|
||||
if query.Module != "" {
|
||||
dbQuery = dbQuery.Where("module = ?", query.Module)
|
||||
}
|
||||
if query.Action != "" {
|
||||
dbQuery = dbQuery.Where("action = ?", query.Action)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
logPtrs := make([]*entities.AdminOperationLog, len(logs))
|
||||
for i := range logs {
|
||||
logPtrs[i] = &logs[i]
|
||||
}
|
||||
|
||||
return logPtrs, total, nil
|
||||
}
|
||||
|
||||
// GetTotalOperations 获取总操作数
|
||||
func (r *GormAdminOperationLogRepository) GetTotalOperations(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// GetOperationsByAdmin 获取指定管理员在指定天数内的操作数
|
||||
func (r *GormAdminOperationLogRepository) GetOperationsByAdmin(ctx context.Context, adminID string, days int) (int64, error) {
|
||||
var count int64
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Where("admin_id = ? AND created_at >= ?", adminID, startDate).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建操作日志
|
||||
func (r *GormAdminOperationLogRepository) BatchCreate(ctx context.Context, logs []entities.AdminOperationLog) error {
|
||||
r.logger.Info("批量创建管理员操作日志", zap.Int("count", len(logs)))
|
||||
return r.db.WithContext(ctx).Create(&logs).Error
|
||||
}
|
||||
@@ -0,0 +1,222 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/admin/entities"
|
||||
"tyapi-server/internal/domains/admin/repositories"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormAdminPermissionRepository 管理员权限GORM仓储实现
|
||||
type GormAdminPermissionRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.AdminPermissionRepository = (*GormAdminPermissionRepository)(nil)
|
||||
|
||||
// NewGormAdminPermissionRepository 创建管理员权限GORM仓储
|
||||
func NewGormAdminPermissionRepository(db *gorm.DB, logger *zap.Logger) repositories.AdminPermissionRepository {
|
||||
return &GormAdminPermissionRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建权限
|
||||
func (r *GormAdminPermissionRepository) Create(ctx context.Context, permission entities.AdminPermission) (entities.AdminPermission, error) {
|
||||
r.logger.Info("创建管理员权限", zap.String("code", permission.Code))
|
||||
err := r.db.WithContext(ctx).Create(&permission).Error
|
||||
return permission, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取权限
|
||||
func (r *GormAdminPermissionRepository) GetByID(ctx context.Context, id string) (entities.AdminPermission, error) {
|
||||
var permission entities.AdminPermission
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&permission).Error
|
||||
return permission, err
|
||||
}
|
||||
|
||||
// Update 更新权限
|
||||
func (r *GormAdminPermissionRepository) Update(ctx context.Context, permission entities.AdminPermission) error {
|
||||
r.logger.Info("更新管理员权限", zap.String("id", permission.ID))
|
||||
return r.db.WithContext(ctx).Save(&permission).Error
|
||||
}
|
||||
|
||||
// Delete 删除权限
|
||||
func (r *GormAdminPermissionRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除管理员权限", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminPermission{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除权限
|
||||
func (r *GormAdminPermissionRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除管理员权限", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminPermission{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复权限
|
||||
func (r *GormAdminPermissionRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复管理员权限", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.AdminPermission{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计权限数量
|
||||
func (r *GormAdminPermissionRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.AdminPermission{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("code LIKE ? OR name LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查权限是否存在
|
||||
func (r *GormAdminPermissionRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.AdminPermission{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建权限
|
||||
func (r *GormAdminPermissionRepository) CreateBatch(ctx context.Context, permissions []entities.AdminPermission) error {
|
||||
r.logger.Info("批量创建管理员权限", zap.Int("count", len(permissions)))
|
||||
return r.db.WithContext(ctx).Create(&permissions).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取权限
|
||||
func (r *GormAdminPermissionRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.AdminPermission, error) {
|
||||
var permissions []entities.AdminPermission
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&permissions).Error
|
||||
return permissions, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新权限
|
||||
func (r *GormAdminPermissionRepository) UpdateBatch(ctx context.Context, permissions []entities.AdminPermission) error {
|
||||
r.logger.Info("批量更新管理员权限", zap.Int("count", len(permissions)))
|
||||
return r.db.WithContext(ctx).Save(&permissions).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除权限
|
||||
func (r *GormAdminPermissionRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除管理员权限", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.AdminPermission{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取权限列表
|
||||
func (r *GormAdminPermissionRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.AdminPermission, error) {
|
||||
var permissions []entities.AdminPermission
|
||||
query := r.db.WithContext(ctx).Model(&entities.AdminPermission{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("code LIKE ? OR name LIKE ? OR module LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return permissions, query.Find(&permissions).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormAdminPermissionRepository) WithTx(tx interface{}) interfaces.Repository[entities.AdminPermission] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormAdminPermissionRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// FindByCode 根据权限代码查找权限
|
||||
func (r *GormAdminPermissionRepository) FindByCode(ctx context.Context, code string) (*entities.AdminPermission, error) {
|
||||
var permission entities.AdminPermission
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&permission).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// FindByModule 根据模块查找权限
|
||||
func (r *GormAdminPermissionRepository) FindByModule(ctx context.Context, module string) ([]entities.AdminPermission, error) {
|
||||
var permissions []entities.AdminPermission
|
||||
err := r.db.WithContext(ctx).Where("module = ?", module).Find(&permissions).Error
|
||||
return permissions, err
|
||||
}
|
||||
|
||||
// ListActive 获取所有激活的权限
|
||||
func (r *GormAdminPermissionRepository) ListActive(ctx context.Context) ([]entities.AdminPermission, error) {
|
||||
var permissions []entities.AdminPermission
|
||||
err := r.db.WithContext(ctx).Where("is_active = ?", true).Find(&permissions).Error
|
||||
return permissions, err
|
||||
}
|
||||
|
||||
// GetPermissionsByRole 根据角色获取权限
|
||||
func (r *GormAdminPermissionRepository) GetPermissionsByRole(ctx context.Context, role entities.AdminRole) ([]entities.AdminPermission, error) {
|
||||
var permissions []entities.AdminPermission
|
||||
|
||||
query := r.db.WithContext(ctx).
|
||||
Joins("JOIN admin_role_permissions ON admin_permissions.id = admin_role_permissions.permission_id").
|
||||
Where("admin_role_permissions.role = ? AND admin_permissions.is_active = ?", role, true)
|
||||
|
||||
return permissions, query.Find(&permissions).Error
|
||||
}
|
||||
|
||||
// AssignPermissionsToRole 为角色分配权限
|
||||
func (r *GormAdminPermissionRepository) AssignPermissionsToRole(ctx context.Context, role entities.AdminRole, permissionIDs []string) error {
|
||||
// 先删除现有权限
|
||||
if err := r.db.WithContext(ctx).Where("role = ?", role).Delete(&entities.AdminRolePermission{}).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 批量插入新权限
|
||||
var rolePermissions []entities.AdminRolePermission
|
||||
for _, permissionID := range permissionIDs {
|
||||
rolePermissions = append(rolePermissions, entities.AdminRolePermission{
|
||||
Role: role,
|
||||
PermissionID: permissionID,
|
||||
})
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Create(&rolePermissions).Error
|
||||
}
|
||||
|
||||
// RemovePermissionsFromRole 从角色移除权限
|
||||
func (r *GormAdminPermissionRepository) RemovePermissionsFromRole(ctx context.Context, role entities.AdminRole, permissionIDs []string) error {
|
||||
return r.db.WithContext(ctx).Where("role = ? AND permission_id IN ?", role, permissionIDs).Delete(&entities.AdminRolePermission{}).Error
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/admin/entities"
|
||||
"tyapi-server/internal/domains/admin/repositories"
|
||||
"tyapi-server/internal/domains/admin/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormAdminRepository 管理员GORM仓储实现
|
||||
type GormAdminRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.AdminRepository = (*GormAdminRepository)(nil)
|
||||
|
||||
// NewGormAdminRepository 创建管理员GORM仓储
|
||||
func NewGormAdminRepository(db *gorm.DB, logger *zap.Logger) *GormAdminRepository {
|
||||
return &GormAdminRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建管理员
|
||||
func (r *GormAdminRepository) Create(ctx context.Context, admin entities.Admin) (entities.Admin, error) {
|
||||
r.logger.Info("创建管理员", zap.String("username", admin.Username))
|
||||
err := r.db.WithContext(ctx).Create(&admin).Error
|
||||
return admin, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取管理员
|
||||
func (r *GormAdminRepository) GetByID(ctx context.Context, id string) (entities.Admin, error) {
|
||||
var admin entities.Admin
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&admin).Error
|
||||
return admin, err
|
||||
}
|
||||
|
||||
// Update 更新管理员
|
||||
func (r *GormAdminRepository) Update(ctx context.Context, admin entities.Admin) error {
|
||||
r.logger.Info("更新管理员", zap.String("id", admin.ID))
|
||||
return r.db.WithContext(ctx).Save(&admin).Error
|
||||
}
|
||||
|
||||
// Delete 删除管理员
|
||||
func (r *GormAdminRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除管理员", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Admin{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除管理员
|
||||
func (r *GormAdminRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除管理员", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Admin{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复管理员
|
||||
func (r *GormAdminRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复管理员", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Admin{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计管理员数量
|
||||
func (r *GormAdminRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Admin{})
|
||||
|
||||
// 应用过滤条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("username LIKE ? OR email LIKE ? OR real_name LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查管理员是否存在
|
||||
func (r *GormAdminRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Admin{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建管理员
|
||||
func (r *GormAdminRepository) CreateBatch(ctx context.Context, admins []entities.Admin) error {
|
||||
r.logger.Info("批量创建管理员", zap.Int("count", len(admins)))
|
||||
return r.db.WithContext(ctx).Create(&admins).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取管理员
|
||||
func (r *GormAdminRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Admin, error) {
|
||||
var admins []entities.Admin
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&admins).Error
|
||||
return admins, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新管理员
|
||||
func (r *GormAdminRepository) UpdateBatch(ctx context.Context, admins []entities.Admin) error {
|
||||
r.logger.Info("批量更新管理员", zap.Int("count", len(admins)))
|
||||
return r.db.WithContext(ctx).Save(&admins).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除管理员
|
||||
func (r *GormAdminRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除管理员", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Admin{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取管理员列表
|
||||
func (r *GormAdminRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Admin, error) {
|
||||
var admins []entities.Admin
|
||||
query := r.db.WithContext(ctx).Model(&entities.Admin{})
|
||||
|
||||
// 应用过滤条件
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
// 应用搜索条件
|
||||
if options.Search != "" {
|
||||
query = query.Where("username LIKE ? OR email LIKE ? OR real_name LIKE ?",
|
||||
"%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
// 应用排序
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return admins, query.Find(&admins).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormAdminRepository) WithTx(tx interface{}) interfaces.Repository[entities.Admin] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormAdminRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FindByUsername 根据用户名查找管理员
|
||||
func (r *GormAdminRepository) FindByUsername(ctx context.Context, username string) (*entities.Admin, error) {
|
||||
var admin entities.Admin
|
||||
err := r.db.WithContext(ctx).Where("username = ?", username).First(&admin).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &admin, nil
|
||||
}
|
||||
|
||||
// FindByEmail 根据邮箱查找管理员
|
||||
func (r *GormAdminRepository) FindByEmail(ctx context.Context, email string) (*entities.Admin, error) {
|
||||
var admin entities.Admin
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&admin).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &admin, nil
|
||||
}
|
||||
|
||||
// ListAdmins 获取管理员列表(带分页和筛选)
|
||||
func (r *GormAdminRepository) ListAdmins(ctx context.Context, query *queries.ListAdminsQuery) ([]*entities.Admin, int64, error) {
|
||||
var admins []entities.Admin
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Admin{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Username != "" {
|
||||
dbQuery = dbQuery.Where("username LIKE ?", "%"+query.Username+"%")
|
||||
}
|
||||
if query.Email != "" {
|
||||
dbQuery = dbQuery.Where("email LIKE ?", "%"+query.Email+"%")
|
||||
}
|
||||
if query.Role != "" {
|
||||
dbQuery = dbQuery.Where("role = ?", query.Role)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&admins).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
adminPtrs := make([]*entities.Admin, len(admins))
|
||||
for i := range admins {
|
||||
adminPtrs[i] = &admins[i]
|
||||
}
|
||||
|
||||
return adminPtrs, total, nil
|
||||
}
|
||||
|
||||
// GetStats 获取管理员统计信息
|
||||
func (r *GormAdminRepository) GetStats(ctx context.Context, query *queries.GetAdminInfoQuery) (*repositories.AdminStats, error) {
|
||||
var stats repositories.AdminStats
|
||||
|
||||
// 总管理员数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Admin{}).Count(&stats.TotalAdmins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 激活管理员数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Admin{}).Where("is_active = ?", true).Count(&stats.ActiveAdmins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).Model(&entities.AdminLoginLog{}).Where("created_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总操作数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.AdminOperationLog{}).Count(&stats.TotalOperations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetPermissionsByRole 根据角色获取权限
|
||||
func (r *GormAdminRepository) GetPermissionsByRole(ctx context.Context, role entities.AdminRole) ([]entities.AdminPermission, error) {
|
||||
var permissions []entities.AdminPermission
|
||||
|
||||
query := r.db.WithContext(ctx).
|
||||
Joins("JOIN admin_role_permissions ON admin_permissions.id = admin_role_permissions.permission_id").
|
||||
Where("admin_role_permissions.role = ? AND admin_permissions.is_active = ?", role, true)
|
||||
|
||||
return permissions, query.Find(&permissions).Error
|
||||
}
|
||||
|
||||
// UpdatePermissions 更新管理员权限
|
||||
func (r *GormAdminRepository) UpdatePermissions(ctx context.Context, adminID string, permissions []string) error {
|
||||
permissionsJSON, err := json.Marshal(permissions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化权限失败: %w", err)
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.Admin{}).
|
||||
Where("id = ?", adminID).
|
||||
Update("permissions", string(permissionsJSON)).Error
|
||||
}
|
||||
|
||||
// UpdateLoginStats 更新登录统计
|
||||
func (r *GormAdminRepository) UpdateLoginStats(ctx context.Context, adminID string) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.Admin{}).
|
||||
Where("id = ?", adminID).
|
||||
Updates(map[string]interface{}{
|
||||
"last_login_at": time.Now(),
|
||||
"login_count": gorm.Expr("login_count + 1"),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdateReviewStats 更新审核统计
|
||||
func (r *GormAdminRepository) UpdateReviewStats(ctx context.Context, adminID string, approved bool) error {
|
||||
updates := map[string]interface{}{
|
||||
"review_count": gorm.Expr("review_count + 1"),
|
||||
}
|
||||
|
||||
if approved {
|
||||
updates["approved_count"] = gorm.Expr("approved_count + 1")
|
||||
} else {
|
||||
updates["rejected_count"] = gorm.Expr("rejected_count + 1")
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.Admin{}).
|
||||
Where("id = ?", adminID).
|
||||
Updates(updates).Error
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/certification/entities"
|
||||
"tyapi-server/internal/domains/certification/repositories"
|
||||
"tyapi-server/internal/domains/certification/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormCertificationRepository GORM认证仓储实现
|
||||
type GormCertificationRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.CertificationRepository = (*GormCertificationRepository)(nil)
|
||||
|
||||
// NewGormCertificationRepository 创建GORM认证仓储
|
||||
func NewGormCertificationRepository(db *gorm.DB, logger *zap.Logger) repositories.CertificationRepository {
|
||||
return &GormCertificationRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建认证申请
|
||||
func (r *GormCertificationRepository) Create(ctx context.Context, cert entities.Certification) (entities.Certification, error) {
|
||||
r.logger.Info("创建认证申请", zap.String("user_id", cert.UserID))
|
||||
err := r.db.WithContext(ctx).Create(&cert).Error
|
||||
return cert, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取认证申请
|
||||
func (r *GormCertificationRepository) GetByID(ctx context.Context, id string) (entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&cert).Error
|
||||
return cert, err
|
||||
}
|
||||
|
||||
// Update 更新认证申请
|
||||
func (r *GormCertificationRepository) Update(ctx context.Context, cert entities.Certification) error {
|
||||
r.logger.Info("更新认证申请", zap.String("id", cert.ID))
|
||||
return r.db.WithContext(ctx).Save(&cert).Error
|
||||
}
|
||||
|
||||
// Delete 删除认证申请
|
||||
func (r *GormCertificationRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除认证申请", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Certification{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除认证申请
|
||||
func (r *GormCertificationRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除认证申请", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Certification{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复认证申请
|
||||
func (r *GormCertificationRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复认证申请", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Certification{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计认证申请数量
|
||||
func (r *GormCertificationRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Certification{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查认证申请是否存在
|
||||
func (r *GormCertificationRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建认证申请
|
||||
func (r *GormCertificationRepository) CreateBatch(ctx context.Context, certs []entities.Certification) error {
|
||||
r.logger.Info("批量创建认证申请", zap.Int("count", len(certs)))
|
||||
return r.db.WithContext(ctx).Create(&certs).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取认证申请
|
||||
func (r *GormCertificationRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Certification, error) {
|
||||
var certs []entities.Certification
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&certs).Error
|
||||
return certs, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新认证申请
|
||||
func (r *GormCertificationRepository) UpdateBatch(ctx context.Context, certs []entities.Certification) error {
|
||||
r.logger.Info("批量更新认证申请", zap.Int("count", len(certs)))
|
||||
return r.db.WithContext(ctx).Save(&certs).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除认证申请
|
||||
func (r *GormCertificationRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除认证申请", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Certification{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取认证申请列表
|
||||
func (r *GormCertificationRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Certification, error) {
|
||||
var certs []entities.Certification
|
||||
query := r.db.WithContext(ctx).Model(&entities.Certification{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return certs, query.Find(&certs).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormCertificationRepository) WithTx(tx interface{}) interfaces.Repository[entities.Certification] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormCertificationRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// ListCertifications 获取认证申请列表(带分页和筛选)
|
||||
func (r *GormCertificationRepository) ListCertifications(ctx context.Context, query *queries.ListCertificationsQuery) ([]*entities.Certification, int64, error) {
|
||||
var certs []entities.Certification
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Certification{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
if query.AdminID != "" {
|
||||
dbQuery = dbQuery.Where("admin_id = ?", query.AdminID)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
if query.EnterpriseName != "" {
|
||||
dbQuery = dbQuery.Joins("JOIN enterprises ON certifications.enterprise_id = enterprises.id").
|
||||
Where("enterprises.enterprise_name LIKE ?", "%"+query.EnterpriseName+"%")
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&certs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
certPtrs := make([]*entities.Certification, len(certs))
|
||||
for i := range certs {
|
||||
certPtrs[i] = &certs[i]
|
||||
}
|
||||
|
||||
return certPtrs, total, nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取认证申请
|
||||
func (r *GormCertificationRepository) GetByUserID(ctx context.Context, userID string) (*entities.Certification, error) {
|
||||
var cert entities.Certification
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&cert).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &cert, nil
|
||||
}
|
||||
|
||||
// GetByStatus 根据状态获取认证申请列表
|
||||
func (r *GormCertificationRepository) GetByStatus(ctx context.Context, status string) ([]*entities.Certification, error) {
|
||||
var certs []entities.Certification
|
||||
err := r.db.WithContext(ctx).Where("status = ?", status).Find(&certs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPtrs := make([]*entities.Certification, len(certs))
|
||||
for i := range certs {
|
||||
certPtrs[i] = &certs[i]
|
||||
}
|
||||
|
||||
return certPtrs, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新认证状态
|
||||
func (r *GormCertificationRepository) UpdateStatus(ctx context.Context, certificationID string, status string, adminID *string, notes string) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if adminID != nil {
|
||||
updates["admin_id"] = *adminID
|
||||
}
|
||||
|
||||
if notes != "" {
|
||||
updates["approval_notes"] = notes
|
||||
}
|
||||
|
||||
// 根据状态设置相应的时间戳
|
||||
switch status {
|
||||
case "INFO_SUBMITTED":
|
||||
updates["info_submitted_at"] = time.Now()
|
||||
case "FACE_VERIFIED":
|
||||
updates["face_verified_at"] = time.Now()
|
||||
case "CONTRACT_APPLIED":
|
||||
updates["contract_applied_at"] = time.Now()
|
||||
case "CONTRACT_APPROVED":
|
||||
updates["contract_approved_at"] = time.Now()
|
||||
case "CONTRACT_SIGNED":
|
||||
updates["contract_signed_at"] = time.Now()
|
||||
case "COMPLETED":
|
||||
updates["completed_at"] = time.Now()
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&entities.Certification{}).
|
||||
Where("id = ?", certificationID).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// GetPendingCertifications 获取待审核的认证申请
|
||||
func (r *GormCertificationRepository) GetPendingCertifications(ctx context.Context) ([]*entities.Certification, error) {
|
||||
return r.GetByStatus(ctx, "CONTRACT_PENDING")
|
||||
}
|
||||
|
||||
// GetStats 获取认证统计信息
|
||||
func (r *GormCertificationRepository) GetStats(ctx context.Context) (*repositories.CertificationStats, error) {
|
||||
var stats repositories.CertificationStats
|
||||
|
||||
// 总认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Count(&stats.TotalCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 待审核认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ?", "CONTRACT_PENDING").Count(&stats.PendingCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已完成认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ?", "COMPLETED").Count(&stats.CompletedCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 被拒绝认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ?", "REJECTED").Count(&stats.RejectedCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日提交数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("created_at >= ?", today).Count(&stats.TodaySubmissions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetStatsByDateRange 根据日期范围获取认证统计信息
|
||||
func (r *GormCertificationRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.CertificationStats, error) {
|
||||
var stats repositories.CertificationStats
|
||||
|
||||
// 总认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("created_at BETWEEN ? AND ?", startDate, endDate).Count(&stats.TotalCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 待审核认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ? AND created_at BETWEEN ? AND ?", "CONTRACT_PENDING", startDate, endDate).Count(&stats.PendingCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 已完成认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ? AND created_at BETWEEN ? AND ?", "COMPLETED", startDate, endDate).Count(&stats.CompletedCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 被拒绝认证申请数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("status = ? AND created_at BETWEEN ? AND ?", "REJECTED", startDate, endDate).Count(&stats.RejectedCertifications).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日提交数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Certification{}).Where("created_at >= ?", today).Count(&stats.TodaySubmissions).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
@@ -0,0 +1,422 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/certification/entities"
|
||||
"tyapi-server/internal/domains/certification/repositories"
|
||||
"tyapi-server/internal/domains/certification/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormContractRecordRepository GORM合同记录仓储实现
|
||||
type GormContractRecordRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.ContractRecordRepository = (*GormContractRecordRepository)(nil)
|
||||
|
||||
// NewGormContractRecordRepository 创建GORM合同记录仓储
|
||||
func NewGormContractRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.ContractRecordRepository {
|
||||
return &GormContractRecordRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建合同记录
|
||||
func (r *GormContractRecordRepository) Create(ctx context.Context, record entities.ContractRecord) (entities.ContractRecord, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
|
||||
r.logger.Error("创建合同记录失败",
|
||||
zap.String("certification_id", record.CertificationID),
|
||||
zap.String("contract_type", record.ContractType),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.ContractRecord{}, fmt.Errorf("创建合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("合同记录创建成功",
|
||||
zap.String("id", record.ID),
|
||||
zap.String("contract_type", record.ContractType),
|
||||
)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取合同记录
|
||||
func (r *GormContractRecordRepository) GetByID(ctx context.Context, id string) (entities.ContractRecord, error) {
|
||||
var record entities.ContractRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entities.ContractRecord{}, fmt.Errorf("合同记录不存在")
|
||||
}
|
||||
r.logger.Error("获取合同记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.ContractRecord{}, fmt.Errorf("获取合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Update 更新合同记录
|
||||
func (r *GormContractRecordRepository) Update(ctx context.Context, record entities.ContractRecord) error {
|
||||
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
|
||||
r.logger.Error("更新合同记录失败",
|
||||
zap.String("id", record.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("更新合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除合同记录
|
||||
func (r *GormContractRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.ContractRecord{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除合同记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("删除合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除合同记录
|
||||
func (r *GormContractRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// Restore 恢复合同记录
|
||||
func (r *GormContractRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.ContractRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复合同记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("恢复合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("合同记录恢复成功", zap.String("id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计合同记录数量
|
||||
func (r *GormContractRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("contract_type LIKE ? OR contract_name LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查合同记录是否存在
|
||||
func (r *GormContractRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.ContractRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建合同记录
|
||||
func (r *GormContractRecordRepository) CreateBatch(ctx context.Context, records []entities.ContractRecord) error {
|
||||
r.logger.Info("批量创建合同记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取合同记录
|
||||
func (r *GormContractRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.ContractRecord, error) {
|
||||
var records []entities.ContractRecord
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新合同记录
|
||||
func (r *GormContractRecordRepository) UpdateBatch(ctx context.Context, records []entities.ContractRecord) error {
|
||||
r.logger.Info("批量更新合同记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除合同记录
|
||||
func (r *GormContractRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除合同记录", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.ContractRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取合同记录列表
|
||||
func (r *GormContractRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.ContractRecord, error) {
|
||||
var records []entities.ContractRecord
|
||||
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("contract_type LIKE ? OR contract_name LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return records, query.Find(&records).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormContractRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.ContractRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormContractRecordRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// GetByCertificationID 根据认证申请ID获取合同记录列表
|
||||
func (r *GormContractRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) ([]*entities.ContractRecord, error) {
|
||||
var records []entities.ContractRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("根据认证申请ID获取合同记录失败",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.ContractRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
|
||||
// GetLatestByCertificationID 根据认证申请ID获取最新的合同记录
|
||||
func (r *GormContractRecordRepository) GetLatestByCertificationID(ctx context.Context, certificationID string) (*entities.ContractRecord, error) {
|
||||
var record entities.ContractRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").First(&record).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("合同记录不存在")
|
||||
}
|
||||
r.logger.Error("根据认证申请ID获取最新合同记录失败",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// ListRecords 获取合同记录列表(带分页和筛选)
|
||||
func (r *GormContractRecordRepository) ListRecords(ctx context.Context, query *queries.ListContractRecordsQuery) ([]*entities.ContractRecord, int64, error) {
|
||||
var records []entities.ContractRecord
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.ContractRecord{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.CertificationID != "" {
|
||||
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
|
||||
}
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&records).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.ContractRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, total, nil
|
||||
}
|
||||
|
||||
// UpdateContractStatus 更新合同状态
|
||||
func (r *GormContractRecordRepository) UpdateContractStatus(ctx context.Context, recordID string, status string, adminID *string, notes string) error {
|
||||
updates := map[string]interface{}{
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if adminID != nil {
|
||||
updates["admin_id"] = *adminID
|
||||
}
|
||||
|
||||
if notes != "" {
|
||||
updates["admin_notes"] = notes
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.ContractRecord{}).
|
||||
Where("id = ?", recordID).
|
||||
Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新合同状态失败",
|
||||
zap.String("record_id", recordID),
|
||||
zap.String("status", status),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("更新合同状态失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("合同状态更新成功",
|
||||
zap.String("record_id", recordID),
|
||||
zap.String("status", status),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取合同记录列表
|
||||
func (r *GormContractRecordRepository) GetByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.ContractRecord, int, error) {
|
||||
var records []entities.ContractRecord
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取用户合同记录总数失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取合同记录总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("获取用户合同记录列表失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取合同记录列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.ContractRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, int(total), nil
|
||||
}
|
||||
|
||||
// GetByStatus 根据状态获取合同记录列表
|
||||
func (r *GormContractRecordRepository) GetByStatus(ctx context.Context, status string, page, pageSize int) ([]*entities.ContractRecord, int, error) {
|
||||
var records []entities.ContractRecord
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.ContractRecord{}).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
r.logger.Error("根据状态获取合同记录总数失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取合同记录总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("根据状态获取合同记录列表失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取合同记录列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.ContractRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, int(total), nil
|
||||
}
|
||||
|
||||
// GetPendingContracts 获取待审核的合同记录
|
||||
func (r *GormContractRecordRepository) GetPendingContracts(ctx context.Context, page, pageSize int) ([]*entities.ContractRecord, int, error) {
|
||||
return r.GetByStatus(ctx, "PENDING", page, pageSize)
|
||||
}
|
||||
|
||||
// GetExpiredSigningContracts 获取签署链接已过期的合同记录
|
||||
func (r *GormContractRecordRepository) GetExpiredSigningContracts(ctx context.Context, limit int) ([]*entities.ContractRecord, error) {
|
||||
var records []entities.ContractRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("expires_at < NOW() AND status = ?", "APPROVED").
|
||||
Limit(limit).
|
||||
Order("expires_at ASC").
|
||||
Find(&records).Error; err != nil {
|
||||
r.logger.Error("获取过期签署合同记录失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取过期签署合同记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.ContractRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
|
||||
// GetExpiredContracts 获取已过期的合同记录(通用方法)
|
||||
func (r *GormContractRecordRepository) GetExpiredContracts(ctx context.Context, limit int) ([]*entities.ContractRecord, error) {
|
||||
return r.GetExpiredSigningContracts(ctx, limit)
|
||||
}
|
||||
@@ -0,0 +1,394 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/certification/entities"
|
||||
"tyapi-server/internal/domains/certification/repositories"
|
||||
"tyapi-server/internal/domains/certification/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormFaceVerifyRecordRepository GORM人脸识别记录仓储实现
|
||||
type GormFaceVerifyRecordRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.FaceVerifyRecordRepository = (*GormFaceVerifyRecordRepository)(nil)
|
||||
|
||||
// NewGormFaceVerifyRecordRepository 创建GORM人脸识别记录仓储
|
||||
func NewGormFaceVerifyRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.FaceVerifyRecordRepository {
|
||||
return &GormFaceVerifyRecordRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) Create(ctx context.Context, record entities.FaceVerifyRecord) (entities.FaceVerifyRecord, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
|
||||
r.logger.Error("创建人脸识别记录失败",
|
||||
zap.String("certification_id", record.CertificationID),
|
||||
zap.String("certify_id", record.CertifyID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.FaceVerifyRecord{}, fmt.Errorf("创建人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("人脸识别记录创建成功",
|
||||
zap.String("id", record.ID),
|
||||
zap.String("certify_id", record.CertifyID),
|
||||
)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) GetByID(ctx context.Context, id string) (entities.FaceVerifyRecord, error) {
|
||||
var record entities.FaceVerifyRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entities.FaceVerifyRecord{}, fmt.Errorf("人脸识别记录不存在")
|
||||
}
|
||||
r.logger.Error("获取人脸识别记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.FaceVerifyRecord{}, fmt.Errorf("获取人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Update 更新人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) Update(ctx context.Context, record entities.FaceVerifyRecord) error {
|
||||
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
|
||||
r.logger.Error("更新人脸识别记录失败",
|
||||
zap.String("id", record.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("更新人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.FaceVerifyRecord{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除人脸识别记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("删除人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// Restore 恢复人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.FaceVerifyRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复人脸识别记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("恢复人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("人脸识别记录恢复成功", zap.String("id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计人脸识别记录数量
|
||||
func (r *GormFaceVerifyRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("certify_id LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查人脸识别记录是否存在
|
||||
func (r *GormFaceVerifyRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) CreateBatch(ctx context.Context, records []entities.FaceVerifyRecord) error {
|
||||
r.logger.Info("批量创建人脸识别记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.FaceVerifyRecord, error) {
|
||||
var records []entities.FaceVerifyRecord
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) UpdateBatch(ctx context.Context, records []entities.FaceVerifyRecord) error {
|
||||
r.logger.Info("批量更新人脸识别记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除人脸识别记录", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.FaceVerifyRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取人脸识别记录列表
|
||||
func (r *GormFaceVerifyRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.FaceVerifyRecord, error) {
|
||||
var records []entities.FaceVerifyRecord
|
||||
query := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("certify_id LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return records, query.Find(&records).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormFaceVerifyRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.FaceVerifyRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormFaceVerifyRecordRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// GetByCertificationID 根据认证申请ID获取人脸识别记录列表
|
||||
func (r *GormFaceVerifyRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) ([]*entities.FaceVerifyRecord, error) {
|
||||
var records []entities.FaceVerifyRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("根据认证申请ID获取人脸识别记录失败",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
|
||||
// GetLatestByCertificationID 根据认证申请ID获取最新的人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) GetLatestByCertificationID(ctx context.Context, certificationID string) (*entities.FaceVerifyRecord, error) {
|
||||
var record entities.FaceVerifyRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").First(&record).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("人脸识别记录不存在")
|
||||
}
|
||||
r.logger.Error("根据认证申请ID获取最新人脸识别记录失败",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// ListRecords 获取人脸识别记录列表(带分页和筛选)
|
||||
func (r *GormFaceVerifyRecordRepository) ListRecords(ctx context.Context, query *queries.ListFaceVerifyRecordsQuery) ([]*entities.FaceVerifyRecord, int64, error) {
|
||||
var records []entities.FaceVerifyRecord
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.CertificationID != "" {
|
||||
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
|
||||
}
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&records).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, total, nil
|
||||
}
|
||||
|
||||
// GetSuccessRate 获取成功率
|
||||
func (r *GormFaceVerifyRecordRepository) GetSuccessRate(ctx context.Context, days int) (float64, error) {
|
||||
var totalCount int64
|
||||
var successCount int64
|
||||
|
||||
// 计算指定天数前的日期
|
||||
startDate := fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d DAY)", days)
|
||||
|
||||
// 获取总数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).
|
||||
Where("created_at >= " + startDate).Count(&totalCount).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// 获取成功数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).
|
||||
Where("created_at >= "+startDate+" AND status = ?", "SUCCESS").Count(&successCount).Error; err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if totalCount == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return float64(successCount) / float64(totalCount) * 100, nil
|
||||
}
|
||||
|
||||
// GetByCertifyID 根据认证ID获取人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) GetByCertifyID(ctx context.Context, certifyID string) (*entities.FaceVerifyRecord, error) {
|
||||
var record entities.FaceVerifyRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "certify_id = ?", certifyID).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("人脸识别记录不存在")
|
||||
}
|
||||
r.logger.Error("根据认证ID获取人脸识别记录失败",
|
||||
zap.String("certify_id", certifyID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取人脸识别记录列表
|
||||
func (r *GormFaceVerifyRecordRepository) GetByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.FaceVerifyRecord, int, error) {
|
||||
var records []entities.FaceVerifyRecord
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.FaceVerifyRecord{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取用户人脸识别记录总数失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取人脸识别记录总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("获取用户人脸识别记录列表失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取人脸识别记录列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, int(total), nil
|
||||
}
|
||||
|
||||
// GetExpiredRecords 获取已过期的人脸识别记录
|
||||
func (r *GormFaceVerifyRecordRepository) GetExpiredRecords(ctx context.Context, limit int) ([]*entities.FaceVerifyRecord, error) {
|
||||
var records []entities.FaceVerifyRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("expires_at < NOW() AND status = ?", "PROCESSING").
|
||||
Limit(limit).
|
||||
Order("expires_at ASC").
|
||||
Find(&records).Error; err != nil {
|
||||
r.logger.Error("获取过期人脸识别记录失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取过期人脸识别记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.FaceVerifyRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
@@ -0,0 +1,374 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/certification/entities"
|
||||
"tyapi-server/internal/domains/certification/repositories"
|
||||
"tyapi-server/internal/domains/certification/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormLicenseUploadRecordRepository GORM营业执照上传记录仓储实现
|
||||
type GormLicenseUploadRecordRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.LicenseUploadRecordRepository = (*GormLicenseUploadRecordRepository)(nil)
|
||||
|
||||
// NewGormLicenseUploadRecordRepository 创建GORM营业执照上传记录仓储
|
||||
func NewGormLicenseUploadRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.LicenseUploadRecordRepository {
|
||||
return &GormLicenseUploadRecordRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) Create(ctx context.Context, record entities.LicenseUploadRecord) (entities.LicenseUploadRecord, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
|
||||
r.logger.Error("创建上传记录失败",
|
||||
zap.String("user_id", record.UserID),
|
||||
zap.String("file_name", record.OriginalFileName),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.LicenseUploadRecord{}, fmt.Errorf("创建上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("上传记录创建成功",
|
||||
zap.String("id", record.ID),
|
||||
zap.String("file_name", record.OriginalFileName),
|
||||
)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) GetByID(ctx context.Context, id string) (entities.LicenseUploadRecord, error) {
|
||||
var record entities.LicenseUploadRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entities.LicenseUploadRecord{}, fmt.Errorf("上传记录不存在")
|
||||
}
|
||||
r.logger.Error("获取上传记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.LicenseUploadRecord{}, fmt.Errorf("获取上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Update 更新上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) Update(ctx context.Context, record entities.LicenseUploadRecord) error {
|
||||
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
|
||||
r.logger.Error("更新上传记录失败",
|
||||
zap.String("id", record.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("更新上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.LicenseUploadRecord{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除上传记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("删除上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// Restore 恢复上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.LicenseUploadRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复上传记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("恢复上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("上传记录恢复成功", zap.String("id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计上传记录数量
|
||||
func (r *GormLicenseUploadRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("original_file_name LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查上传记录是否存在
|
||||
func (r *GormLicenseUploadRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) CreateBatch(ctx context.Context, records []entities.LicenseUploadRecord) error {
|
||||
r.logger.Info("批量创建上传记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.LicenseUploadRecord, error) {
|
||||
var records []entities.LicenseUploadRecord
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) UpdateBatch(ctx context.Context, records []entities.LicenseUploadRecord) error {
|
||||
r.logger.Info("批量更新上传记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除上传记录", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.LicenseUploadRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取上传记录列表
|
||||
func (r *GormLicenseUploadRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.LicenseUploadRecord, error) {
|
||||
var records []entities.LicenseUploadRecord
|
||||
query := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("original_file_name LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return records, query.Find(&records).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormLicenseUploadRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.LicenseUploadRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormLicenseUploadRecordRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// GetByCertificationID 根据认证ID获取上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) (*entities.LicenseUploadRecord, error) {
|
||||
var record entities.LicenseUploadRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "certification_id = ?", certificationID).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("上传记录不存在")
|
||||
}
|
||||
r.logger.Error("根据认证ID获取上传记录失败",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// ListRecords 获取上传记录列表(带分页和筛选)
|
||||
func (r *GormLicenseUploadRecordRepository) ListRecords(ctx context.Context, query *queries.ListLicenseUploadRecordsQuery) ([]*entities.LicenseUploadRecord, int64, error) {
|
||||
var records []entities.LicenseUploadRecord
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.CertificationID != "" {
|
||||
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
|
||||
}
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&records).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.LicenseUploadRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, total, nil
|
||||
}
|
||||
|
||||
// UpdateOCRResult 更新OCR结果
|
||||
func (r *GormLicenseUploadRecordRepository) UpdateOCRResult(ctx context.Context, recordID string, ocrResult string, confidence float64) error {
|
||||
updates := map[string]interface{}{
|
||||
"ocr_result": ocrResult,
|
||||
"ocr_confidence": confidence,
|
||||
"ocr_processed": true,
|
||||
"ocr_success": true,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.LicenseUploadRecord{}).
|
||||
Where("id = ?", recordID).
|
||||
Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新OCR结果失败",
|
||||
zap.String("record_id", recordID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("更新OCR结果失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("OCR结果更新成功",
|
||||
zap.String("record_id", recordID),
|
||||
zap.Float64("confidence", confidence),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取上传记录列表
|
||||
func (r *GormLicenseUploadRecordRepository) GetByUserID(ctx context.Context, userID string, page, pageSize int) ([]*entities.LicenseUploadRecord, int, error) {
|
||||
var records []entities.LicenseUploadRecord
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&entities.LicenseUploadRecord{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
r.logger.Error("获取用户上传记录总数失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取上传记录总数失败: %w", err)
|
||||
}
|
||||
|
||||
// 分页查询
|
||||
offset := (page - 1) * pageSize
|
||||
if err := query.Offset(offset).Limit(pageSize).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("获取用户上传记录列表失败", zap.Error(err))
|
||||
return nil, 0, fmt.Errorf("获取上传记录列表失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.LicenseUploadRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, int(total), nil
|
||||
}
|
||||
|
||||
// GetByQiNiuKey 根据七牛云Key获取上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) GetByQiNiuKey(ctx context.Context, key string) (*entities.LicenseUploadRecord, error) {
|
||||
var record entities.LicenseUploadRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "qiniu_key = ?", key).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("上传记录不存在")
|
||||
}
|
||||
r.logger.Error("根据七牛云Key获取上传记录失败",
|
||||
zap.String("qiniu_key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取上传记录失败: %w", err)
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// GetPendingOCR 获取待OCR处理的上传记录
|
||||
func (r *GormLicenseUploadRecordRepository) GetPendingOCR(ctx context.Context, limit int) ([]*entities.LicenseUploadRecord, error) {
|
||||
var records []entities.LicenseUploadRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("ocr_processed = ? OR (ocr_processed = ? AND ocr_success = ?)", false, true, false).
|
||||
Limit(limit).
|
||||
Order("created_at ASC").
|
||||
Find(&records).Error; err != nil {
|
||||
r.logger.Error("获取待OCR处理记录失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取待OCR处理记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.LicenseUploadRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
@@ -0,0 +1,344 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/certification/entities"
|
||||
"tyapi-server/internal/domains/certification/repositories"
|
||||
"tyapi-server/internal/domains/certification/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormNotificationRecordRepository GORM通知记录仓储实现
|
||||
type GormNotificationRecordRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ repositories.NotificationRecordRepository = (*GormNotificationRecordRepository)(nil)
|
||||
|
||||
// NewGormNotificationRecordRepository 创建GORM通知记录仓储
|
||||
func NewGormNotificationRecordRepository(db *gorm.DB, logger *zap.Logger) repositories.NotificationRecordRepository {
|
||||
return &GormNotificationRecordRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建通知记录
|
||||
func (r *GormNotificationRecordRepository) Create(ctx context.Context, record entities.NotificationRecord) (entities.NotificationRecord, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&record).Error; err != nil {
|
||||
r.logger.Error("创建通知记录失败",
|
||||
zap.String("user_id", *record.UserID),
|
||||
zap.String("type", record.NotificationType),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.NotificationRecord{}, fmt.Errorf("创建通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("通知记录创建成功",
|
||||
zap.String("id", record.ID),
|
||||
zap.String("type", record.NotificationType),
|
||||
)
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取通知记录
|
||||
func (r *GormNotificationRecordRepository) GetByID(ctx context.Context, id string) (entities.NotificationRecord, error) {
|
||||
var record entities.NotificationRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).First(&record, "id = ?", id).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entities.NotificationRecord{}, fmt.Errorf("通知记录不存在")
|
||||
}
|
||||
r.logger.Error("获取通知记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return entities.NotificationRecord{}, fmt.Errorf("获取通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// Update 更新通知记录
|
||||
func (r *GormNotificationRecordRepository) Update(ctx context.Context, record entities.NotificationRecord) error {
|
||||
if err := r.db.WithContext(ctx).Save(&record).Error; err != nil {
|
||||
r.logger.Error("更新通知记录失败",
|
||||
zap.String("id", record.ID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("更新通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除通知记录
|
||||
func (r *GormNotificationRecordRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.NotificationRecord{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除通知记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("删除通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除通知记录
|
||||
func (r *GormNotificationRecordRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// Restore 恢复通知记录
|
||||
func (r *GormNotificationRecordRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.NotificationRecord{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复通知记录失败",
|
||||
zap.String("id", id),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("恢复通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("通知记录恢复成功", zap.String("id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计通知记录数量
|
||||
func (r *GormNotificationRecordRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.NotificationRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("title LIKE ? OR content LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查通知记录是否存在
|
||||
func (r *GormNotificationRecordRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.NotificationRecord{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建通知记录
|
||||
func (r *GormNotificationRecordRepository) CreateBatch(ctx context.Context, records []entities.NotificationRecord) error {
|
||||
r.logger.Info("批量创建通知记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取通知记录
|
||||
func (r *GormNotificationRecordRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.NotificationRecord, error) {
|
||||
var records []entities.NotificationRecord
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&records).Error
|
||||
return records, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新通知记录
|
||||
func (r *GormNotificationRecordRepository) UpdateBatch(ctx context.Context, records []entities.NotificationRecord) error {
|
||||
r.logger.Info("批量更新通知记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Save(&records).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除通知记录
|
||||
func (r *GormNotificationRecordRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除通知记录", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.NotificationRecord{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取通知记录列表
|
||||
func (r *GormNotificationRecordRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.NotificationRecord, error) {
|
||||
var records []entities.NotificationRecord
|
||||
query := r.db.WithContext(ctx).Model(&entities.NotificationRecord{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("title LIKE ? OR content LIKE ? OR user_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return records, query.Find(&records).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormNotificationRecordRepository) WithTx(tx interface{}) interfaces.Repository[entities.NotificationRecord] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormNotificationRecordRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// GetByCertificationID 根据认证申请ID获取通知记录列表
|
||||
func (r *GormNotificationRecordRepository) GetByCertificationID(ctx context.Context, certificationID string) ([]*entities.NotificationRecord, error) {
|
||||
var records []entities.NotificationRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("certification_id = ?", certificationID).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("根据认证申请ID获取通知记录失败",
|
||||
zap.String("certification_id", certificationID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.NotificationRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
|
||||
// GetUnreadByUserID 根据用户ID获取未读通知记录列表
|
||||
func (r *GormNotificationRecordRepository) GetUnreadByUserID(ctx context.Context, userID string) ([]*entities.NotificationRecord, error) {
|
||||
var records []entities.NotificationRecord
|
||||
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ? AND is_read = ?", userID, false).Order("created_at DESC").Find(&records).Error; err != nil {
|
||||
r.logger.Error("根据用户ID获取未读通知记录失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取未读通知记录失败: %w", err)
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.NotificationRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, nil
|
||||
}
|
||||
|
||||
// ListRecords 获取通知记录列表(带分页和筛选)
|
||||
func (r *GormNotificationRecordRepository) ListRecords(ctx context.Context, query *queries.ListNotificationRecordsQuery) ([]*entities.NotificationRecord, int64, error) {
|
||||
var records []entities.NotificationRecord
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.NotificationRecord{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.CertificationID != "" {
|
||||
dbQuery = dbQuery.Where("certification_id = ?", query.CertificationID)
|
||||
}
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.Type != "" {
|
||||
dbQuery = dbQuery.Where("type = ?", query.Type)
|
||||
}
|
||||
if query.IsRead != nil {
|
||||
dbQuery = dbQuery.Where("is_read = ?", *query.IsRead)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&records).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
recordPtrs := make([]*entities.NotificationRecord, len(records))
|
||||
for i := range records {
|
||||
recordPtrs[i] = &records[i]
|
||||
}
|
||||
|
||||
return recordPtrs, total, nil
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建通知记录
|
||||
func (r *GormNotificationRecordRepository) BatchCreate(ctx context.Context, records []entities.NotificationRecord) error {
|
||||
r.logger.Info("批量创建通知记录", zap.Int("count", len(records)))
|
||||
return r.db.WithContext(ctx).Create(&records).Error
|
||||
}
|
||||
|
||||
// MarkAsRead 标记通知记录为已读
|
||||
func (r *GormNotificationRecordRepository) MarkAsRead(ctx context.Context, recordIDs []string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.NotificationRecord{}).
|
||||
Where("id IN ?", recordIDs).
|
||||
Update("is_read", true).Error; err != nil {
|
||||
r.logger.Error("标记通知记录为已读失败",
|
||||
zap.Strings("record_ids", recordIDs),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("标记通知记录为已读失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("通知记录标记为已读成功", zap.Strings("record_ids", recordIDs))
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarkAllAsReadByUser 标记用户所有通知记录为已读
|
||||
func (r *GormNotificationRecordRepository) MarkAllAsReadByUser(ctx context.Context, userID string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.NotificationRecord{}).
|
||||
Where("user_id = ? AND is_read = ?", userID, false).
|
||||
Update("is_read", true).Error; err != nil {
|
||||
r.logger.Error("标记用户所有通知记录为已读失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("标记用户所有通知记录为已读失败: %w", err)
|
||||
}
|
||||
|
||||
r.logger.Info("用户所有通知记录标记为已读成功", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/finance/entities"
|
||||
domain_finance_repo "tyapi-server/internal/domains/finance/repositories"
|
||||
"tyapi-server/internal/domains/finance/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormWalletRepository 钱包GORM仓储实现
|
||||
type GormWalletRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ domain_finance_repo.WalletRepository = (*GormWalletRepository)(nil)
|
||||
|
||||
// NewGormWalletRepository 创建钱包GORM仓储
|
||||
func NewGormWalletRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.WalletRepository {
|
||||
return &GormWalletRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建钱包
|
||||
func (r *GormWalletRepository) Create(ctx context.Context, wallet entities.Wallet) (entities.Wallet, error) {
|
||||
r.logger.Info("创建钱包", zap.String("user_id", wallet.UserID))
|
||||
err := r.db.WithContext(ctx).Create(&wallet).Error
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取钱包
|
||||
func (r *GormWalletRepository) GetByID(ctx context.Context, id string) (entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&wallet).Error
|
||||
return wallet, err
|
||||
}
|
||||
|
||||
// Update 更新钱包
|
||||
func (r *GormWalletRepository) Update(ctx context.Context, wallet entities.Wallet) error {
|
||||
r.logger.Info("更新钱包", zap.String("id", wallet.ID))
|
||||
return r.db.WithContext(ctx).Save(&wallet).Error
|
||||
}
|
||||
|
||||
// Delete 删除钱包
|
||||
func (r *GormWalletRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除钱包", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除钱包
|
||||
func (r *GormWalletRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除钱包", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复钱包
|
||||
func (r *GormWalletRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复钱包", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.Wallet{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计钱包数量
|
||||
func (r *GormWalletRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查钱包是否存在
|
||||
func (r *GormWalletRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建钱包
|
||||
func (r *GormWalletRepository) CreateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
r.logger.Info("批量创建钱包", zap.Int("count", len(wallets)))
|
||||
return r.db.WithContext(ctx).Create(&wallets).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取钱包
|
||||
func (r *GormWalletRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&wallets).Error
|
||||
return wallets, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新钱包
|
||||
func (r *GormWalletRepository) UpdateBatch(ctx context.Context, wallets []entities.Wallet) error {
|
||||
r.logger.Info("批量更新钱包", zap.Int("count", len(wallets)))
|
||||
return r.db.WithContext(ctx).Save(&wallets).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除钱包
|
||||
func (r *GormWalletRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除钱包", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.Wallet{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取钱包列表
|
||||
func (r *GormWalletRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.Wallet, error) {
|
||||
var wallets []entities.Wallet
|
||||
query := r.db.WithContext(ctx).Model(&entities.Wallet{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ?", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return wallets, query.Find(&wallets).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormWalletRepository) WithTx(tx interface{}) interfaces.Repository[entities.Wallet] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormWalletRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找钱包
|
||||
func (r *GormWalletRepository) FindByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID 检查用户钱包是否存在
|
||||
func (r *GormWalletRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// GetTotalBalance 获取总余额
|
||||
func (r *GormWalletRepository) GetTotalBalance(ctx context.Context) (interface{}, error) {
|
||||
var total decimal.Decimal
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&total).Error
|
||||
return total, err
|
||||
}
|
||||
|
||||
// GetActiveWalletCount 获取激活钱包数量
|
||||
func (r *GormWalletRepository) GetActiveWalletCount(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
// GetByUserID 根据用户ID获取钱包
|
||||
func (r *GormWalletRepository) GetByUserID(ctx context.Context, userID string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&wallet).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// GetByWalletAddress 根据钱包地址获取钱包
|
||||
func (r *GormWalletRepository) GetByWalletAddress(ctx context.Context, walletAddress string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("wallet_address = ?", walletAddress).First(&wallet).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// GetByWalletType 根据钱包类型获取钱包
|
||||
func (r *GormWalletRepository) GetByWalletType(ctx context.Context, userID string, walletType string) (*entities.Wallet, error) {
|
||||
var wallet entities.Wallet
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND wallet_type = ?", userID, walletType).First(&wallet).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wallet, nil
|
||||
}
|
||||
|
||||
// ListWallets 获取钱包列表(带分页和筛选)
|
||||
func (r *GormWalletRepository) ListWallets(ctx context.Context, query *queries.ListWalletsQuery) ([]*entities.Wallet, int64, error) {
|
||||
var wallets []entities.Wallet
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.Wallet{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.WalletType != "" {
|
||||
dbQuery = dbQuery.Where("wallet_type = ?", query.WalletType)
|
||||
}
|
||||
if query.WalletAddress != "" {
|
||||
dbQuery = dbQuery.Where("wallet_address LIKE ?", "%"+query.WalletAddress+"%")
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&wallets).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
walletPtrs := make([]*entities.Wallet, len(wallets))
|
||||
for i := range wallets {
|
||||
walletPtrs[i] = &wallets[i]
|
||||
}
|
||||
|
||||
return walletPtrs, total, nil
|
||||
}
|
||||
|
||||
// UpdateBalance 更新钱包余额
|
||||
func (r *GormWalletRepository) UpdateBalance(ctx context.Context, walletID string, balance string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", balance).Error
|
||||
}
|
||||
|
||||
// AddBalance 增加钱包余额
|
||||
func (r *GormWalletRepository) AddBalance(ctx context.Context, walletID string, amount string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance + ?", amount)).Error
|
||||
}
|
||||
|
||||
// SubtractBalance 减少钱包余额
|
||||
func (r *GormWalletRepository) SubtractBalance(ctx context.Context, walletID string, amount string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("balance", gorm.Expr("balance - ?", amount)).Error
|
||||
}
|
||||
|
||||
// ActivateWallet 激活钱包
|
||||
func (r *GormWalletRepository) ActivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", true).Error
|
||||
}
|
||||
|
||||
// DeactivateWallet 停用钱包
|
||||
func (r *GormWalletRepository) DeactivateWallet(ctx context.Context, walletID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("id = ?", walletID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// GetStats 获取财务统计信息
|
||||
func (r *GormWalletRepository) GetStats(ctx context.Context) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 总钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 激活钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("is_active = ?", true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetUserWalletStats 获取用户钱包统计信息
|
||||
func (r *GormWalletRepository) GetUserWalletStats(ctx context.Context, userID string) (*domain_finance_repo.FinanceStats, error) {
|
||||
var stats domain_finance_repo.FinanceStats
|
||||
|
||||
// 用户钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Count(&stats.TotalWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户激活钱包数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ? AND is_active = ?", userID, true).Count(&stats.ActiveWallets).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 用户总余额
|
||||
var totalBalance decimal.Decimal
|
||||
if err := r.db.WithContext(ctx).Model(&entities.Wallet{}).Where("user_id = ?", userID).Select("COALESCE(SUM(balance), 0)").Scan(&totalBalance).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalBalance = totalBalance.String()
|
||||
|
||||
// 用户今日交易数(这里需要根据实际业务逻辑实现)
|
||||
stats.TodayTransactions = 0
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GormUserSecretsRepository 用户密钥GORM仓储实现
|
||||
type GormUserSecretsRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// 编译时检查接口实现
|
||||
var _ domain_finance_repo.UserSecretsRepository = (*GormUserSecretsRepository)(nil)
|
||||
|
||||
// NewGormUserSecretsRepository 创建用户密钥GORM仓储
|
||||
func NewGormUserSecretsRepository(db *gorm.DB, logger *zap.Logger) domain_finance_repo.UserSecretsRepository {
|
||||
return &GormUserSecretsRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建用户密钥
|
||||
func (r *GormUserSecretsRepository) Create(ctx context.Context, secrets entities.UserSecrets) (entities.UserSecrets, error) {
|
||||
r.logger.Info("创建用户密钥", zap.String("user_id", secrets.UserID))
|
||||
err := r.db.WithContext(ctx).Create(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetByID(ctx context.Context, id string) (entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("id = ?", id).First(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// Update 更新用户密钥
|
||||
func (r *GormUserSecretsRepository) Update(ctx context.Context, secrets entities.UserSecrets) error {
|
||||
r.logger.Info("更新用户密钥", zap.String("id", secrets.ID))
|
||||
return r.db.WithContext(ctx).Save(&secrets).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户密钥
|
||||
func (r *GormUserSecretsRepository) Delete(ctx context.Context, id string) error {
|
||||
r.logger.Info("删除用户密钥", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// SoftDelete 软删除用户密钥
|
||||
func (r *GormUserSecretsRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
r.logger.Info("软删除用户密钥", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id = ?", id).Error
|
||||
}
|
||||
|
||||
// Restore 恢复用户密钥
|
||||
func (r *GormUserSecretsRepository) Restore(ctx context.Context, id string) error {
|
||||
r.logger.Info("恢复用户密钥", zap.String("id", id))
|
||||
return r.db.WithContext(ctx).Unscoped().Model(&entities.UserSecrets{}).Where("id = ?", id).Update("deleted_at", nil).Error
|
||||
}
|
||||
|
||||
// Count 统计用户密钥数量
|
||||
func (r *GormUserSecretsRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
return count, query.Count(&count).Error
|
||||
}
|
||||
|
||||
// Exists 检查用户密钥是否存在
|
||||
func (r *GormUserSecretsRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建用户密钥
|
||||
func (r *GormUserSecretsRepository) CreateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
|
||||
r.logger.Info("批量创建用户密钥", zap.Int("count", len(secrets)))
|
||||
return r.db.WithContext(ctx).Create(&secrets).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.UserSecrets, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新用户密钥
|
||||
func (r *GormUserSecretsRepository) UpdateBatch(ctx context.Context, secrets []entities.UserSecrets) error {
|
||||
r.logger.Info("批量更新用户密钥", zap.Int("count", len(secrets)))
|
||||
return r.db.WithContext(ctx).Save(&secrets).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除用户密钥
|
||||
func (r *GormUserSecretsRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除用户密钥", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.UserSecrets{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取用户密钥列表
|
||||
func (r *GormUserSecretsRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.UserSecrets, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
query := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("user_id LIKE ? OR access_id LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return secrets, query.Find(&secrets).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormUserSecretsRepository) WithTx(tx interface{}) interfaces.Repository[entities.UserSecrets] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormUserSecretsRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FindByUserID 根据用户ID查找密钥
|
||||
func (r *GormUserSecretsRepository) FindByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// FindByAccessID 根据访问ID查找密钥
|
||||
func (r *GormUserSecretsRepository) FindByAccessID(ctx context.Context, accessID string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("access_id = ?", accessID).First(&secrets).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// ExistsByUserID 检查用户密钥是否存在
|
||||
func (r *GormUserSecretsRepository) ExistsByUserID(ctx context.Context, userID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ExistsByAccessID 检查访问ID是否存在
|
||||
func (r *GormUserSecretsRepository) ExistsByAccessID(ctx context.Context, accessID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// UpdateLastUsedAt 更新最后使用时间
|
||||
func (r *GormUserSecretsRepository) UpdateLastUsedAt(ctx context.Context, accessID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("access_id = ?", accessID).Update("last_used_at", time.Now()).Error
|
||||
}
|
||||
|
||||
// DeactivateByUserID 停用用户密钥
|
||||
func (r *GormUserSecretsRepository) DeactivateByUserID(ctx context.Context, userID string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Update("is_active", false).Error
|
||||
}
|
||||
|
||||
// RegenerateAccessKey 重新生成访问密钥
|
||||
func (r *GormUserSecretsRepository) RegenerateAccessKey(ctx context.Context, userID string, accessID, accessKey string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).Where("user_id = ?", userID).Updates(map[string]interface{}{
|
||||
"access_id": accessID,
|
||||
"access_key": accessKey,
|
||||
"updated_at": time.Now(),
|
||||
}).Error
|
||||
}
|
||||
|
||||
// GetExpiredSecrets 获取过期的密钥
|
||||
func (r *GormUserSecretsRepository) GetExpiredSecrets(ctx context.Context) ([]entities.UserSecrets, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Find(&secrets).Error
|
||||
return secrets, err
|
||||
}
|
||||
|
||||
// DeleteExpiredSecrets 删除过期的密钥
|
||||
func (r *GormUserSecretsRepository) DeleteExpiredSecrets(ctx context.Context) error {
|
||||
return r.db.WithContext(ctx).Where("expires_at < ?", time.Now()).Delete(&entities.UserSecrets{}).Error
|
||||
}
|
||||
|
||||
// ================ 接口要求的方法 ================
|
||||
|
||||
// GetByUserID 根据用户ID获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetByUserID(ctx context.Context, userID string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&secrets).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// GetBySecretType 根据密钥类型获取用户密钥
|
||||
func (r *GormUserSecretsRepository) GetBySecretType(ctx context.Context, userID string, secretType string) (*entities.UserSecrets, error) {
|
||||
var secrets entities.UserSecrets
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).First(&secrets).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &secrets, nil
|
||||
}
|
||||
|
||||
// ListUserSecrets 获取用户密钥列表(带分页和筛选)
|
||||
func (r *GormUserSecretsRepository) ListUserSecrets(ctx context.Context, query *queries.ListUserSecretsQuery) ([]*entities.UserSecrets, int64, error) {
|
||||
var secrets []entities.UserSecrets
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.UserSecrets{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.UserID != "" {
|
||||
dbQuery = dbQuery.Where("user_id = ?", query.UserID)
|
||||
}
|
||||
if query.SecretType != "" {
|
||||
dbQuery = dbQuery.Where("secret_type = ?", query.SecretType)
|
||||
}
|
||||
if query.IsActive != nil {
|
||||
dbQuery = dbQuery.Where("is_active = ?", *query.IsActive)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&secrets).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
secretPtrs := make([]*entities.UserSecrets, len(secrets))
|
||||
for i := range secrets {
|
||||
secretPtrs[i] = &secrets[i]
|
||||
}
|
||||
|
||||
return secretPtrs, total, nil
|
||||
}
|
||||
|
||||
// UpdateSecret 更新密钥
|
||||
func (r *GormUserSecretsRepository) UpdateSecret(ctx context.Context, userID string, secretType string, secretValue string) error {
|
||||
return r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
|
||||
Where("user_id = ? AND secret_type = ?", userID, secretType).
|
||||
Update("secret_value", secretValue).Error
|
||||
}
|
||||
|
||||
// DeleteSecret 删除密钥
|
||||
func (r *GormUserSecretsRepository) DeleteSecret(ctx context.Context, userID string, secretType string) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ? AND secret_type = ?", userID, secretType).
|
||||
Delete(&entities.UserSecrets{}).Error
|
||||
}
|
||||
|
||||
// ValidateSecret 验证密钥
|
||||
func (r *GormUserSecretsRepository) ValidateSecret(ctx context.Context, userID string, secretType string, secretValue string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.UserSecrets{}).
|
||||
Where("user_id = ? AND secret_type = ? AND secret_value = ?", userID, secretType, secretValue).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// GormEnterpriseInfoRepository 企业信息GORM仓储实现
|
||||
type GormEnterpriseInfoRepository struct {
|
||||
db *gorm.DB
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGormEnterpriseInfoRepository 创建企业信息GORM仓储
|
||||
func NewGormEnterpriseInfoRepository(db *gorm.DB, logger *zap.Logger) repositories.EnterpriseInfoRepository {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: db,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Create(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) (entities.EnterpriseInfo, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("创建企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("创建企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByID(ctx context.Context, id string) (entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&enterpriseInfo).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return entities.EnterpriseInfo{}, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// Update 更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Update(ctx context.Context, enterpriseInfo entities.EnterpriseInfo) error {
|
||||
// 检查企业信息是否已认证完成,认证完成后不可修改
|
||||
if enterpriseInfo.IsReadOnly() {
|
||||
return fmt.Errorf("企业信息已认证完成,不可修改")
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Save(&enterpriseInfo).Error; err != nil {
|
||||
r.logger.Error("更新企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("更新企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("软删除企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("软删除企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的企业信息
|
||||
func (r *GormEnterpriseInfoRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复企业信息失败", zap.Error(err))
|
||||
return fmt.Errorf("恢复企业信息失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUserID(ctx context.Context, userID string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).First(&enterpriseInfo).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// GetByUnifiedSocialCode 根据统一社会信用代码获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByUnifiedSocialCode(ctx context.Context, unifiedSocialCode string) (*entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfo entities.EnterpriseInfo
|
||||
if err := r.db.WithContext(ctx).Where("unified_social_code = ?", unifiedSocialCode).First(&enterpriseInfo).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("企业信息不存在")
|
||||
}
|
||||
r.logger.Error("获取企业信息失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("获取企业信息失败: %w", err)
|
||||
}
|
||||
return &enterpriseInfo, nil
|
||||
}
|
||||
|
||||
// CheckUnifiedSocialCodeExists 检查统一社会信用代码是否已存在
|
||||
func (r *GormEnterpriseInfoRepository) CheckUnifiedSocialCodeExists(ctx context.Context, unifiedSocialCode string, excludeUserID string) (bool, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("unified_social_code = ?", unifiedSocialCode)
|
||||
|
||||
if excludeUserID != "" {
|
||||
query = query.Where("user_id != ?", excludeUserID)
|
||||
}
|
||||
|
||||
if err := query.Count(&count).Error; err != nil {
|
||||
r.logger.Error("检查统一社会信用代码失败", zap.Error(err))
|
||||
return false, fmt.Errorf("检查统一社会信用代码失败: %w", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// UpdateVerificationStatus 更新验证状态
|
||||
func (r *GormEnterpriseInfoRepository) UpdateVerificationStatus(ctx context.Context, userID string, isOCRVerified, isFaceVerified, isCertified bool) error {
|
||||
updates := map[string]interface{}{
|
||||
"is_ocr_verified": isOCRVerified,
|
||||
"is_face_verified": isFaceVerified,
|
||||
"is_certified": isCertified,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新验证状态失败", zap.Error(err))
|
||||
return fmt.Errorf("更新验证状态失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateOCRData 更新OCR数据
|
||||
func (r *GormEnterpriseInfoRepository) UpdateOCRData(ctx context.Context, userID string, rawData string, confidence float64) error {
|
||||
updates := map[string]interface{}{
|
||||
"ocr_raw_data": rawData,
|
||||
"ocr_confidence": confidence,
|
||||
"is_ocr_verified": true,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("更新OCR数据失败", zap.Error(err))
|
||||
return fmt.Errorf("更新OCR数据失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CompleteCertification 完成认证
|
||||
func (r *GormEnterpriseInfoRepository) CompleteCertification(ctx context.Context, userID string) error {
|
||||
now := time.Now()
|
||||
updates := map[string]interface{}{
|
||||
"is_certified": true,
|
||||
"certified_at": &now,
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("user_id = ?", userID).Updates(updates).Error; err != nil {
|
||||
r.logger.Error("完成认证失败", zap.Error(err))
|
||||
return fmt.Errorf("完成认证失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计企业信息数量
|
||||
func (r *GormEnterpriseInfoRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查企业信息是否存在
|
||||
func (r *GormEnterpriseInfoRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建企业信息
|
||||
func (r *GormEnterpriseInfoRepository) CreateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Create(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取企业信息
|
||||
func (r *GormEnterpriseInfoRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新企业信息
|
||||
func (r *GormEnterpriseInfoRepository) UpdateBatch(ctx context.Context, enterpriseInfos []entities.EnterpriseInfo) error {
|
||||
return r.db.WithContext(ctx).Save(&enterpriseInfos).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除企业信息
|
||||
func (r *GormEnterpriseInfoRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
return r.db.WithContext(ctx).Delete(&entities.EnterpriseInfo{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取企业信息列表
|
||||
func (r *GormEnterpriseInfoRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.EnterpriseInfo, error) {
|
||||
var enterpriseInfos []entities.EnterpriseInfo
|
||||
query := r.db.WithContext(ctx).Model(&entities.EnterpriseInfo{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("company_name LIKE ? OR unified_social_code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
err := query.Find(&enterpriseInfos).Error
|
||||
return enterpriseInfos, err
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormEnterpriseInfoRepository) WithTx(tx interface{}) interfaces.Repository[entities.EnterpriseInfo] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormEnterpriseInfoRepository{
|
||||
db: gormTx,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -0,0 +1,499 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/domains/user/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// SMSCodeRepository 短信验证码仓储
|
||||
type GormSMSCodeRepository struct {
|
||||
db *gorm.DB
|
||||
cache interfaces.CacheService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGormSMSCodeRepository 创建短信验证码仓储
|
||||
func NewGormSMSCodeRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) repositories.SMSCodeRepository {
|
||||
return &GormSMSCodeRepository{
|
||||
db: db,
|
||||
cache: cache,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 GormSMSCodeRepository 实现了 SMSCodeRepository 接口
|
||||
var _ repositories.SMSCodeRepository = (*GormSMSCodeRepository)(nil)
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建短信验证码记录
|
||||
func (r *GormSMSCodeRepository) Create(ctx context.Context, smsCode entities.SMSCode) (entities.SMSCode, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&smsCode).Error; err != nil {
|
||||
r.logger.Error("创建短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
// 缓存验证码
|
||||
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
|
||||
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
|
||||
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByID(ctx context.Context, id string) (entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&smsCode).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return entities.SMSCode{}, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
r.logger.Error("获取短信验证码失败", zap.Error(err))
|
||||
return entities.SMSCode{}, err
|
||||
}
|
||||
|
||||
return smsCode, nil
|
||||
}
|
||||
|
||||
// Update 更新验证码记录
|
||||
func (r *GormSMSCodeRepository) Update(ctx context.Context, smsCode entities.SMSCode) error {
|
||||
if err := r.db.WithContext(ctx).Save(&smsCode).Error; err != nil {
|
||||
r.logger.Error("更新验证码记录失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新缓存
|
||||
cacheKey := r.buildCacheKey(smsCode.Phone, smsCode.Scene)
|
||||
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
|
||||
|
||||
r.logger.Info("验证码记录更新成功", zap.String("code_id", smsCode.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除短信验证码
|
||||
func (r *GormSMSCodeRepository) Delete(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除短信验证码失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("短信验证码删除成功", zap.String("id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除短信验证码
|
||||
func (r *GormSMSCodeRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
return r.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// Restore 恢复短信验证码
|
||||
func (r *GormSMSCodeRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.SMSCode{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复短信验证码失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("短信验证码恢复成功", zap.String("id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计短信验证码数量
|
||||
func (r *GormSMSCodeRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ? OR code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查短信验证码是否存在
|
||||
func (r *GormSMSCodeRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.SMSCode{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建短信验证码
|
||||
func (r *GormSMSCodeRepository) CreateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
r.logger.Info("批量创建短信验证码", zap.Int("count", len(smsCodes)))
|
||||
return r.db.WithContext(ctx).Create(&smsCodes).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&smsCodes).Error
|
||||
return smsCodes, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新短信验证码
|
||||
func (r *GormSMSCodeRepository) UpdateBatch(ctx context.Context, smsCodes []entities.SMSCode) error {
|
||||
r.logger.Info("批量更新短信验证码", zap.Int("count", len(smsCodes)))
|
||||
return r.db.WithContext(ctx).Save(&smsCodes).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除短信验证码
|
||||
func (r *GormSMSCodeRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除短信验证码", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.SMSCode{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取短信验证码列表
|
||||
func (r *GormSMSCodeRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.SMSCode, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
query := r.db.WithContext(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ? OR code LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return smsCodes, query.Find(&smsCodes).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormSMSCodeRepository) WithTx(tx interface{}) interfaces.Repository[entities.SMSCode] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormSMSCodeRepository{
|
||||
db: gormTx,
|
||||
cache: r.cache,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// GetByPhone 根据手机号获取短信验证码
|
||||
func (r *GormSMSCodeRepository) GetByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).Where("phone = ?", phone).Order("created_at DESC").First(&smsCode).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return nil, fmt.Errorf("短信验证码不存在")
|
||||
}
|
||||
r.logger.Error("根据手机号获取短信验证码失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// GetLatestByPhone 根据手机号获取最新的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetLatestByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
return r.GetByPhone(ctx, phone)
|
||||
}
|
||||
|
||||
// GetValidByPhone 根据手机号获取有效的短信验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhone(ctx context.Context, phone string) (*entities.SMSCode, error) {
|
||||
return r.GetValidCode(ctx, phone, "")
|
||||
}
|
||||
|
||||
// GetValidByPhoneAndScene 根据手机号和场景获取有效的验证码
|
||||
func (r *GormSMSCodeRepository) GetValidByPhoneAndScene(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
return r.GetValidCode(ctx, phone, scene)
|
||||
}
|
||||
|
||||
// ListSMSCodes 获取短信验证码列表(带分页和筛选)
|
||||
func (r *GormSMSCodeRepository) ListSMSCodes(ctx context.Context, query *queries.ListSMSCodesQuery) ([]*entities.SMSCode, int64, error) {
|
||||
var smsCodes []entities.SMSCode
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.SMSCode{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
dbQuery = dbQuery.Where("phone = ?", query.Phone)
|
||||
}
|
||||
if query.Purpose != "" {
|
||||
dbQuery = dbQuery.Where("scene = ?", query.Purpose)
|
||||
}
|
||||
if query.Status != "" {
|
||||
dbQuery = dbQuery.Where("status = ?", query.Status)
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&smsCodes).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
smsCodePtrs := make([]*entities.SMSCode, len(smsCodes))
|
||||
for i := range smsCodes {
|
||||
smsCodePtrs[i] = &smsCodes[i]
|
||||
}
|
||||
|
||||
return smsCodePtrs, total, nil
|
||||
}
|
||||
|
||||
// CreateCode 创建验证码
|
||||
func (r *GormSMSCodeRepository) CreateCode(ctx context.Context, phone string, code string, purpose string) (entities.SMSCode, error) {
|
||||
smsCode := entities.SMSCode{
|
||||
Phone: phone,
|
||||
Code: code,
|
||||
Scene: entities.SMSScene(purpose),
|
||||
ExpiresAt: time.Now().Add(5 * time.Minute), // 5分钟过期
|
||||
Used: false,
|
||||
}
|
||||
|
||||
return r.Create(ctx, smsCode)
|
||||
}
|
||||
|
||||
// ValidateCode 验证验证码
|
||||
func (r *GormSMSCodeRepository) ValidateCode(ctx context.Context, phone string, code string, purpose string) (bool, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("phone = ? AND code = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
|
||||
phone, code, purpose, time.Now()).
|
||||
First(&smsCode).Error; err != nil {
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return false, nil
|
||||
}
|
||||
r.logger.Error("验证验证码失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
// 标记为已使用
|
||||
if err := r.MarkAsUsed(ctx, smsCode.ID); err != nil {
|
||||
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// InvalidateCode 使验证码失效
|
||||
func (r *GormSMSCodeRepository) InvalidateCode(ctx context.Context, phone string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND used_at IS NULL", phone).
|
||||
Update("used_at", time.Now()).Error; err != nil {
|
||||
r.logger.Error("使验证码失效失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除缓存
|
||||
cacheKey := r.buildCacheKey(phone, "")
|
||||
r.cache.Delete(ctx, cacheKey)
|
||||
|
||||
r.logger.Info("验证码已失效", zap.String("phone", phone))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckSendFrequency 检查发送频率
|
||||
func (r *GormSMSCodeRepository) CheckSendFrequency(ctx context.Context, phone string, purpose string) (bool, error) {
|
||||
// 检查最近1分钟内是否已发送
|
||||
oneMinuteAgo := time.Now().Add(-1 * time.Minute)
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND scene = ? AND created_at > ?", phone, purpose, oneMinuteAgo).
|
||||
Count(&count).Error; err != nil {
|
||||
r.logger.Error("检查发送频率失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count == 0, nil
|
||||
}
|
||||
|
||||
// GetTodaySendCount 获取今日发送次数
|
||||
func (r *GormSMSCodeRepository) GetTodaySendCount(ctx context.Context, phone string) (int64, error) {
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&count).Error; err != nil {
|
||||
r.logger.Error("获取今日发送次数失败", zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// GetCodeStats 获取验证码统计信息
|
||||
func (r *GormSMSCodeRepository) GetCodeStats(ctx context.Context, phone string, days int) (*repositories.SMSCodeStats, error) {
|
||||
var stats repositories.SMSCodeStats
|
||||
|
||||
// 计算指定天数前的日期
|
||||
startDate := time.Now().AddDate(0, 0, -days)
|
||||
|
||||
// 总发送数
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, startDate).
|
||||
Count(&stats.TotalSent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 总验证数
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ? AND used_at IS NOT NULL", phone, startDate).
|
||||
Count(&stats.TotalValidated).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 成功率
|
||||
if stats.TotalSent > 0 {
|
||||
stats.SuccessRate = float64(stats.TotalValidated) / float64(stats.TotalSent) * 100
|
||||
}
|
||||
|
||||
// 今日发送数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND created_at >= ?", phone, today).
|
||||
Count(&stats.TodaySent).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetValidCode 获取有效的验证码
|
||||
func (r *GormSMSCodeRepository) GetValidCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
// 先从缓存查找
|
||||
cacheKey := r.buildCacheKey(phone, scene)
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.cache.Get(ctx, cacheKey, &smsCode); err == nil {
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// 从数据库查找最新的有效验证码
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("phone = ? AND scene = ? AND expires_at > ? AND used_at IS NULL",
|
||||
phone, scene, time.Now()).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 缓存结果
|
||||
r.cache.Set(ctx, cacheKey, &smsCode, 5*time.Minute)
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// MarkAsUsed 标记验证码为已使用
|
||||
func (r *GormSMSCodeRepository) MarkAsUsed(ctx context.Context, id string) error {
|
||||
now := time.Now()
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("id = ?", id).
|
||||
Update("used_at", now).Error; err != nil {
|
||||
r.logger.Error("标记验证码为已使用失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("验证码已标记为使用", zap.String("code_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRecentCode 获取最近的验证码记录(不限制有效性)
|
||||
func (r *GormSMSCodeRepository) GetRecentCode(ctx context.Context, phone string, scene entities.SMSScene) (*entities.SMSCode, error) {
|
||||
var smsCode entities.SMSCode
|
||||
if err := r.db.WithContext(ctx).
|
||||
Where("phone = ? AND scene = ?", phone, scene).
|
||||
Order("created_at DESC").
|
||||
First(&smsCode).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &smsCode, nil
|
||||
}
|
||||
|
||||
// CleanupExpired 清理过期的验证码
|
||||
func (r *GormSMSCodeRepository) CleanupExpired(ctx context.Context) error {
|
||||
result := r.db.WithContext(ctx).
|
||||
Where("expires_at < ?", time.Now()).
|
||||
Delete(&entities.SMSCode{})
|
||||
|
||||
if result.Error != nil {
|
||||
r.logger.Error("清理过期验证码失败", zap.Error(result.Error))
|
||||
return result.Error
|
||||
}
|
||||
|
||||
if result.RowsAffected > 0 {
|
||||
r.logger.Info("清理过期验证码完成", zap.Int64("count", result.RowsAffected))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CountRecentCodes 统计最近发送的验证码数量
|
||||
func (r *GormSMSCodeRepository) CountRecentCodes(ctx context.Context, phone string, scene entities.SMSScene, duration time.Duration) (int64, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.SMSCode{}).
|
||||
Where("phone = ? AND scene = ? AND created_at > ?",
|
||||
phone, scene, time.Now().Add(-duration)).
|
||||
Count(&count).Error; err != nil {
|
||||
r.logger.Error("统计最近验证码数量失败", zap.Error(err))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// buildCacheKey 构建缓存键
|
||||
func (r *GormSMSCodeRepository) buildCacheKey(phone string, scene entities.SMSScene) string {
|
||||
return fmt.Sprintf("sms_code:%s:%s", phone, string(scene))
|
||||
}
|
||||
@@ -0,0 +1,511 @@
|
||||
//go:build !test
|
||||
// +build !test
|
||||
|
||||
package repositories
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"tyapi-server/internal/domains/user/entities"
|
||||
"tyapi-server/internal/domains/user/repositories"
|
||||
"tyapi-server/internal/domains/user/repositories/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// 定义错误常量
|
||||
var (
|
||||
// ErrUserNotFound 用户不存在错误
|
||||
ErrUserNotFound = errors.New("用户不存在")
|
||||
)
|
||||
|
||||
// UserRepository 用户仓储实现
|
||||
type GormUserRepository struct {
|
||||
db *gorm.DB
|
||||
cache interfaces.CacheService
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewGormUserRepository 创建用户仓储
|
||||
func NewGormUserRepository(db *gorm.DB, cache interfaces.CacheService, logger *zap.Logger) repositories.UserRepository {
|
||||
return &GormUserRepository{
|
||||
db: db,
|
||||
cache: cache,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// 确保 GormUserRepository 实现了 UserRepository 接口
|
||||
var _ repositories.UserRepository = (*GormUserRepository)(nil)
|
||||
|
||||
// ================ 基础CRUD操作 ================
|
||||
|
||||
// Create 创建用户
|
||||
func (r *GormUserRepository) Create(ctx context.Context, user entities.User) (entities.User, error) {
|
||||
if err := r.db.WithContext(ctx).Create(&user).Error; err != nil {
|
||||
r.logger.Error("创建用户失败", zap.Error(err))
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByPhone(ctx, user.Phone)
|
||||
|
||||
r.logger.Info("用户创建成功", zap.String("user_id", user.ID))
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (r *GormUserRepository) GetByID(ctx context.Context, id string) (entities.User, error) {
|
||||
// 尝试从缓存获取
|
||||
cacheKey := fmt.Sprintf("user:id:%s", id)
|
||||
var userCache entities.UserCache
|
||||
if err := r.cache.Get(ctx, cacheKey, &userCache); err == nil {
|
||||
var user entities.User
|
||||
user.FromCache(&userCache)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// 从数据库查询
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return entities.User{}, ErrUserNotFound
|
||||
}
|
||||
r.logger.Error("根据ID查询用户失败", zap.Error(err))
|
||||
return entities.User{}, err
|
||||
}
|
||||
|
||||
// 缓存结果
|
||||
r.cache.Set(ctx, cacheKey, user.ToCache(), 10*time.Minute)
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *GormUserRepository) Update(ctx context.Context, user entities.User) error {
|
||||
if err := r.db.WithContext(ctx).Save(&user).Error; err != nil {
|
||||
r.logger.Error("更新用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, user.ID)
|
||||
r.deleteCacheByPhone(ctx, user.Phone)
|
||||
|
||||
r.logger.Info("用户更新成功", zap.String("user_id", user.ID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (r *GormUserRepository) Delete(ctx context.Context, id string) error {
|
||||
// 先获取用户信息用于清除缓存
|
||||
user, err := r.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("删除用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, id)
|
||||
r.deleteCacheByPhone(ctx, user.Phone)
|
||||
|
||||
r.logger.Info("用户删除成功", zap.String("user_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// SoftDelete 软删除用户
|
||||
func (r *GormUserRepository) SoftDelete(ctx context.Context, id string) error {
|
||||
// 先获取用户信息用于清除缓存
|
||||
user, err := r.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.db.WithContext(ctx).Delete(&entities.User{}, "id = ?", id).Error; err != nil {
|
||||
r.logger.Error("软删除用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, id)
|
||||
r.deleteCacheByPhone(ctx, user.Phone)
|
||||
|
||||
r.logger.Info("用户软删除成功", zap.String("user_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Restore 恢复软删除的用户
|
||||
func (r *GormUserRepository) Restore(ctx context.Context, id string) error {
|
||||
if err := r.db.WithContext(ctx).Unscoped().Model(&entities.User{}).Where("id = ?", id).Update("deleted_at", nil).Error; err != nil {
|
||||
r.logger.Error("恢复用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, id)
|
||||
|
||||
r.logger.Info("用户恢复成功", zap.String("user_id", id))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count 统计用户数量
|
||||
func (r *GormUserRepository) Count(ctx context.Context, options interfaces.CountOptions) (int64, error) {
|
||||
var count int64
|
||||
query := r.db.WithContext(ctx).Model(&entities.User{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ? OR nickname LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
err := query.Count(&count).Error
|
||||
return count, err
|
||||
}
|
||||
|
||||
// Exists 检查用户是否存在
|
||||
func (r *GormUserRepository) Exists(ctx context.Context, id string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&entities.User{}).Where("id = ?", id).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CreateBatch 批量创建用户
|
||||
func (r *GormUserRepository) CreateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.logger.Info("批量创建用户", zap.Int("count", len(users)))
|
||||
return r.db.WithContext(ctx).Create(&users).Error
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表获取用户
|
||||
func (r *GormUserRepository) GetByIDs(ctx context.Context, ids []string) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&users).Error
|
||||
return users, err
|
||||
}
|
||||
|
||||
// UpdateBatch 批量更新用户
|
||||
func (r *GormUserRepository) UpdateBatch(ctx context.Context, users []entities.User) error {
|
||||
r.logger.Info("批量更新用户", zap.Int("count", len(users)))
|
||||
return r.db.WithContext(ctx).Save(&users).Error
|
||||
}
|
||||
|
||||
// DeleteBatch 批量删除用户
|
||||
func (r *GormUserRepository) DeleteBatch(ctx context.Context, ids []string) error {
|
||||
r.logger.Info("批量删除用户", zap.Strings("ids", ids))
|
||||
return r.db.WithContext(ctx).Delete(&entities.User{}, "id IN ?", ids).Error
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (r *GormUserRepository) List(ctx context.Context, options interfaces.ListOptions) ([]entities.User, error) {
|
||||
var users []entities.User
|
||||
query := r.db.WithContext(ctx).Model(&entities.User{})
|
||||
|
||||
if options.Filters != nil {
|
||||
for key, value := range options.Filters {
|
||||
query = query.Where(key+" = ?", value)
|
||||
}
|
||||
}
|
||||
|
||||
if options.Search != "" {
|
||||
query = query.Where("phone LIKE ? OR nickname LIKE ?", "%"+options.Search+"%", "%"+options.Search+"%")
|
||||
}
|
||||
|
||||
if options.Sort != "" {
|
||||
order := "ASC"
|
||||
if options.Order != "" {
|
||||
order = options.Order
|
||||
}
|
||||
query = query.Order(options.Sort + " " + order)
|
||||
}
|
||||
|
||||
if options.Page > 0 && options.PageSize > 0 {
|
||||
offset := (options.Page - 1) * options.PageSize
|
||||
query = query.Offset(offset).Limit(options.PageSize)
|
||||
}
|
||||
|
||||
return users, query.Find(&users).Error
|
||||
}
|
||||
|
||||
// WithTx 使用事务
|
||||
func (r *GormUserRepository) WithTx(tx interface{}) interfaces.Repository[entities.User] {
|
||||
if gormTx, ok := tx.(*gorm.DB); ok {
|
||||
return &GormUserRepository{
|
||||
db: gormTx,
|
||||
cache: r.cache,
|
||||
logger: r.logger,
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ================ 业务方法 ================
|
||||
|
||||
// GetByPhone 根据手机号获取用户
|
||||
func (r *GormUserRepository) GetByPhone(ctx context.Context, phone string) (*entities.User, error) {
|
||||
// 尝试从缓存获取
|
||||
cacheKey := fmt.Sprintf("user:phone:%s", phone)
|
||||
var userCache entities.UserCache
|
||||
if err := r.cache.Get(ctx, cacheKey, &userCache); err == nil {
|
||||
var user entities.User
|
||||
user.FromCache(&userCache)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// 从数据库查询
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
r.logger.Error("根据手机号查询用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 缓存结果
|
||||
r.cache.Set(ctx, cacheKey, user.ToCache(), 10*time.Minute)
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// ListUsers 获取用户列表(带分页和筛选)
|
||||
func (r *GormUserRepository) ListUsers(ctx context.Context, query *queries.ListUsersQuery) ([]*entities.User, int64, error) {
|
||||
var users []entities.User
|
||||
var total int64
|
||||
|
||||
dbQuery := r.db.WithContext(ctx).Model(&entities.User{})
|
||||
|
||||
// 应用筛选条件
|
||||
if query.Phone != "" {
|
||||
dbQuery = dbQuery.Where("phone LIKE ?", "%"+query.Phone+"%")
|
||||
}
|
||||
if query.StartDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at >= ?", query.StartDate)
|
||||
}
|
||||
if query.EndDate != "" {
|
||||
dbQuery = dbQuery.Where("created_at <= ?", query.EndDate)
|
||||
}
|
||||
|
||||
// 统计总数
|
||||
if err := dbQuery.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 应用分页
|
||||
offset := (query.Page - 1) * query.PageSize
|
||||
dbQuery = dbQuery.Offset(offset).Limit(query.PageSize)
|
||||
|
||||
// 默认排序
|
||||
dbQuery = dbQuery.Order("created_at DESC")
|
||||
|
||||
// 查询数据
|
||||
if err := dbQuery.Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 转换为指针切片
|
||||
userPtrs := make([]*entities.User, len(users))
|
||||
for i := range users {
|
||||
userPtrs[i] = &users[i]
|
||||
}
|
||||
|
||||
return userPtrs, total, nil
|
||||
}
|
||||
|
||||
// ValidateUser 验证用户
|
||||
func (r *GormUserRepository) ValidateUser(ctx context.Context, phone, password string) (*entities.User, error) {
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("phone = ? AND password = ?", phone, password).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, fmt.Errorf("手机号或密码错误")
|
||||
}
|
||||
r.logger.Error("验证用户失败", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// UpdateLastLogin 更新最后登录时间
|
||||
func (r *GormUserRepository) UpdateLastLogin(ctx context.Context, userID string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("last_login_at", time.Now()).Error; err != nil {
|
||||
r.logger.Error("更新最后登录时间失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, userID)
|
||||
|
||||
r.logger.Info("最后登录时间更新成功", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatePassword 更新密码
|
||||
func (r *GormUserRepository) UpdatePassword(ctx context.Context, userID string, newPassword string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("password", newPassword).Error; err != nil {
|
||||
r.logger.Error("更新密码失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, userID)
|
||||
|
||||
r.logger.Info("密码更新成功", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckPassword 检查密码
|
||||
func (r *GormUserRepository) CheckPassword(ctx context.Context, userID string, password string) (bool, error) {
|
||||
var user entities.User
|
||||
if err := r.db.WithContext(ctx).Where("id = ? AND password = ?", userID, password).First(&user).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
r.logger.Error("检查密码失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// ActivateUser 激活用户
|
||||
func (r *GormUserRepository) ActivateUser(ctx context.Context, userID string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("status", "ACTIVE").Error; err != nil {
|
||||
r.logger.Error("激活用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, userID)
|
||||
|
||||
r.logger.Info("用户激活成功", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeactivateUser 停用用户
|
||||
func (r *GormUserRepository) DeactivateUser(ctx context.Context, userID string) error {
|
||||
if err := r.db.WithContext(ctx).
|
||||
Model(&entities.User{}).
|
||||
Where("id = ?", userID).
|
||||
Update("status", "INACTIVE").Error; err != nil {
|
||||
r.logger.Error("停用用户失败", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
// 清除相关缓存
|
||||
r.deleteCacheByID(ctx, userID)
|
||||
|
||||
r.logger.Info("用户停用成功", zap.String("user_id", userID))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStats 获取用户统计信息
|
||||
func (r *GormUserRepository) GetStats(ctx context.Context) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
// 总用户数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("status = ?", "ACTIVE").Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// GetStatsByDateRange 根据日期范围获取用户统计信息
|
||||
func (r *GormUserRepository) GetStatsByDateRange(ctx context.Context, startDate, endDate string) (*repositories.UserStats, error) {
|
||||
var stats repositories.UserStats
|
||||
|
||||
// 总用户数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at BETWEEN ? AND ?", startDate, endDate).Count(&stats.TotalUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 活跃用户数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("status = ? AND created_at BETWEEN ? AND ?", "ACTIVE", startDate, endDate).Count(&stats.ActiveUsers).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日注册数
|
||||
today := time.Now().Truncate(24 * time.Hour)
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("created_at >= ?", today).Count(&stats.TodayRegistrations).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 今日登录数
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("last_login_at >= ?", today).Count(&stats.TodayLogins).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &stats, nil
|
||||
}
|
||||
|
||||
// FindByPhone 根据手机号查找用户(兼容旧方法)
|
||||
func (r *GormUserRepository) FindByPhone(ctx context.Context, phone string) (*entities.User, error) {
|
||||
return r.GetByPhone(ctx, phone)
|
||||
}
|
||||
|
||||
// ExistsByPhone 检查手机号是否存在
|
||||
func (r *GormUserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||||
var count int64
|
||||
if err := r.db.WithContext(ctx).Model(&entities.User{}).Where("phone = ?", phone).Count(&count).Error; err != nil {
|
||||
r.logger.Error("检查手机号是否存在失败", zap.Error(err))
|
||||
return false, err
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// 私有辅助方法
|
||||
|
||||
// deleteCacheByID 根据ID删除缓存
|
||||
func (r *GormUserRepository) deleteCacheByID(ctx context.Context, id string) {
|
||||
cacheKey := fmt.Sprintf("user:id:%s", id)
|
||||
if err := r.cache.Delete(ctx, cacheKey); err != nil {
|
||||
r.logger.Warn("删除用户ID缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// deleteCacheByPhone 根据手机号删除缓存
|
||||
func (r *GormUserRepository) deleteCacheByPhone(ctx context.Context, phone string) {
|
||||
cacheKey := fmt.Sprintf("user:phone:%s", phone)
|
||||
if err := r.cache.Delete(ctx, cacheKey); err != nil {
|
||||
r.logger.Warn("删除用户手机号缓存失败", zap.String("cache_key", cacheKey), zap.Error(err))
|
||||
}
|
||||
}
|
||||
515
internal/infrastructure/external/notification/wechat_work_service.go
vendored
Normal file
515
internal/infrastructure/external/notification/wechat_work_service.go
vendored
Normal file
@@ -0,0 +1,515 @@
|
||||
package notification
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// WeChatWorkService 企业微信通知服务
|
||||
type WeChatWorkService struct {
|
||||
webhookURL string
|
||||
secret string
|
||||
timeout time.Duration
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// WechatWorkConfig 企业微信配置
|
||||
type WechatWorkConfig struct {
|
||||
WebhookURL string `yaml:"webhook_url"`
|
||||
Timeout time.Duration `yaml:"timeout"`
|
||||
}
|
||||
|
||||
// WechatWorkMessage 企业微信消息
|
||||
type WechatWorkMessage struct {
|
||||
MsgType string `json:"msgtype"`
|
||||
Text *WechatWorkText `json:"text,omitempty"`
|
||||
Markdown *WechatWorkMarkdown `json:"markdown,omitempty"`
|
||||
}
|
||||
|
||||
// WechatWorkText 文本消息
|
||||
type WechatWorkText struct {
|
||||
Content string `json:"content"`
|
||||
MentionedList []string `json:"mentioned_list,omitempty"`
|
||||
MentionedMobileList []string `json:"mentioned_mobile_list,omitempty"`
|
||||
}
|
||||
|
||||
// WechatWorkMarkdown Markdown消息
|
||||
type WechatWorkMarkdown struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// NewWeChatWorkService 创建企业微信通知服务
|
||||
func NewWeChatWorkService(webhookURL, secret string, logger *zap.Logger) *WeChatWorkService {
|
||||
return &WeChatWorkService{
|
||||
webhookURL: webhookURL,
|
||||
secret: secret,
|
||||
timeout: 30 * time.Second,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendTextMessage 发送文本消息
|
||||
func (s *WeChatWorkService) SendTextMessage(ctx context.Context, content string, mentionedList []string, mentionedMobileList []string) error {
|
||||
s.logger.Info("发送企业微信文本消息",
|
||||
zap.String("content", content),
|
||||
zap.Strings("mentioned_list", mentionedList),
|
||||
)
|
||||
|
||||
message := map[string]interface{}{
|
||||
"msgtype": "text",
|
||||
"text": map[string]interface{}{
|
||||
"content": content,
|
||||
"mentioned_list": mentionedList,
|
||||
"mentioned_mobile_list": mentionedMobileList,
|
||||
},
|
||||
}
|
||||
|
||||
return s.sendMessage(ctx, message)
|
||||
}
|
||||
|
||||
// SendMarkdownMessage 发送Markdown消息
|
||||
func (s *WeChatWorkService) SendMarkdownMessage(ctx context.Context, content string) error {
|
||||
s.logger.Info("发送企业微信Markdown消息", zap.String("content", content))
|
||||
|
||||
message := map[string]interface{}{
|
||||
"msgtype": "markdown",
|
||||
"markdown": map[string]interface{}{
|
||||
"content": content,
|
||||
},
|
||||
}
|
||||
|
||||
return s.sendMessage(ctx, message)
|
||||
}
|
||||
|
||||
// SendCardMessage 发送卡片消息
|
||||
func (s *WeChatWorkService) SendCardMessage(ctx context.Context, title, description, url string, btnText string) error {
|
||||
s.logger.Info("发送企业微信卡片消息",
|
||||
zap.String("title", title),
|
||||
zap.String("description", description),
|
||||
)
|
||||
|
||||
message := map[string]interface{}{
|
||||
"msgtype": "template_card",
|
||||
"template_card": map[string]interface{}{
|
||||
"card_type": "text_notice",
|
||||
"source": map[string]interface{}{
|
||||
"icon_url": "https://example.com/icon.png",
|
||||
"desc": "企业认证系统",
|
||||
},
|
||||
"main_title": map[string]interface{}{
|
||||
"title": title,
|
||||
},
|
||||
"horizontal_content_list": []map[string]interface{}{
|
||||
{
|
||||
"keyname": "描述",
|
||||
"value": description,
|
||||
},
|
||||
},
|
||||
"jump_list": []map[string]interface{}{
|
||||
{
|
||||
"type": "1",
|
||||
"title": btnText,
|
||||
"url": url,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return s.sendMessage(ctx, message)
|
||||
}
|
||||
|
||||
// SendCertificationNotification 发送认证相关通知
|
||||
func (s *WeChatWorkService) SendCertificationNotification(ctx context.Context, notificationType string, data map[string]interface{}) error {
|
||||
s.logger.Info("发送认证通知", zap.String("type", notificationType))
|
||||
|
||||
switch notificationType {
|
||||
case "new_application":
|
||||
return s.sendNewApplicationNotification(ctx, data)
|
||||
case "ocr_success":
|
||||
return s.sendOCRSuccessNotification(ctx, data)
|
||||
case "ocr_failed":
|
||||
return s.sendOCRFailedNotification(ctx, data)
|
||||
case "face_verify_success":
|
||||
return s.sendFaceVerifySuccessNotification(ctx, data)
|
||||
case "face_verify_failed":
|
||||
return s.sendFaceVerifyFailedNotification(ctx, data)
|
||||
case "admin_approved":
|
||||
return s.sendAdminApprovedNotification(ctx, data)
|
||||
case "admin_rejected":
|
||||
return s.sendAdminRejectedNotification(ctx, data)
|
||||
case "contract_signed":
|
||||
return s.sendContractSignedNotification(ctx, data)
|
||||
case "certification_completed":
|
||||
return s.sendCertificationCompletedNotification(ctx, data)
|
||||
default:
|
||||
return fmt.Errorf("不支持的通知类型: %s", notificationType)
|
||||
}
|
||||
}
|
||||
|
||||
// sendNewApplicationNotification 发送新申请通知
|
||||
func (s *WeChatWorkService) sendNewApplicationNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicantName := data["applicant_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 🆕 新的企业认证申请
|
||||
|
||||
**企业名称**: %s
|
||||
**申请人**: %s
|
||||
**申请ID**: %s
|
||||
**申请时间**: %s
|
||||
|
||||
请管理员及时审核处理。`,
|
||||
companyName,
|
||||
applicantName,
|
||||
applicationID,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendOCRSuccessNotification 发送OCR识别成功通知
|
||||
func (s *WeChatWorkService) sendOCRSuccessNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
confidence := data["confidence"].(float64)
|
||||
applicationID := data["application_id"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## ✅ OCR识别成功
|
||||
|
||||
**企业名称**: %s
|
||||
**识别置信度**: %.2f%%
|
||||
**申请ID**: %s
|
||||
**识别时间**: %s
|
||||
|
||||
营业执照信息已自动提取,请用户确认信息。`,
|
||||
companyName,
|
||||
confidence*100,
|
||||
applicationID,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendOCRFailedNotification 发送OCR识别失败通知
|
||||
func (s *WeChatWorkService) sendOCRFailedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
applicationID := data["application_id"].(string)
|
||||
errorMsg := data["error_message"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## ❌ OCR识别失败
|
||||
|
||||
**申请ID**: %s
|
||||
**错误信息**: %s
|
||||
**失败时间**: %s
|
||||
|
||||
请检查营业执照图片质量或联系技术支持。`,
|
||||
applicationID,
|
||||
errorMsg,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendFaceVerifySuccessNotification 发送人脸识别成功通知
|
||||
func (s *WeChatWorkService) sendFaceVerifySuccessNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
applicantName := data["applicant_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
confidence := data["confidence"].(float64)
|
||||
|
||||
content := fmt.Sprintf(`## ✅ 人脸识别成功
|
||||
|
||||
**申请人**: %s
|
||||
**申请ID**: %s
|
||||
**识别置信度**: %.2f%%
|
||||
**识别时间**: %s
|
||||
|
||||
身份验证通过,可以进行下一步操作。`,
|
||||
applicantName,
|
||||
applicationID,
|
||||
confidence*100,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendFaceVerifyFailedNotification 发送人脸识别失败通知
|
||||
func (s *WeChatWorkService) sendFaceVerifyFailedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
applicantName := data["applicant_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
errorMsg := data["error_message"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## ❌ 人脸识别失败
|
||||
|
||||
**申请人**: %s
|
||||
**申请ID**: %s
|
||||
**错误信息**: %s
|
||||
**失败时间**: %s
|
||||
|
||||
请重新进行人脸识别或联系技术支持。`,
|
||||
applicantName,
|
||||
applicationID,
|
||||
errorMsg,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendAdminApprovedNotification 发送管理员审核通过通知
|
||||
func (s *WeChatWorkService) sendAdminApprovedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
adminName := data["admin_name"].(string)
|
||||
comment := data["comment"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## ✅ 管理员审核通过
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**审核人**: %s
|
||||
**审核意见**: %s
|
||||
**审核时间**: %s
|
||||
|
||||
认证申请已通过审核,请用户签署电子合同。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
adminName,
|
||||
comment,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendAdminRejectedNotification 发送管理员审核拒绝通知
|
||||
func (s *WeChatWorkService) sendAdminRejectedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
adminName := data["admin_name"].(string)
|
||||
reason := data["reason"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## ❌ 管理员审核拒绝
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**审核人**: %s
|
||||
**拒绝原因**: %s
|
||||
**审核时间**: %s
|
||||
|
||||
认证申请被拒绝,请根据反馈意见重新提交。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
adminName,
|
||||
reason,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendContractSignedNotification 发送合同签署通知
|
||||
func (s *WeChatWorkService) sendContractSignedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
signerName := data["signer_name"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 📝 电子合同已签署
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**签署人**: %s
|
||||
**签署时间**: %s
|
||||
|
||||
电子合同签署完成,系统将自动生成钱包和Access Key。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
signerName,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendCertificationCompletedNotification 发送认证完成通知
|
||||
func (s *WeChatWorkService) sendCertificationCompletedNotification(ctx context.Context, data map[string]interface{}) error {
|
||||
companyName := data["company_name"].(string)
|
||||
applicationID := data["application_id"].(string)
|
||||
walletAddress := data["wallet_address"].(string)
|
||||
|
||||
content := fmt.Sprintf(`## 🎉 企业认证完成
|
||||
|
||||
**企业名称**: %s
|
||||
**申请ID**: %s
|
||||
**钱包地址**: %s
|
||||
**完成时间**: %s
|
||||
|
||||
恭喜!企业认证流程已完成,钱包和Access Key已生成。`,
|
||||
companyName,
|
||||
applicationID,
|
||||
walletAddress,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// sendMessage 发送消息到企业微信
|
||||
func (s *WeChatWorkService) sendMessage(ctx context.Context, message map[string]interface{}) error {
|
||||
// 生成签名URL
|
||||
signedURL := s.generateSignedURL()
|
||||
|
||||
// 序列化消息
|
||||
messageBytes, err := json.Marshal(message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: s.timeout,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", signedURL, bytes.NewBuffer(messageBytes))
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "tyapi-server/1.0")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var response map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
|
||||
return fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误码
|
||||
if errCode, ok := response["errcode"].(float64); ok && errCode != 0 {
|
||||
errmsg := response["errmsg"].(string)
|
||||
return fmt.Errorf("企业微信API错误: %d - %s", int(errCode), errmsg)
|
||||
}
|
||||
|
||||
s.logger.Info("企业微信消息发送成功", zap.Any("response", response))
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateSignedURL 生成带签名的URL
|
||||
func (s *WeChatWorkService) generateSignedURL() string {
|
||||
if s.secret == "" {
|
||||
return s.webhookURL
|
||||
}
|
||||
|
||||
// 生成时间戳
|
||||
timestamp := time.Now().Unix()
|
||||
|
||||
// 生成随机字符串(这里简化处理,实际应该使用随机字符串)
|
||||
nonce := fmt.Sprintf("%d", timestamp)
|
||||
|
||||
// 构建签名字符串
|
||||
signStr := fmt.Sprintf("%d\n%s", timestamp, s.secret)
|
||||
|
||||
// 计算签名
|
||||
h := hmac.New(sha256.New, []byte(s.secret))
|
||||
h.Write([]byte(signStr))
|
||||
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
// 构建签名URL
|
||||
return fmt.Sprintf("%s×tamp=%d&nonce=%s&sign=%s",
|
||||
s.webhookURL, timestamp, nonce, signature)
|
||||
}
|
||||
|
||||
// SendSystemAlert 发送系统告警
|
||||
func (s *WeChatWorkService) SendSystemAlert(ctx context.Context, level, title, message string) error {
|
||||
s.logger.Info("发送系统告警",
|
||||
zap.String("level", level),
|
||||
zap.String("title", title),
|
||||
)
|
||||
|
||||
// 根据告警级别选择图标
|
||||
var icon string
|
||||
switch level {
|
||||
case "info":
|
||||
icon = "ℹ️"
|
||||
case "warning":
|
||||
icon = "⚠️"
|
||||
case "error":
|
||||
icon = "🚨"
|
||||
case "critical":
|
||||
icon = "💥"
|
||||
default:
|
||||
icon = "📢"
|
||||
}
|
||||
|
||||
content := fmt.Sprintf(`## %s 系统告警
|
||||
|
||||
**级别**: %s
|
||||
**标题**: %s
|
||||
**消息**: %s
|
||||
**时间**: %s
|
||||
|
||||
请相关人员及时处理。`,
|
||||
icon,
|
||||
level,
|
||||
title,
|
||||
message,
|
||||
time.Now().Format("2006-01-02 15:04:05"))
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
|
||||
// SendDailyReport 发送每日报告
|
||||
func (s *WeChatWorkService) SendDailyReport(ctx context.Context, reportData map[string]interface{}) error {
|
||||
s.logger.Info("发送每日报告")
|
||||
|
||||
content := fmt.Sprintf(`## 📊 企业认证系统每日报告
|
||||
|
||||
**报告日期**: %s
|
||||
|
||||
### 统计数据
|
||||
- **新增申请**: %d
|
||||
- **OCR识别成功**: %d
|
||||
- **OCR识别失败**: %d
|
||||
- **人脸识别成功**: %d
|
||||
- **人脸识别失败**: %d
|
||||
- **审核通过**: %d
|
||||
- **审核拒绝**: %d
|
||||
- **认证完成**: %d
|
||||
|
||||
### 系统状态
|
||||
- **系统运行时间**: %s
|
||||
- **API调用次数**: %d
|
||||
- **错误次数**: %d
|
||||
|
||||
祝您工作愉快!`,
|
||||
time.Now().Format("2006-01-02"),
|
||||
reportData["new_applications"],
|
||||
reportData["ocr_success"],
|
||||
reportData["ocr_failed"],
|
||||
reportData["face_verify_success"],
|
||||
reportData["face_verify_failed"],
|
||||
reportData["admin_approved"],
|
||||
reportData["admin_rejected"],
|
||||
reportData["certification_completed"],
|
||||
reportData["uptime"],
|
||||
reportData["api_calls"],
|
||||
reportData["errors"])
|
||||
|
||||
return s.SendMarkdownMessage(ctx, content)
|
||||
}
|
||||
505
internal/infrastructure/external/ocr/baidu_ocr_service.go
vendored
Normal file
505
internal/infrastructure/external/ocr/baidu_ocr_service.go
vendored
Normal file
@@ -0,0 +1,505 @@
|
||||
package ocr
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/certification/dto/responses"
|
||||
)
|
||||
|
||||
// BaiduOCRService 百度OCR服务
|
||||
type BaiduOCRService struct {
|
||||
apiKey string
|
||||
secretKey string
|
||||
endpoint string
|
||||
timeout time.Duration
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewBaiduOCRService 创建百度OCR服务
|
||||
func NewBaiduOCRService(apiKey, secretKey string, logger *zap.Logger) *BaiduOCRService {
|
||||
return &BaiduOCRService{
|
||||
apiKey: apiKey,
|
||||
secretKey: secretKey,
|
||||
endpoint: "https://aip.baidubce.com",
|
||||
timeout: 30 * time.Second,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RecognizeBusinessLicense 识别营业执照
|
||||
func (s *BaiduOCRService) RecognizeBusinessLicense(ctx context.Context, imageBytes []byte) (*responses.BusinessLicenseResult, error) {
|
||||
s.logger.Info("开始识别营业执照", zap.Int("image_size", len(imageBytes)))
|
||||
|
||||
// 获取访问令牌
|
||||
accessToken, err := s.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 将图片转换为base64并进行URL编码
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
|
||||
|
||||
// 构建请求URL(只包含access_token)
|
||||
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/business_license?access_token=%s", s.endpoint, accessToken)
|
||||
|
||||
// 构建POST请求体
|
||||
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
|
||||
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("营业执照识别请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
|
||||
errorMsg := result["error_msg"].(string)
|
||||
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
|
||||
}
|
||||
|
||||
// 解析识别结果
|
||||
licenseResult := s.parseBusinessLicenseResult(result)
|
||||
|
||||
s.logger.Info("营业执照识别成功",
|
||||
zap.String("company_name", licenseResult.CompanyName),
|
||||
zap.String("legal_representative", licenseResult.LegalPersonName),
|
||||
zap.String("registered_capital", licenseResult.RegisteredCapital),
|
||||
)
|
||||
|
||||
return licenseResult, nil
|
||||
}
|
||||
|
||||
// RecognizeIDCard 识别身份证
|
||||
func (s *BaiduOCRService) RecognizeIDCard(ctx context.Context, imageBytes []byte, side string) (*responses.IDCardResult, error) {
|
||||
s.logger.Info("开始识别身份证", zap.String("side", side), zap.Int("image_size", len(imageBytes)))
|
||||
|
||||
// 获取访问令牌
|
||||
accessToken, err := s.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 将图片转换为base64并进行URL编码
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
|
||||
|
||||
// 构建请求URL(只包含access_token)
|
||||
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/idcard?access_token=%s", s.endpoint, accessToken)
|
||||
|
||||
// 构建POST请求体
|
||||
payload := strings.NewReader(fmt.Sprintf("image=%s&side=%s", imageBase64UrlEncoded, side))
|
||||
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("身份证识别请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
|
||||
errorMsg := result["error_msg"].(string)
|
||||
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
|
||||
}
|
||||
|
||||
// 解析识别结果
|
||||
idCardResult := s.parseIDCardResult(result, side)
|
||||
|
||||
s.logger.Info("身份证识别成功",
|
||||
zap.String("name", idCardResult.Name),
|
||||
zap.String("id_number", idCardResult.IDCardNumber),
|
||||
zap.String("side", side),
|
||||
)
|
||||
|
||||
return idCardResult, nil
|
||||
}
|
||||
|
||||
// RecognizeGeneralText 通用文字识别
|
||||
func (s *BaiduOCRService) RecognizeGeneralText(ctx context.Context, imageBytes []byte) (*responses.GeneralTextResult, error) {
|
||||
s.logger.Info("开始通用文字识别", zap.Int("image_size", len(imageBytes)))
|
||||
|
||||
// 获取访问令牌
|
||||
accessToken, err := s.getAccessToken(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("获取访问令牌失败: %w", err)
|
||||
}
|
||||
|
||||
// 将图片转换为base64并进行URL编码
|
||||
imageBase64 := base64.StdEncoding.EncodeToString(imageBytes)
|
||||
imageBase64UrlEncoded := url.QueryEscape(imageBase64)
|
||||
|
||||
// 构建请求URL(只包含access_token)
|
||||
apiURL := fmt.Sprintf("%s/rest/2.0/ocr/v1/general_basic?access_token=%s", s.endpoint, accessToken)
|
||||
|
||||
// 构建POST请求体
|
||||
payload := strings.NewReader(fmt.Sprintf("image=%s", imageBase64UrlEncoded))
|
||||
resp, err := s.sendRequest(ctx, "POST", apiURL, payload)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("通用文字识别请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return nil, fmt.Errorf("解析响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error_code"].(float64); ok && errCode != 0 {
|
||||
errorMsg := result["error_msg"].(string)
|
||||
return nil, fmt.Errorf("OCR识别失败: %s", errorMsg)
|
||||
}
|
||||
|
||||
// 解析识别结果
|
||||
textResult := s.parseGeneralTextResult(result)
|
||||
|
||||
s.logger.Info("通用文字识别成功",
|
||||
zap.Int("word_count", len(textResult.Words)),
|
||||
zap.Float64("confidence", textResult.Confidence),
|
||||
)
|
||||
|
||||
return textResult, nil
|
||||
}
|
||||
|
||||
// RecognizeFromURL 从URL识别图片
|
||||
func (s *BaiduOCRService) RecognizeFromURL(ctx context.Context, imageURL string, ocrType string) (interface{}, error) {
|
||||
s.logger.Info("从URL识别图片", zap.String("url", imageURL), zap.String("type", ocrType))
|
||||
|
||||
// 下载图片
|
||||
imageBytes, err := s.downloadImage(ctx, imageURL)
|
||||
if err != nil {
|
||||
s.logger.Error("下载图片失败", zap.Error(err))
|
||||
return nil, fmt.Errorf("下载图片失败: %w", err)
|
||||
}
|
||||
|
||||
// 根据类型调用相应的识别方法
|
||||
switch ocrType {
|
||||
case "business_license":
|
||||
return s.RecognizeBusinessLicense(ctx, imageBytes)
|
||||
case "idcard_front":
|
||||
return s.RecognizeIDCard(ctx, imageBytes, "front")
|
||||
case "idcard_back":
|
||||
return s.RecognizeIDCard(ctx, imageBytes, "back")
|
||||
case "general_text":
|
||||
return s.RecognizeGeneralText(ctx, imageBytes)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的OCR类型: %s", ocrType)
|
||||
}
|
||||
}
|
||||
|
||||
// getAccessToken 获取百度API访问令牌
|
||||
func (s *BaiduOCRService) getAccessToken(ctx context.Context) (string, error) {
|
||||
// 构建获取访问令牌的URL
|
||||
tokenURL := fmt.Sprintf("%s/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
|
||||
s.endpoint, s.apiKey, s.secretKey)
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.sendRequest(ctx, "POST", tokenURL, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("获取访问令牌请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(resp, &result); err != nil {
|
||||
return "", fmt.Errorf("解析访问令牌响应失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查错误
|
||||
if errCode, ok := result["error"].(string); ok && errCode != "" {
|
||||
errorDesc := result["error_description"].(string)
|
||||
return "", fmt.Errorf("获取访问令牌失败: %s - %s", errCode, errorDesc)
|
||||
}
|
||||
|
||||
// 提取访问令牌
|
||||
accessToken, ok := result["access_token"].(string)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("响应中未找到访问令牌")
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// sendRequest 发送HTTP请求
|
||||
func (s *BaiduOCRService) sendRequest(ctx context.Context, method, url string, body io.Reader) ([]byte, error) {
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: s.timeout,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("User-Agent", "tyapi-server/1.0")
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("发送请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("请求失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应内容
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取响应内容失败: %w", err)
|
||||
}
|
||||
|
||||
return responseBody, nil
|
||||
}
|
||||
|
||||
// parseBusinessLicenseResult 解析营业执照识别结果
|
||||
func (s *BaiduOCRService) parseBusinessLicenseResult(result map[string]interface{}) *responses.BusinessLicenseResult {
|
||||
wordsResult := result["words_result"].(map[string]interface{})
|
||||
|
||||
// 提取企业信息
|
||||
companyName := ""
|
||||
if companyNameObj, ok := wordsResult["单位名称"].(map[string]interface{}); ok {
|
||||
companyName = companyNameObj["words"].(string)
|
||||
}
|
||||
|
||||
unifiedSocialCode := ""
|
||||
if socialCreditCodeObj, ok := wordsResult["社会信用代码"].(map[string]interface{}); ok {
|
||||
unifiedSocialCode = socialCreditCodeObj["words"].(string)
|
||||
}
|
||||
|
||||
legalPersonName := ""
|
||||
if legalPersonObj, ok := wordsResult["法人"].(map[string]interface{}); ok {
|
||||
legalPersonName = legalPersonObj["words"].(string)
|
||||
}
|
||||
|
||||
// 提取注册资本等其他信息
|
||||
registeredCapital := ""
|
||||
if registeredCapitalObj, ok := wordsResult["注册资本"].(map[string]interface{}); ok {
|
||||
registeredCapital = registeredCapitalObj["words"].(string)
|
||||
}
|
||||
|
||||
// 计算置信度(这里简化处理,实际应该从OCR结果中获取)
|
||||
confidence := 0.9 // 默认置信度
|
||||
|
||||
return &responses.BusinessLicenseResult{
|
||||
CompanyName: companyName,
|
||||
UnifiedSocialCode: unifiedSocialCode,
|
||||
LegalPersonName: legalPersonName,
|
||||
RegisteredCapital: registeredCapital,
|
||||
Confidence: confidence,
|
||||
}
|
||||
}
|
||||
|
||||
// parseIDCardResult 解析身份证识别结果
|
||||
func (s *BaiduOCRService) parseIDCardResult(result map[string]interface{}, side string) *responses.IDCardResult {
|
||||
wordsResult := result["words_result"].(map[string]interface{})
|
||||
|
||||
idCardResult := &responses.IDCardResult{
|
||||
Side: side,
|
||||
Confidence: s.extractConfidence(result),
|
||||
}
|
||||
|
||||
if side == "front" {
|
||||
if name, ok := wordsResult["姓名"]; ok {
|
||||
if word, ok := name.(map[string]interface{}); ok {
|
||||
idCardResult.Name = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if gender, ok := wordsResult["性别"]; ok {
|
||||
if word, ok := gender.(map[string]interface{}); ok {
|
||||
idCardResult.Gender = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if nation, ok := wordsResult["民族"]; ok {
|
||||
if word, ok := nation.(map[string]interface{}); ok {
|
||||
idCardResult.Nation = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if birthday, ok := wordsResult["出生"]; ok {
|
||||
if word, ok := birthday.(map[string]interface{}); ok {
|
||||
idCardResult.Birthday = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if address, ok := wordsResult["住址"]; ok {
|
||||
if word, ok := address.(map[string]interface{}); ok {
|
||||
idCardResult.Address = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if idNumber, ok := wordsResult["公民身份号码"]; ok {
|
||||
if word, ok := idNumber.(map[string]interface{}); ok {
|
||||
idCardResult.IDCardNumber = word["words"].(string)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if issuingAgency, ok := wordsResult["签发机关"]; ok {
|
||||
if word, ok := issuingAgency.(map[string]interface{}); ok {
|
||||
idCardResult.IssuingAgency = word["words"].(string)
|
||||
}
|
||||
}
|
||||
if validPeriod, ok := wordsResult["有效期限"]; ok {
|
||||
if word, ok := validPeriod.(map[string]interface{}); ok {
|
||||
idCardResult.ValidPeriod = word["words"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return idCardResult
|
||||
}
|
||||
|
||||
// parseGeneralTextResult 解析通用文字识别结果
|
||||
func (s *BaiduOCRService) parseGeneralTextResult(result map[string]interface{}) *responses.GeneralTextResult {
|
||||
wordsResult := result["words_result"].([]interface{})
|
||||
|
||||
textResult := &responses.GeneralTextResult{
|
||||
Confidence: s.extractConfidence(result),
|
||||
Words: make([]responses.TextLine, 0, len(wordsResult)),
|
||||
}
|
||||
|
||||
for _, word := range wordsResult {
|
||||
if wordMap, ok := word.(map[string]interface{}); ok {
|
||||
line := responses.TextLine{
|
||||
Text: wordMap["words"].(string),
|
||||
Confidence: 1.0, // 百度返回的通用文字识别没有单独置信度
|
||||
}
|
||||
textResult.Words = append(textResult.Words, line)
|
||||
}
|
||||
}
|
||||
|
||||
return textResult
|
||||
}
|
||||
|
||||
// extractConfidence 提取置信度
|
||||
func (s *BaiduOCRService) extractConfidence(result map[string]interface{}) float64 {
|
||||
if confidence, ok := result["confidence"].(float64); ok {
|
||||
return confidence
|
||||
}
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// extractWords 提取识别的文字
|
||||
func (s *BaiduOCRService) extractWords(result map[string]interface{}) []string {
|
||||
words := make([]string, 0)
|
||||
|
||||
if wordsResult, ok := result["words_result"]; ok {
|
||||
switch v := wordsResult.(type) {
|
||||
case map[string]interface{}:
|
||||
// 营业执照等结构化文档
|
||||
for _, word := range v {
|
||||
if wordMap, ok := word.(map[string]interface{}); ok {
|
||||
if wordsStr, ok := wordMap["words"].(string); ok {
|
||||
words = append(words, wordsStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
case []interface{}:
|
||||
// 通用文字识别
|
||||
for _, word := range v {
|
||||
if wordMap, ok := word.(map[string]interface{}); ok {
|
||||
if wordsStr, ok := wordMap["words"].(string); ok {
|
||||
words = append(words, wordsStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return words
|
||||
}
|
||||
|
||||
// downloadImage 下载图片
|
||||
func (s *BaiduOCRService) downloadImage(ctx context.Context, imageURL string) ([]byte, error) {
|
||||
// 创建HTTP客户端
|
||||
client := &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", imageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建请求失败: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("下载图片失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// 检查响应状态
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("下载图片失败,状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// 读取响应内容
|
||||
imageBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取图片内容失败: %w", err)
|
||||
}
|
||||
|
||||
return imageBytes, nil
|
||||
}
|
||||
|
||||
// ValidateBusinessLicense 验证营业执照识别结果
|
||||
func (s *BaiduOCRService) ValidateBusinessLicense(result *responses.BusinessLicenseResult) error {
|
||||
if result.Confidence < 0.8 {
|
||||
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
|
||||
}
|
||||
if result.CompanyName == "" {
|
||||
return fmt.Errorf("未能识别公司名称")
|
||||
}
|
||||
if result.LegalPersonName == "" {
|
||||
return fmt.Errorf("未能识别法定代表人")
|
||||
}
|
||||
if result.UnifiedSocialCode == "" {
|
||||
return fmt.Errorf("未能识别统一社会信用代码")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateIDCard 验证身份证识别结果
|
||||
func (s *BaiduOCRService) ValidateIDCard(result *responses.IDCardResult) error {
|
||||
if result.Confidence < 0.8 {
|
||||
return fmt.Errorf("识别置信度过低: %.2f", result.Confidence)
|
||||
}
|
||||
if result.Side == "front" {
|
||||
if result.Name == "" {
|
||||
return fmt.Errorf("未能识别姓名")
|
||||
}
|
||||
if result.IDCardNumber == "" {
|
||||
return fmt.Errorf("未能识别身份证号码")
|
||||
}
|
||||
} else {
|
||||
if result.IssuingAgency == "" {
|
||||
return fmt.Errorf("未能识别签发机关")
|
||||
}
|
||||
if result.ValidPeriod == "" {
|
||||
return fmt.Errorf("未能识别有效期限")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
123
internal/infrastructure/external/sms/sms_service.go
vendored
Normal file
123
internal/infrastructure/external/sms/sms_service.go
vendored
Normal file
@@ -0,0 +1,123 @@
|
||||
package sms
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
|
||||
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/config"
|
||||
)
|
||||
|
||||
// AliSMSService 阿里云短信服务
|
||||
type AliSMSService struct {
|
||||
client *dysmsapi.Client
|
||||
config config.SMSConfig
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAliSMSService 创建阿里云短信服务
|
||||
func NewAliSMSService(cfg config.SMSConfig, logger *zap.Logger) (*AliSMSService, error) {
|
||||
client, err := dysmsapi.NewClientWithAccessKey("cn-hangzhou", cfg.AccessKeyID, cfg.AccessKeySecret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建短信客户端失败: %w", err)
|
||||
}
|
||||
return &AliSMSService{
|
||||
client: client,
|
||||
config: cfg,
|
||||
logger: logger,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendVerificationCode 发送验证码
|
||||
func (s *AliSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
request := dysmsapi.CreateSendSmsRequest()
|
||||
request.Scheme = "https"
|
||||
request.PhoneNumbers = phone
|
||||
request.SignName = s.config.SignName
|
||||
request.TemplateCode = s.config.TemplateCode
|
||||
request.TemplateParam = fmt.Sprintf(`{"code":"%s"}`, code)
|
||||
|
||||
response, err := s.client.SendSms(request)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to send SMS",
|
||||
zap.String("phone", phone),
|
||||
zap.Error(err))
|
||||
return fmt.Errorf("短信发送失败: %w", err)
|
||||
}
|
||||
|
||||
if response.Code != "OK" {
|
||||
s.logger.Error("SMS send failed",
|
||||
zap.String("phone", phone),
|
||||
zap.String("code", response.Code),
|
||||
zap.String("message", response.Message))
|
||||
return fmt.Errorf("短信发送失败: %s - %s", response.Code, response.Message)
|
||||
}
|
||||
|
||||
s.logger.Info("SMS sent successfully",
|
||||
zap.String("phone", phone),
|
||||
zap.String("bizId", response.BizId))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成验证码
|
||||
func (s *AliSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
|
||||
// 生成指定长度的数字验证码
|
||||
max := big.NewInt(int64(pow10(length)))
|
||||
n, _ := rand.Int(rand.Reader, max)
|
||||
|
||||
// 格式化为指定长度,不足时前面补0
|
||||
format := fmt.Sprintf("%%0%dd", length)
|
||||
return fmt.Sprintf(format, n.Int64())
|
||||
}
|
||||
|
||||
// pow10 计算10的n次方
|
||||
func pow10(n int) int {
|
||||
result := 1
|
||||
for i := 0; i < n; i++ {
|
||||
result *= 10
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MockSMSService 模拟短信服务(用于开发和测试)
|
||||
type MockSMSService struct {
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewMockSMSService 创建模拟短信服务
|
||||
func NewMockSMSService(logger *zap.Logger) *MockSMSService {
|
||||
return &MockSMSService{
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// SendVerificationCode 模拟发送验证码
|
||||
func (s *MockSMSService) SendVerificationCode(ctx context.Context, phone string, code string) error {
|
||||
s.logger.Info("Mock SMS sent",
|
||||
zap.String("phone", phone),
|
||||
zap.String("code", code))
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateCode 生成验证码
|
||||
func (s *MockSMSService) GenerateCode(length int) string {
|
||||
if length <= 0 {
|
||||
length = 6
|
||||
}
|
||||
|
||||
// 开发环境使用固定验证码便于测试
|
||||
result := ""
|
||||
for i := 0; i < length; i++ {
|
||||
result += "1"
|
||||
}
|
||||
return result
|
||||
}
|
||||
281
internal/infrastructure/external/storage/qiniu_storage_service.go
vendored
Normal file
281
internal/infrastructure/external/storage/qiniu_storage_service.go
vendored
Normal file
@@ -0,0 +1,281 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/qiniu/go-sdk/v7/auth/qbox"
|
||||
"github.com/qiniu/go-sdk/v7/storage"
|
||||
"go.uber.org/zap"
|
||||
|
||||
sharedStorage "tyapi-server/internal/shared/storage"
|
||||
)
|
||||
|
||||
// QiNiuStorageService 七牛云存储服务
|
||||
type QiNiuStorageService struct {
|
||||
accessKey string
|
||||
secretKey string
|
||||
bucket string
|
||||
domain string
|
||||
logger *zap.Logger
|
||||
mac *qbox.Mac
|
||||
bucketManager *storage.BucketManager
|
||||
}
|
||||
|
||||
// QiNiuStorageConfig 七牛云存储配置
|
||||
type QiNiuStorageConfig struct {
|
||||
AccessKey string `yaml:"access_key"`
|
||||
SecretKey string `yaml:"secret_key"`
|
||||
Bucket string `yaml:"bucket"`
|
||||
Domain string `yaml:"domain"`
|
||||
}
|
||||
|
||||
// NewQiNiuStorageService 创建七牛云存储服务
|
||||
func NewQiNiuStorageService(accessKey, secretKey, bucket, domain string, logger *zap.Logger) *QiNiuStorageService {
|
||||
mac := qbox.NewMac(accessKey, secretKey)
|
||||
// 使用默认配置,不需要指定region
|
||||
cfg := storage.Config{}
|
||||
bucketManager := storage.NewBucketManager(mac, &cfg)
|
||||
|
||||
return &QiNiuStorageService{
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
bucket: bucket,
|
||||
domain: domain,
|
||||
logger: logger,
|
||||
mac: mac,
|
||||
bucketManager: bucketManager,
|
||||
}
|
||||
}
|
||||
|
||||
// UploadFile 上传文件到七牛云
|
||||
func (s *QiNiuStorageService) UploadFile(ctx context.Context, fileBytes []byte, fileName string) (*sharedStorage.UploadResult, error) {
|
||||
s.logger.Info("开始上传文件到七牛云",
|
||||
zap.String("file_name", fileName),
|
||||
zap.Int("file_size", len(fileBytes)),
|
||||
)
|
||||
|
||||
// 生成唯一的文件key
|
||||
key := s.generateFileKey(fileName)
|
||||
|
||||
// 创建上传凭证
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: s.bucket,
|
||||
}
|
||||
upToken := putPolicy.UploadToken(s.mac)
|
||||
|
||||
// 配置上传参数
|
||||
cfg := storage.Config{}
|
||||
formUploader := storage.NewFormUploader(&cfg)
|
||||
ret := storage.PutRet{}
|
||||
|
||||
// 上传文件
|
||||
err := formUploader.Put(ctx, &ret, upToken, key, strings.NewReader(string(fileBytes)), int64(len(fileBytes)), &storage.PutExtra{})
|
||||
if err != nil {
|
||||
s.logger.Error("文件上传失败",
|
||||
zap.String("file_name", fileName),
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("文件上传失败: %w", err)
|
||||
}
|
||||
|
||||
// 构建文件URL
|
||||
fileURL := s.GetFileURL(ctx, key)
|
||||
|
||||
s.logger.Info("文件上传成功",
|
||||
zap.String("file_name", fileName),
|
||||
zap.String("key", key),
|
||||
zap.String("url", fileURL),
|
||||
)
|
||||
|
||||
return &sharedStorage.UploadResult{
|
||||
Key: key,
|
||||
URL: fileURL,
|
||||
MimeType: s.getMimeType(fileName),
|
||||
Size: int64(len(fileBytes)),
|
||||
Hash: ret.Hash,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GenerateUploadToken 生成上传凭证
|
||||
func (s *QiNiuStorageService) GenerateUploadToken(ctx context.Context, key string) (string, error) {
|
||||
putPolicy := storage.PutPolicy{
|
||||
Scope: s.bucket,
|
||||
// 设置过期时间(1小时)
|
||||
Expires: uint64(time.Now().Add(time.Hour).Unix()),
|
||||
}
|
||||
|
||||
token := putPolicy.UploadToken(s.mac)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// GetFileURL 获取文件访问URL
|
||||
func (s *QiNiuStorageService) GetFileURL(ctx context.Context, key string) string {
|
||||
// 如果是私有空间,需要生成带签名的URL
|
||||
if s.isPrivateBucket() {
|
||||
deadline := time.Now().Add(time.Hour).Unix() // 1小时过期
|
||||
privateAccessURL := storage.MakePrivateURL(s.mac, s.domain, key, deadline)
|
||||
return privateAccessURL
|
||||
}
|
||||
|
||||
// 公开空间直接返回URL
|
||||
return fmt.Sprintf("%s/%s", s.domain, key)
|
||||
}
|
||||
|
||||
// GetPrivateFileURL 获取私有文件访问URL
|
||||
func (s *QiNiuStorageService) GetPrivateFileURL(ctx context.Context, key string, expires int64) (string, error) {
|
||||
baseURL := s.GetFileURL(ctx, key)
|
||||
|
||||
// TODO: 实际集成七牛云SDK生成私有URL
|
||||
s.logger.Info("生成七牛云私有文件URL",
|
||||
zap.String("key", key),
|
||||
zap.Int64("expires", expires),
|
||||
)
|
||||
|
||||
// 模拟返回私有URL
|
||||
return fmt.Sprintf("%s?token=mock_private_token&expires=%d", baseURL, expires), nil
|
||||
}
|
||||
|
||||
// DeleteFile 删除文件
|
||||
func (s *QiNiuStorageService) DeleteFile(ctx context.Context, key string) error {
|
||||
s.logger.Info("删除七牛云文件", zap.String("key", key))
|
||||
|
||||
err := s.bucketManager.Delete(s.bucket, key)
|
||||
if err != nil {
|
||||
s.logger.Error("删除文件失败",
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("删除文件失败: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info("文件删除成功", zap.String("key", key))
|
||||
return nil
|
||||
}
|
||||
|
||||
// FileExists 检查文件是否存在
|
||||
func (s *QiNiuStorageService) FileExists(ctx context.Context, key string) (bool, error) {
|
||||
// TODO: 实际集成七牛云SDK检查文件存在性
|
||||
s.logger.Info("检查七牛云文件存在性", zap.String("key", key))
|
||||
|
||||
// 模拟文件存在
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetFileInfo 获取文件信息
|
||||
func (s *QiNiuStorageService) GetFileInfo(ctx context.Context, key string) (*sharedStorage.FileInfo, error) {
|
||||
fileInfo, err := s.bucketManager.Stat(s.bucket, key)
|
||||
if err != nil {
|
||||
s.logger.Error("获取文件信息失败",
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("获取文件信息失败: %w", err)
|
||||
}
|
||||
|
||||
return &sharedStorage.FileInfo{
|
||||
Key: key,
|
||||
Size: fileInfo.Fsize,
|
||||
MimeType: fileInfo.MimeType,
|
||||
Hash: fileInfo.Hash,
|
||||
PutTime: fileInfo.PutTime,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListFiles 列出文件
|
||||
func (s *QiNiuStorageService) ListFiles(ctx context.Context, prefix string, limit int) ([]*sharedStorage.FileInfo, error) {
|
||||
entries, _, _, hasMore, err := s.bucketManager.ListFiles(s.bucket, prefix, "", "", limit)
|
||||
if err != nil {
|
||||
s.logger.Error("列出文件失败",
|
||||
zap.String("prefix", prefix),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("列出文件失败: %w", err)
|
||||
}
|
||||
|
||||
var fileInfos []*sharedStorage.FileInfo
|
||||
for _, entry := range entries {
|
||||
fileInfo := &sharedStorage.FileInfo{
|
||||
Key: entry.Key,
|
||||
Size: entry.Fsize,
|
||||
MimeType: entry.MimeType,
|
||||
Hash: entry.Hash,
|
||||
PutTime: entry.PutTime,
|
||||
}
|
||||
fileInfos = append(fileInfos, fileInfo)
|
||||
}
|
||||
|
||||
_ = hasMore // 暂时忽略hasMore
|
||||
return fileInfos, nil
|
||||
}
|
||||
|
||||
// generateFileKey 生成文件key
|
||||
func (s *QiNiuStorageService) generateFileKey(fileName string) string {
|
||||
// 生成时间戳
|
||||
timestamp := time.Now().Format("20060102_150405")
|
||||
// 生成随机字符串
|
||||
randomStr := fmt.Sprintf("%d", time.Now().UnixNano()%1000000)
|
||||
// 获取文件扩展名
|
||||
ext := filepath.Ext(fileName)
|
||||
// 构建key: 日期/时间戳_随机数.扩展名
|
||||
key := fmt.Sprintf("certification/%s/%s_%s%s",
|
||||
time.Now().Format("20060102"), timestamp, randomStr, ext)
|
||||
|
||||
return key
|
||||
}
|
||||
|
||||
// getMimeType 根据文件名获取MIME类型
|
||||
func (s *QiNiuStorageService) getMimeType(fileName string) string {
|
||||
ext := strings.ToLower(filepath.Ext(fileName))
|
||||
switch ext {
|
||||
case ".jpg", ".jpeg":
|
||||
return "image/jpeg"
|
||||
case ".png":
|
||||
return "image/png"
|
||||
case ".pdf":
|
||||
return "application/pdf"
|
||||
case ".gif":
|
||||
return "image/gif"
|
||||
case ".bmp":
|
||||
return "image/bmp"
|
||||
case ".webp":
|
||||
return "image/webp"
|
||||
default:
|
||||
return "application/octet-stream"
|
||||
}
|
||||
}
|
||||
|
||||
// isPrivateBucket 判断是否为私有空间
|
||||
func (s *QiNiuStorageService) isPrivateBucket() bool {
|
||||
// 这里可以根据配置或域名特征判断
|
||||
// 私有空间的域名通常包含特定标识
|
||||
return strings.Contains(s.domain, "private") ||
|
||||
strings.Contains(s.domain, "auth") ||
|
||||
strings.Contains(s.domain, "secure")
|
||||
}
|
||||
|
||||
// generateSignature 生成签名(用于私有空间访问)
|
||||
func (s *QiNiuStorageService) generateSignature(data string) string {
|
||||
h := hmac.New(sha1.New, []byte(s.secretKey))
|
||||
h.Write([]byte(data))
|
||||
return base64.URLEncoding.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// UploadFromReader 从Reader上传文件
|
||||
func (s *QiNiuStorageService) UploadFromReader(ctx context.Context, reader io.Reader, fileName string, fileSize int64) (*sharedStorage.UploadResult, error) {
|
||||
// 读取文件内容
|
||||
fileBytes, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取文件失败: %w", err)
|
||||
}
|
||||
|
||||
return s.UploadFile(ctx, fileBytes, fileName)
|
||||
}
|
||||
280
internal/infrastructure/http/handlers/admin_handler.go
Normal file
280
internal/infrastructure/http/handlers/admin_handler.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/admin"
|
||||
"tyapi-server/internal/application/admin/dto/commands"
|
||||
"tyapi-server/internal/application/admin/dto/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// AdminHandler 管理员HTTP处理器
|
||||
type AdminHandler struct {
|
||||
appService admin.AdminApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAdminHandler 创建管理员HTTP处理器
|
||||
func NewAdminHandler(
|
||||
appService admin.AdminApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
) *AdminHandler {
|
||||
return &AdminHandler{
|
||||
appService: appService,
|
||||
responseBuilder: responseBuilder,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Login 管理员登录
|
||||
// @Summary 管理员登录
|
||||
// @Description 使用用户名和密码进行管理员登录,返回JWT令牌
|
||||
// @Tags 管理员认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.AdminLoginCommand true "管理员登录请求"
|
||||
// @Success 200 {object} responses.AdminLoginResponse "登录成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "用户名或密码错误"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/auth/login [post]
|
||||
func (h *AdminHandler) Login(c *gin.Context) {
|
||||
var cmd commands.AdminLoginCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.Login(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("管理员登录失败", zap.Error(err))
|
||||
h.responseBuilder.Unauthorized(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "登录成功")
|
||||
}
|
||||
|
||||
// CreateAdmin 创建管理员
|
||||
// @Summary 创建管理员
|
||||
// @Description 创建新的管理员账户,需要超级管理员权限
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateAdminCommand true "创建管理员请求"
|
||||
// @Success 201 {object} map[string]interface{} "管理员创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin [post]
|
||||
func (h *AdminHandler) CreateAdmin(c *gin.Context) {
|
||||
var cmd commands.CreateAdminCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
cmd.OperatorID = h.getCurrentAdminID(c)
|
||||
|
||||
if err := h.appService.CreateAdmin(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("创建管理员失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, nil, "管理员创建成功")
|
||||
}
|
||||
|
||||
// UpdateAdmin 更新管理员
|
||||
// @Summary 更新管理员信息
|
||||
// @Description 更新指定管理员的基本信息
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "管理员ID"
|
||||
// @Param request body commands.UpdateAdminCommand true "更新管理员请求"
|
||||
// @Success 200 {object} map[string]interface{} "管理员更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 404 {object} map[string]interface{} "管理员不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/{id} [put]
|
||||
func (h *AdminHandler) UpdateAdmin(c *gin.Context) {
|
||||
var cmd commands.UpdateAdminCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
cmd.AdminID = c.Param("id")
|
||||
cmd.OperatorID = h.getCurrentAdminID(c)
|
||||
|
||||
if err := h.appService.UpdateAdmin(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("更新管理员失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "管理员更新成功")
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
// @Summary 修改管理员密码
|
||||
// @Description 修改当前登录管理员的密码
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.ChangeAdminPasswordCommand true "修改密码请求"
|
||||
// @Success 200 {object} map[string]interface{} "密码修改成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/change-password [post]
|
||||
func (h *AdminHandler) ChangePassword(c *gin.Context) {
|
||||
var cmd commands.ChangeAdminPasswordCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
cmd.AdminID = h.getCurrentAdminID(c)
|
||||
|
||||
if err := h.appService.ChangePassword(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("修改密码失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "密码修改成功")
|
||||
}
|
||||
|
||||
// ListAdmins 获取管理员列表
|
||||
// @Summary 获取管理员列表
|
||||
// @Description 分页获取管理员列表,支持搜索和筛选
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param page query int false "页码" default(1)
|
||||
// @Param size query int false "每页数量" default(10)
|
||||
// @Param keyword query string false "搜索关键词"
|
||||
// @Param status query string false "状态筛选"
|
||||
// @Success 200 {object} responses.AdminListResponse "获取管理员列表成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin [get]
|
||||
func (h *AdminHandler) ListAdmins(c *gin.Context) {
|
||||
var query queries.ListAdminsQuery
|
||||
if err := c.ShouldBindQuery(&query); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.ListAdmins(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取管理员列表失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取管理员列表失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, response, "获取管理员列表成功")
|
||||
}
|
||||
|
||||
// GetAdminByID 根据ID获取管理员
|
||||
// @Summary 获取管理员详情
|
||||
// @Description 根据管理员ID获取详细信息
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "管理员ID"
|
||||
// @Success 200 {object} responses.AdminInfoResponse "获取管理员详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "管理员不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/{id} [get]
|
||||
func (h *AdminHandler) GetAdminByID(c *gin.Context) {
|
||||
var query queries.GetAdminInfoQuery
|
||||
if err := c.ShouldBindUri(&query); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
admin, err := h.appService.GetAdminByID(c.Request.Context(), &query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取管理员详情失败", zap.Error(err))
|
||||
h.responseBuilder.NotFound(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, admin, "获取管理员详情成功")
|
||||
}
|
||||
|
||||
// DeleteAdmin 删除管理员
|
||||
// @Summary 删除管理员
|
||||
// @Description 删除指定的管理员账户
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param id path string true "管理员ID"
|
||||
// @Success 200 {object} map[string]interface{} "管理员删除成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 403 {object} map[string]interface{} "权限不足"
|
||||
// @Failure 404 {object} map[string]interface{} "管理员不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/{id} [delete]
|
||||
func (h *AdminHandler) DeleteAdmin(c *gin.Context) {
|
||||
var cmd commands.DeleteAdminCommand
|
||||
cmd.AdminID = c.Param("id")
|
||||
cmd.OperatorID = h.getCurrentAdminID(c)
|
||||
|
||||
if err := h.appService.DeleteAdmin(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("删除管理员失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "管理员删除成功")
|
||||
}
|
||||
|
||||
// GetAdminStats 获取管理员统计信息
|
||||
// @Summary 获取管理员统计信息
|
||||
// @Description 获取管理员相关的统计数据
|
||||
// @Tags 管理员管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.AdminStatsResponse "获取统计信息成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/admin/stats [get]
|
||||
func (h *AdminHandler) GetAdminStats(c *gin.Context) {
|
||||
stats, err := h.appService.GetAdminStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取管理员统计失败", zap.Error(err))
|
||||
h.responseBuilder.InternalError(c, "获取统计信息失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, stats, "获取统计信息成功")
|
||||
}
|
||||
|
||||
// getCurrentAdminID 获取当前管理员ID
|
||||
func (h *AdminHandler) getCurrentAdminID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
472
internal/infrastructure/http/handlers/certification_handler.go
Normal file
472
internal/infrastructure/http/handlers/certification_handler.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/certification"
|
||||
"tyapi-server/internal/application/certification/dto/commands"
|
||||
"tyapi-server/internal/application/certification/dto/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// CertificationHandler 认证处理器
|
||||
type CertificationHandler struct {
|
||||
appService certification.CertificationApplicationService
|
||||
response interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCertificationHandler 创建认证处理器
|
||||
func NewCertificationHandler(
|
||||
appService certification.CertificationApplicationService,
|
||||
response interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
) *CertificationHandler {
|
||||
return &CertificationHandler{
|
||||
appService: appService,
|
||||
response: response,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateCertification 创建认证申请
|
||||
// @Summary 创建认证申请
|
||||
// @Description 为用户创建新的企业认证申请
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.CertificationResponse "认证申请创建成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification [post]
|
||||
func (h *CertificationHandler) CreateCertification(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
cmd := &commands.CreateCertificationCommand{UserID: userID}
|
||||
result, err := h.appService.CreateCertification(c.Request.Context(), cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("创建认证申请失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.InternalError(c, "创建认证申请失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "认证申请创建成功")
|
||||
}
|
||||
|
||||
// UploadBusinessLicense 上传营业执照并同步OCR识别
|
||||
// @Summary 上传营业执照并同步OCR识别
|
||||
// @Description 上传营业执照文件,立即进行OCR识别并返回结果
|
||||
// @Tags 企业认证
|
||||
// @Accept multipart/form-data
|
||||
// @Produce json
|
||||
// @Param file formData file true "营业执照文件"
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UploadLicenseResponse "上传成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未授权"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/upload-license [post]
|
||||
func (h *CertificationHandler) UploadBusinessLicense(c *gin.Context) {
|
||||
// 获取当前用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取上传的文件
|
||||
file, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
h.response.BadRequest(c, "文件上传失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 读取文件内容
|
||||
openedFile, err := file.Open()
|
||||
if err != nil {
|
||||
h.response.BadRequest(c, "无法读取文件")
|
||||
return
|
||||
}
|
||||
defer openedFile.Close()
|
||||
|
||||
fileBytes, err := io.ReadAll(openedFile)
|
||||
if err != nil {
|
||||
h.response.BadRequest(c, "文件读取失败")
|
||||
return
|
||||
}
|
||||
|
||||
// 调用应用服务
|
||||
response, err := h.appService.UploadBusinessLicense(c.Request.Context(), userID.(string), fileBytes, file.Filename)
|
||||
if err != nil {
|
||||
h.logger.Error("营业执照上传失败", zap.Error(err))
|
||||
h.response.InternalError(c, "营业执照上传失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, response, "营业执照上传成功")
|
||||
}
|
||||
|
||||
// GetCertificationStatus 获取认证状态
|
||||
// @Summary 获取认证状态
|
||||
// @Description 获取当前用户的认证申请状态
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.CertificationResponse "获取认证状态成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/status [get]
|
||||
func (h *CertificationHandler) GetCertificationStatus(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetCertificationStatusQuery{UserID: userID}
|
||||
result, err := h.appService.GetCertificationStatus(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证状态失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取认证状态成功")
|
||||
}
|
||||
|
||||
// GetProgressStats 获取进度统计
|
||||
// @Summary 获取进度统计
|
||||
// @Description 获取认证申请的进度统计数据
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} map[string]interface{} "获取进度统计成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/stats [get]
|
||||
func (h *CertificationHandler) GetProgressStats(c *gin.Context) {
|
||||
// 这里应该实现获取进度统计的逻辑
|
||||
// 暂时返回空数据
|
||||
h.response.Success(c, map[string]interface{}{
|
||||
"total_applications": 0,
|
||||
"pending": 0,
|
||||
"in_progress": 0,
|
||||
"completed": 0,
|
||||
"rejected": 0,
|
||||
}, "获取进度统计成功")
|
||||
}
|
||||
|
||||
// GetCertificationProgress 获取认证进度
|
||||
// @Summary 获取认证进度
|
||||
// @Description 获取当前用户的认证申请详细进度信息
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} map[string]interface{} "获取认证进度成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证申请不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/progress [get]
|
||||
func (h *CertificationHandler) GetCertificationProgress(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetCertificationProgress(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证进度失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取认证进度成功")
|
||||
}
|
||||
|
||||
// SubmitEnterpriseInfo 提交企业信息
|
||||
// @Summary 提交企业信息
|
||||
// @Description 提交企业基本信息,包括企业名称、统一社会信用代码、法定代表人信息等
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.SubmitEnterpriseInfoCommand true "企业信息"
|
||||
// @Success 200 {object} responses.CertificationResponse "企业信息提交成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/enterprise-info [post]
|
||||
func (h *CertificationHandler) SubmitEnterpriseInfo(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.SubmitEnterpriseInfoCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.logger.Error("参数绑定失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "请求参数格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.SubmitEnterpriseInfo(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("提交企业信息失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "企业信息提交成功")
|
||||
}
|
||||
|
||||
// InitiateFaceVerify 发起人脸验证
|
||||
// @Summary 发起人脸验证
|
||||
// @Description 发起企业法人人脸验证流程
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.InitiateFaceVerifyCommand true "人脸验证请求"
|
||||
// @Success 200 {object} responses.FaceVerifyResponse "人脸验证发起成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/face-verify [post]
|
||||
func (h *CertificationHandler) InitiateFaceVerify(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.InitiateFaceVerifyCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.logger.Error("参数绑定失败", zap.Error(err))
|
||||
h.response.BadRequest(c, "请求参数格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 根据用户ID获取认证申请
|
||||
query := &queries.GetCertificationStatusQuery{UserID: userID}
|
||||
certification, err := h.appService.GetCertificationStatus(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证申请失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 如果用户没有认证申请,返回错误
|
||||
if certification.ID == "" {
|
||||
h.response.BadRequest(c, "用户尚未创建认证申请")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.CertificationID = certification.ID
|
||||
|
||||
result, err := h.appService.InitiateFaceVerify(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("发起人脸验证失败",
|
||||
zap.String("certification_id", certification.ID),
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "人脸验证发起成功")
|
||||
}
|
||||
|
||||
// ApplyContract 申请合同
|
||||
// @Summary 申请合同
|
||||
// @Description 申请企业认证合同
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.CertificationResponse "合同申请成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/contract [post]
|
||||
func (h *CertificationHandler) ApplyContract(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.ApplyContract(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("申请合同失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "合同申请成功")
|
||||
}
|
||||
|
||||
// GetCertificationDetails 获取认证详情
|
||||
// @Summary 获取认证详情
|
||||
// @Description 获取当前用户的认证申请详细信息
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.CertificationResponse "获取认证详情成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "认证申请不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/details [get]
|
||||
func (h *CertificationHandler) GetCertificationDetails(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetCertificationDetailsQuery{
|
||||
UserID: userID,
|
||||
}
|
||||
|
||||
result, err := h.appService.GetCertificationDetails(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取认证详情失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取认证详情成功")
|
||||
}
|
||||
|
||||
// RetryStep 重试步骤
|
||||
// @Summary 重试认证步骤
|
||||
// @Description 重新执行指定的认证步骤
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param step path string true "步骤名称"
|
||||
// @Success 200 {object} map[string]interface{} "步骤重试成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/retry/{step} [post]
|
||||
func (h *CertificationHandler) RetryStep(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
step := c.Param("step")
|
||||
if step == "" {
|
||||
h.response.BadRequest(c, "步骤名称不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
var result interface{}
|
||||
var err error
|
||||
|
||||
switch step {
|
||||
case "face_verify":
|
||||
result, err = h.appService.RetryFaceVerify(c.Request.Context(), userID)
|
||||
case "contract_sign":
|
||||
result, err = h.appService.RetryContractSign(c.Request.Context(), userID)
|
||||
default:
|
||||
h.response.BadRequest(c, "不支持的步骤类型")
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
h.logger.Error("重试认证步骤失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("step", step),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "认证步骤重试成功")
|
||||
}
|
||||
|
||||
// GetLicenseOCRResult 获取营业执照OCR识别结果
|
||||
// @Summary 获取营业执照OCR识别结果
|
||||
// @Description 根据上传记录ID获取OCR识别结果
|
||||
// @Tags 企业认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param record_id path string true "上传记录ID"
|
||||
// @Success 200 {object} responses.UploadLicenseResponse "获取OCR结果成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "记录不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/certification/license/{record_id}/ocr-result [get]
|
||||
func (h *CertificationHandler) GetLicenseOCRResult(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
recordID := c.Param("record_id")
|
||||
if recordID == "" {
|
||||
h.response.BadRequest(c, "上传记录ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetLicenseOCRResult(c.Request.Context(), recordID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取OCR结果失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.String("record_id", recordID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, result, "获取OCR结果成功")
|
||||
}
|
||||
429
internal/infrastructure/http/handlers/finance_handler.go
Normal file
429
internal/infrastructure/http/handlers/finance_handler.go
Normal file
@@ -0,0 +1,429 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/finance"
|
||||
"tyapi-server/internal/application/finance/dto/commands"
|
||||
"tyapi-server/internal/application/finance/dto/queries"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
)
|
||||
|
||||
// FinanceHandler 财务HTTP处理器
|
||||
type FinanceHandler struct {
|
||||
appService finance.FinanceApplicationService
|
||||
responseBuilder interfaces.ResponseBuilder
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewFinanceHandler 创建财务HTTP处理器
|
||||
func NewFinanceHandler(
|
||||
appService finance.FinanceApplicationService,
|
||||
responseBuilder interfaces.ResponseBuilder,
|
||||
logger *zap.Logger,
|
||||
) *FinanceHandler {
|
||||
return &FinanceHandler{
|
||||
appService: appService,
|
||||
responseBuilder: responseBuilder,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWallet 创建钱包
|
||||
// @Summary 创建钱包
|
||||
// @Description 为用户创建新的钱包账户
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.CreateWalletCommand true "创建钱包请求"
|
||||
// @Success 201 {object} responses.WalletResponse "钱包创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 409 {object} map[string]interface{} "钱包已存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet [post]
|
||||
func (h *FinanceHandler) CreateWallet(c *gin.Context) {
|
||||
var cmd commands.CreateWalletCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
response, err := h.appService.CreateWallet(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("创建钱包失败", zap.Error(err))
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, response, "钱包创建成功")
|
||||
}
|
||||
|
||||
// GetWallet 获取钱包信息
|
||||
// @Summary 获取钱包信息
|
||||
// @Description 获取当前用户的钱包详细信息
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.WalletResponse "获取钱包信息成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "钱包不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet [get]
|
||||
func (h *FinanceHandler) GetWallet(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetWalletInfoQuery{UserID: userID}
|
||||
result, err := h.appService.GetWallet(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取钱包信息失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取钱包信息成功")
|
||||
}
|
||||
|
||||
// UpdateWallet 更新钱包
|
||||
// @Summary 更新钱包信息
|
||||
// @Description 更新当前用户的钱包基本信息
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.UpdateWalletCommand true "更新钱包请求"
|
||||
// @Success 200 {object} map[string]interface{} "钱包更新成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet [put]
|
||||
func (h *FinanceHandler) UpdateWallet(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.UpdateWalletCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
err := h.appService.UpdateWallet(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("更新钱包失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "钱包更新成功")
|
||||
}
|
||||
|
||||
// Recharge 充值
|
||||
// @Summary 钱包充值
|
||||
// @Description 为钱包进行充值操作
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.RechargeWalletCommand true "充值请求"
|
||||
// @Success 200 {object} responses.TransactionResponse "充值成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/recharge [post]
|
||||
func (h *FinanceHandler) Recharge(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.RechargeWalletCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.Recharge(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("充值失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "充值成功")
|
||||
}
|
||||
|
||||
// Withdraw 提现
|
||||
// @Summary 钱包提现
|
||||
// @Description 从钱包进行提现操作
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.WithdrawWalletCommand true "提现请求"
|
||||
// @Success 200 {object} responses.TransactionResponse "提现申请已提交"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/withdraw [post]
|
||||
func (h *FinanceHandler) Withdraw(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.WithdrawWalletCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.Withdraw(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("提现失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "提现申请已提交")
|
||||
}
|
||||
|
||||
// WalletTransaction 钱包交易
|
||||
// @Summary 钱包交易
|
||||
// @Description 执行钱包内部交易操作
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.WalletTransactionCommand true "交易请求"
|
||||
// @Success 200 {object} responses.TransactionResponse "交易成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/transaction [post]
|
||||
func (h *FinanceHandler) WalletTransaction(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.WalletTransactionCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.WalletTransaction(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("钱包交易失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "交易成功")
|
||||
}
|
||||
|
||||
// GetWalletStats 获取钱包统计
|
||||
// @Summary 获取钱包统计
|
||||
// @Description 获取钱包相关的统计数据
|
||||
// @Tags 钱包管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.WalletStatsResponse "获取钱包统计成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/wallet/stats [get]
|
||||
func (h *FinanceHandler) GetWalletStats(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.appService.GetWalletStats(c.Request.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("获取钱包统计失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.InternalError(c, "获取钱包统计失败")
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取钱包统计成功")
|
||||
}
|
||||
|
||||
// CreateUserSecrets 创建用户密钥
|
||||
// @Summary 创建用户密钥
|
||||
// @Description 为用户创建API访问密钥
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.CreateUserSecretsCommand true "创建密钥请求"
|
||||
// @Success 201 {object} responses.UserSecretsResponse "用户密钥创建成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 409 {object} map[string]interface{} "密钥已存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets [post]
|
||||
func (h *FinanceHandler) CreateUserSecrets(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.CreateUserSecretsCommand
|
||||
if err := c.ShouldBindJSON(&cmd); err != nil {
|
||||
h.responseBuilder.BadRequest(c, "请求参数错误")
|
||||
return
|
||||
}
|
||||
|
||||
cmd.UserID = userID
|
||||
|
||||
result, err := h.appService.CreateUserSecrets(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("创建用户密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Created(c, result, "用户密钥创建成功")
|
||||
}
|
||||
|
||||
// GetUserSecrets 获取用户密钥
|
||||
// @Summary 获取用户密钥
|
||||
// @Description 获取当前用户的API访问密钥信息
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UserSecretsResponse "获取用户密钥成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "密钥不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets [get]
|
||||
func (h *FinanceHandler) GetUserSecrets(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
query := &queries.GetUserSecretsQuery{UserID: userID}
|
||||
result, err := h.appService.GetUserSecrets(c.Request.Context(), query)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "获取用户密钥成功")
|
||||
}
|
||||
|
||||
// RegenerateAccessKey 重新生成访问密钥
|
||||
// @Summary 重新生成访问密钥
|
||||
// @Description 重新生成用户的API访问密钥
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UserSecretsResponse "访问密钥重新生成成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "密钥不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets/regenerate [post]
|
||||
func (h *FinanceHandler) RegenerateAccessKey(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
cmd := &commands.RegenerateAccessKeyCommand{UserID: userID}
|
||||
result, err := h.appService.RegenerateAccessKey(c.Request.Context(), cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("重新生成访问密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, result, "访问密钥重新生成成功")
|
||||
}
|
||||
|
||||
// DeactivateUserSecrets 停用用户密钥
|
||||
// @Summary 停用用户密钥
|
||||
// @Description 停用用户的API访问密钥
|
||||
// @Tags 用户密钥管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} map[string]interface{} "用户密钥停用成功"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "密钥不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/finance/secrets/deactivate [post]
|
||||
func (h *FinanceHandler) DeactivateUserSecrets(c *gin.Context) {
|
||||
userID := c.GetString("user_id")
|
||||
if userID == "" {
|
||||
h.responseBuilder.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
cmd := &commands.DeactivateUserSecretsCommand{UserID: userID}
|
||||
err := h.appService.DeactivateUserSecrets(c.Request.Context(), cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("停用用户密钥失败",
|
||||
zap.String("user_id", userID),
|
||||
zap.Error(err),
|
||||
)
|
||||
h.responseBuilder.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.responseBuilder.Success(c, nil, "用户密钥停用成功")
|
||||
}
|
||||
224
internal/infrastructure/http/handlers/user_handler.go
Normal file
224
internal/infrastructure/http/handlers/user_handler.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"tyapi-server/internal/application/user"
|
||||
"tyapi-server/internal/application/user/dto/commands"
|
||||
"tyapi-server/internal/shared/interfaces"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
)
|
||||
|
||||
// UserHandler 用户HTTP处理器
|
||||
type UserHandler struct {
|
||||
appService user.UserApplicationService
|
||||
response interfaces.ResponseBuilder
|
||||
validator interfaces.RequestValidator
|
||||
logger *zap.Logger
|
||||
jwtAuth *middleware.JWTAuthMiddleware
|
||||
}
|
||||
|
||||
// NewUserHandler 创建用户处理器
|
||||
func NewUserHandler(
|
||||
appService user.UserApplicationService,
|
||||
response interfaces.ResponseBuilder,
|
||||
validator interfaces.RequestValidator,
|
||||
logger *zap.Logger,
|
||||
jwtAuth *middleware.JWTAuthMiddleware,
|
||||
) *UserHandler {
|
||||
return &UserHandler{
|
||||
appService: appService,
|
||||
response: response,
|
||||
validator: validator,
|
||||
logger: logger,
|
||||
jwtAuth: jwtAuth,
|
||||
}
|
||||
}
|
||||
|
||||
// SendCode 发送验证码
|
||||
// @Summary 发送短信验证码
|
||||
// @Description 向指定手机号发送验证码,支持注册、登录、修改密码等场景
|
||||
// @Tags 用户认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.SendCodeCommand true "发送验证码请求"
|
||||
// @Success 200 {object} map[string]interface{} "验证码发送成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 429 {object} map[string]interface{} "请求频率限制"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/send-code [post]
|
||||
func (h *UserHandler) SendCode(c *gin.Context) {
|
||||
var cmd commands.SendCodeCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
|
||||
if err := h.appService.SendCode(c.Request.Context(), &cmd, clientIP, userAgent); err != nil {
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, nil, "验证码发送成功")
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
// @Summary 用户注册
|
||||
// @Description 使用手机号、密码和验证码进行用户注册,需要确认密码
|
||||
// @Tags 用户认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.RegisterUserCommand true "用户注册请求"
|
||||
// @Success 201 {object} responses.RegisterUserResponse "注册成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
|
||||
// @Failure 409 {object} map[string]interface{} "手机号已存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/register [post]
|
||||
func (h *UserHandler) Register(c *gin.Context) {
|
||||
var cmd commands.RegisterUserCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.appService.Register(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("注册用户失败", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Created(c, resp, "用户注册成功")
|
||||
}
|
||||
|
||||
// LoginWithPassword 密码登录
|
||||
// @Summary 用户密码登录
|
||||
// @Description 使用手机号和密码进行用户登录,返回JWT令牌
|
||||
// @Tags 用户认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.LoginWithPasswordCommand true "密码登录请求"
|
||||
// @Success 200 {object} responses.LoginUserResponse "登录成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误"
|
||||
// @Failure 401 {object} map[string]interface{} "用户名或密码错误"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/login-password [post]
|
||||
func (h *UserHandler) LoginWithPassword(c *gin.Context) {
|
||||
var cmd commands.LoginWithPasswordCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.appService.LoginWithPassword(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("密码登录失败", zap.Error(err))
|
||||
h.response.Unauthorized(c, "用户名或密码错误")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, resp, "登录成功")
|
||||
}
|
||||
|
||||
// LoginWithSMS 短信验证码登录
|
||||
// @Summary 用户短信验证码登录
|
||||
// @Description 使用手机号和短信验证码进行用户登录,返回JWT令牌
|
||||
// @Tags 用户认证
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Param request body commands.LoginWithSMSCommand true "短信登录请求"
|
||||
// @Success 200 {object} responses.LoginUserResponse "登录成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
|
||||
// @Failure 401 {object} map[string]interface{} "认证失败"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/login-sms [post]
|
||||
func (h *UserHandler) LoginWithSMS(c *gin.Context) {
|
||||
var cmd commands.LoginWithSMSCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.appService.LoginWithSMS(c.Request.Context(), &cmd)
|
||||
if err != nil {
|
||||
h.logger.Error("短信登录失败", zap.Error(err))
|
||||
h.response.Unauthorized(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, resp, "登录成功")
|
||||
}
|
||||
|
||||
// GetProfile 获取当前用户信息
|
||||
// @Summary 获取当前用户信息
|
||||
// @Description 根据JWT令牌获取当前登录用户的详细信息
|
||||
// @Tags 用户管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Success 200 {object} responses.UserProfileResponse "用户信息"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 404 {object} map[string]interface{} "用户不存在"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/me [get]
|
||||
func (h *UserHandler) GetProfile(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.appService.GetUserProfile(c.Request.Context(), userID)
|
||||
if err != nil {
|
||||
h.logger.Error("获取用户资料失败", zap.Error(err))
|
||||
h.response.NotFound(c, "用户不存在")
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, resp, "获取用户资料成功")
|
||||
}
|
||||
|
||||
// ChangePassword 修改密码
|
||||
// @Summary 修改密码
|
||||
// @Description 使用旧密码、新密码确认和验证码修改当前用户的密码
|
||||
// @Tags 用户管理
|
||||
// @Accept json
|
||||
// @Produce json
|
||||
// @Security Bearer
|
||||
// @Param request body commands.ChangePasswordCommand true "修改密码请求"
|
||||
// @Success 200 {object} map[string]interface{} "密码修改成功"
|
||||
// @Failure 400 {object} map[string]interface{} "请求参数错误或验证码无效"
|
||||
// @Failure 401 {object} map[string]interface{} "未认证"
|
||||
// @Failure 500 {object} map[string]interface{} "服务器内部错误"
|
||||
// @Router /api/v1/users/me/password [put]
|
||||
func (h *UserHandler) ChangePassword(c *gin.Context) {
|
||||
userID := h.getCurrentUserID(c)
|
||||
if userID == "" {
|
||||
h.response.Unauthorized(c, "用户未认证")
|
||||
return
|
||||
}
|
||||
|
||||
var cmd commands.ChangePasswordCommand
|
||||
if err := h.validator.BindAndValidate(c, &cmd); err != nil {
|
||||
return
|
||||
}
|
||||
cmd.UserID = userID
|
||||
|
||||
if err := h.appService.ChangePassword(c.Request.Context(), &cmd); err != nil {
|
||||
h.logger.Error("修改密码失败", zap.Error(err))
|
||||
h.response.BadRequest(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
h.response.Success(c, nil, "密码修改成功")
|
||||
}
|
||||
|
||||
// getCurrentUserID 获取当前用户ID
|
||||
func (h *UserHandler) getCurrentUserID(c *gin.Context) string {
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if id, ok := userID.(string); ok {
|
||||
return id
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
58
internal/infrastructure/http/routes/admin_routes.go
Normal file
58
internal/infrastructure/http/routes/admin_routes.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/http/handlers"
|
||||
sharedhttp "tyapi-server/internal/shared/http"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// AdminRoutes 管理员路由注册器
|
||||
type AdminRoutes struct {
|
||||
handler *handlers.AdminHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAdminRoutes 创建管理员路由注册器
|
||||
func NewAdminRoutes(
|
||||
handler *handlers.AdminHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *AdminRoutes {
|
||||
return &AdminRoutes{
|
||||
handler: handler,
|
||||
authMiddleware: authMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册管理员相关路由
|
||||
func (r *AdminRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 管理员路由组
|
||||
engine := router.GetEngine()
|
||||
adminGroup := engine.Group("/api/v1/admin")
|
||||
{
|
||||
// 认证相关路由(无需认证)
|
||||
authGroup := adminGroup.Group("/auth")
|
||||
{
|
||||
authGroup.POST("/login", r.handler.Login)
|
||||
}
|
||||
|
||||
// 管理员管理路由(需要认证)
|
||||
authenticated := adminGroup.Group("")
|
||||
authenticated.Use(r.authMiddleware.Handle())
|
||||
{
|
||||
authenticated.POST("", r.handler.CreateAdmin) // 创建管理员
|
||||
authenticated.GET("", r.handler.ListAdmins) // 获取管理员列表
|
||||
authenticated.GET("/stats", r.handler.GetAdminStats) // 获取统计信息
|
||||
authenticated.GET("/:id", r.handler.GetAdminByID) // 获取管理员详情
|
||||
authenticated.PUT("/:id", r.handler.UpdateAdmin) // 更新管理员
|
||||
authenticated.DELETE("/:id", r.handler.DeleteAdmin) // 删除管理员
|
||||
authenticated.POST("/change-password", r.handler.ChangePassword) // 修改密码
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("管理员路由注册完成")
|
||||
}
|
||||
73
internal/infrastructure/http/routes/certification_routes.go
Normal file
73
internal/infrastructure/http/routes/certification_routes.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/http/handlers"
|
||||
sharedhttp "tyapi-server/internal/shared/http"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// CertificationRoutes 认证路由注册器
|
||||
type CertificationRoutes struct {
|
||||
certificationHandler *handlers.CertificationHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewCertificationRoutes 创建认证路由注册器
|
||||
func NewCertificationRoutes(
|
||||
certificationHandler *handlers.CertificationHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *CertificationRoutes {
|
||||
return &CertificationRoutes{
|
||||
certificationHandler: certificationHandler,
|
||||
authMiddleware: authMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册认证相关路由
|
||||
func (r *CertificationRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 认证相关路由组,需要用户认证
|
||||
engine := router.GetEngine()
|
||||
certificationGroup := engine.Group("/api/v1/certification")
|
||||
certificationGroup.Use(r.authMiddleware.Handle())
|
||||
{
|
||||
// 创建认证申请
|
||||
certificationGroup.POST("", r.certificationHandler.CreateCertification)
|
||||
|
||||
// 营业执照上传
|
||||
certificationGroup.POST("/upload-license", r.certificationHandler.UploadBusinessLicense)
|
||||
|
||||
// 获取OCR识别结果
|
||||
certificationGroup.GET("/license/:record_id/ocr-result", r.certificationHandler.GetLicenseOCRResult)
|
||||
|
||||
// 获取认证状态
|
||||
certificationGroup.GET("/status", r.certificationHandler.GetCertificationStatus)
|
||||
|
||||
// 获取进度统计
|
||||
certificationGroup.GET("/stats", r.certificationHandler.GetProgressStats)
|
||||
|
||||
// 获取认证进度
|
||||
certificationGroup.GET("/progress", r.certificationHandler.GetCertificationProgress)
|
||||
|
||||
// 提交企业信息
|
||||
certificationGroup.POST("/enterprise-info", r.certificationHandler.SubmitEnterpriseInfo)
|
||||
|
||||
// 发起人脸识别验证
|
||||
certificationGroup.POST("/face-verify", r.certificationHandler.InitiateFaceVerify)
|
||||
|
||||
// 申请合同签署
|
||||
certificationGroup.POST("/contract", r.certificationHandler.ApplyContract)
|
||||
|
||||
// 获取认证详情
|
||||
certificationGroup.GET("/details", r.certificationHandler.GetCertificationDetails)
|
||||
|
||||
// 重试认证步骤
|
||||
certificationGroup.POST("/retry/:step", r.certificationHandler.RetryStep)
|
||||
}
|
||||
|
||||
r.logger.Info("认证路由注册完成")
|
||||
}
|
||||
61
internal/infrastructure/http/routes/finance_routes.go
Normal file
61
internal/infrastructure/http/routes/finance_routes.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/http/handlers"
|
||||
sharedhttp "tyapi-server/internal/shared/http"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// FinanceRoutes 财务路由注册器
|
||||
type FinanceRoutes struct {
|
||||
financeHandler *handlers.FinanceHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewFinanceRoutes 创建财务路由注册器
|
||||
func NewFinanceRoutes(
|
||||
financeHandler *handlers.FinanceHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *FinanceRoutes {
|
||||
return &FinanceRoutes{
|
||||
financeHandler: financeHandler,
|
||||
authMiddleware: authMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册财务相关路由
|
||||
func (r *FinanceRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 财务路由组,需要用户认证
|
||||
engine := router.GetEngine()
|
||||
financeGroup := engine.Group("/api/v1/finance")
|
||||
financeGroup.Use(r.authMiddleware.Handle())
|
||||
{
|
||||
// 钱包相关路由
|
||||
walletGroup := financeGroup.Group("/wallet")
|
||||
{
|
||||
walletGroup.POST("", r.financeHandler.CreateWallet) // 创建钱包
|
||||
walletGroup.GET("", r.financeHandler.GetWallet) // 获取钱包信息
|
||||
walletGroup.PUT("", r.financeHandler.UpdateWallet) // 更新钱包
|
||||
walletGroup.POST("/recharge", r.financeHandler.Recharge) // 充值
|
||||
walletGroup.POST("/withdraw", r.financeHandler.Withdraw) // 提现
|
||||
walletGroup.POST("/transaction", r.financeHandler.WalletTransaction) // 钱包交易
|
||||
walletGroup.GET("/stats", r.financeHandler.GetWalletStats) // 获取钱包统计
|
||||
}
|
||||
|
||||
// 用户密钥相关路由
|
||||
secretsGroup := financeGroup.Group("/secrets")
|
||||
{
|
||||
secretsGroup.POST("", r.financeHandler.CreateUserSecrets) // 创建用户密钥
|
||||
secretsGroup.GET("", r.financeHandler.GetUserSecrets) // 获取用户密钥
|
||||
secretsGroup.POST("/regenerate", r.financeHandler.RegenerateAccessKey) // 重新生成访问密钥
|
||||
secretsGroup.POST("/deactivate", r.financeHandler.DeactivateUserSecrets) // 停用用户密钥
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("财务路由注册完成")
|
||||
}
|
||||
53
internal/infrastructure/http/routes/user_routes.go
Normal file
53
internal/infrastructure/http/routes/user_routes.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"tyapi-server/internal/infrastructure/http/handlers"
|
||||
sharedhttp "tyapi-server/internal/shared/http"
|
||||
"tyapi-server/internal/shared/middleware"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// UserRoutes 用户路由注册器
|
||||
type UserRoutes struct {
|
||||
handler *handlers.UserHandler
|
||||
authMiddleware *middleware.JWTAuthMiddleware
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewUserRoutes 创建用户路由注册器
|
||||
func NewUserRoutes(
|
||||
handler *handlers.UserHandler,
|
||||
authMiddleware *middleware.JWTAuthMiddleware,
|
||||
logger *zap.Logger,
|
||||
) *UserRoutes {
|
||||
return &UserRoutes{
|
||||
handler: handler,
|
||||
authMiddleware: authMiddleware,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 注册用户相关路由
|
||||
func (r *UserRoutes) Register(router *sharedhttp.GinRouter) {
|
||||
// 用户域路由组
|
||||
engine := router.GetEngine()
|
||||
usersGroup := engine.Group("/api/v1/users")
|
||||
{
|
||||
// 公开路由(不需要认证)
|
||||
usersGroup.POST("/send-code", r.handler.SendCode) // 发送验证码
|
||||
usersGroup.POST("/register", r.handler.Register) // 用户注册
|
||||
usersGroup.POST("/login-password", r.handler.LoginWithPassword) // 密码登录
|
||||
usersGroup.POST("/login-sms", r.handler.LoginWithSMS) // 短信验证码登录
|
||||
|
||||
// 需要认证的路由
|
||||
authenticated := usersGroup.Group("")
|
||||
authenticated.Use(r.authMiddleware.Handle())
|
||||
{
|
||||
authenticated.GET("/me", r.handler.GetProfile) // 获取当前用户信息
|
||||
authenticated.PUT("/me/password", r.handler.ChangePassword) // 修改密码
|
||||
}
|
||||
}
|
||||
|
||||
r.logger.Info("用户路由注册完成")
|
||||
}
|
||||
Reference in New Issue
Block a user